better 更好的记忆抽取策略,并且移除了无用选项

This commit is contained in:
SengokuCola
2025-03-21 14:37:19 +08:00
parent e0a7bf7e99
commit 6c3afa84c4
14 changed files with 547 additions and 1282 deletions

View File

@@ -18,6 +18,7 @@ from ..chat.utils import (
)
from ..models.utils_model import LLM_request
from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG
from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler
# 定义日志配置
memory_config = LogConfig(
@@ -195,19 +196,9 @@ class Hippocampus:
return hash(f"{nodes[0]}:{nodes[1]}")
def random_get_msg_snippet(self, target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list:
"""随机抽取一段时间内的消息片段
Args:
- target_timestamp: 目标时间戳
- chat_size: 抽取的消息数量
- max_memorized_time_per_msg: 每条消息的最大记忆次数
Returns:
- list: 抽取出的消息记录列表
"""
try_count = 0
# 最多尝试次抽取
while try_count < 3:
# 最多尝试2次抽取
while try_count < 2:
messages = get_closest_chat_from_db(length=chat_size, timestamp=target_timestamp)
if messages:
# 检查messages是否均没有达到记忆次数限制
@@ -224,54 +215,37 @@ class Hippocampus:
)
return messages
try_count += 1
# 三次尝试均失败
return None
def get_memory_sample(self, chat_size=20, time_frequency=None):
"""获取记忆样本
Returns:
list: 消息记录列表,每个元素是一个消息记录字典列表
"""
def get_memory_sample(self):
# 硬编码:每条消息最大记忆次数
# 如有需求可写入global_config
if time_frequency is None:
time_frequency = {"near": 2, "mid": 4, "far": 3}
max_memorized_time_per_msg = 3
current_timestamp = datetime.datetime.now().timestamp()
# 创建双峰分布的记忆调度器
scheduler = MemoryBuildScheduler(
n_hours1=global_config.memory_build_distribution[0], # 第一个分布均值4小时前
std_hours1=global_config.memory_build_distribution[1], # 第一个分布标准差
weight1=global_config.memory_build_distribution[2], # 第一个分布权重 60%
n_hours2=global_config.memory_build_distribution[3], # 第二个分布均值24小时前
std_hours2=global_config.memory_build_distribution[4], # 第二个分布标准差
weight2=global_config.memory_build_distribution[5], # 第二个分布权重 40%
total_samples=global_config.build_memory_sample_num # 总共生成10个时间点
)
# 生成时间戳数组
timestamps = scheduler.get_timestamp_array()
logger.debug(f"生成的时间戳数组: {timestamps}")
chat_samples = []
# 短期1h 中期4h 长期24h
logger.debug("正在抽取短期消息样本")
for i in range(time_frequency.get("near")):
random_time = current_timestamp - random.randint(1, 3600)
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
for timestamp in timestamps:
messages = self.random_get_msg_snippet(timestamp, global_config.build_memory_sample_length, max_memorized_time_per_msg)
if messages:
logger.debug(f"成功抽取短期消息样本{len(messages)}")
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
logger.debug(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}")
chat_samples.append(messages)
else:
logger.warning(f"{i}次短期消息样本抽取失败")
logger.debug("正在抽取中期消息样本")
for i in range(time_frequency.get("mid")):
random_time = current_timestamp - random.randint(3600, 3600 * 4)
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
if messages:
logger.debug(f"成功抽取中期消息样本{len(messages)}")
chat_samples.append(messages)
else:
logger.warning(f"{i}次中期消息样本抽取失败")
logger.debug("正在抽取长期消息样本")
for i in range(time_frequency.get("far")):
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
if messages:
logger.debug(f"成功抽取长期消息样本{len(messages)}")
chat_samples.append(messages)
else:
logger.warning(f"{i}次长期消息样本抽取失败")
logger.warning(f"时间戳 {timestamp}消息样本抽取失败")
return chat_samples
@@ -372,9 +346,8 @@ class Hippocampus:
)
return topic_num
async def operation_build_memory(self, chat_size=20):
time_frequency = {"near": 1, "mid": 4, "far": 4}
memory_samples = self.get_memory_sample(chat_size, time_frequency)
async def operation_build_memory(self):
memory_samples = self.get_memory_sample()
for i, messages in enumerate(memory_samples, 1):
all_topics = []

View File

@@ -7,11 +7,9 @@ import sys
import time
from collections import Counter
from pathlib import Path
import matplotlib.pyplot as plt
import networkx as nx
from dotenv import load_dotenv
from src.common.logger import get_module_logger
import jieba
# from chat.config import global_config
@@ -19,6 +17,7 @@ import jieba
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.logger import get_module_logger
from src.common.database import db # noqa E402
from src.plugins.memory_system.offline_llm import LLMModel # noqa E402

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,172 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import time
from datetime import datetime, timedelta
class DistributionVisualizer:
def __init__(self, mean=0, std=1, skewness=0, sample_size=10):
"""
初始化分布可视化器
参数:
mean (float): 期望均值
std (float): 标准差
skewness (float): 偏度
sample_size (int): 样本大小
"""
self.mean = mean
self.std = std
self.skewness = skewness
self.sample_size = sample_size
self.samples = None
def generate_samples(self):
"""生成具有指定参数的样本"""
if self.skewness == 0:
# 对于无偏度的情况,直接使用正态分布
self.samples = np.random.normal(loc=self.mean, scale=self.std, size=self.sample_size)
else:
# 使用 scipy.stats 生成具有偏度的分布
self.samples = stats.skewnorm.rvs(a=self.skewness,
loc=self.mean,
scale=self.std,
size=self.sample_size)
def get_weighted_samples(self):
"""获取加权后的样本数列"""
if self.samples is None:
self.generate_samples()
# 将样本值乘以样本大小
return self.samples * self.sample_size
def get_statistics(self):
"""获取分布的统计信息"""
if self.samples is None:
self.generate_samples()
return {
"均值": np.mean(self.samples),
"标准差": np.std(self.samples),
"实际偏度": stats.skew(self.samples)
}
class MemoryBuildScheduler:
def __init__(self,
n_hours1, std_hours1, weight1,
n_hours2, std_hours2, weight2,
total_samples=50):
"""
初始化记忆构建调度器
参数:
n_hours1 (float): 第一个分布的均值(距离现在的小时数)
std_hours1 (float): 第一个分布的标准差(小时)
weight1 (float): 第一个分布的权重
n_hours2 (float): 第二个分布的均值(距离现在的小时数)
std_hours2 (float): 第二个分布的标准差(小时)
weight2 (float): 第二个分布的权重
total_samples (int): 要生成的总时间点数量
"""
# 归一化权重
total_weight = weight1 + weight2
self.weight1 = weight1 / total_weight
self.weight2 = weight2 / total_weight
self.n_hours1 = n_hours1
self.std_hours1 = std_hours1
self.n_hours2 = n_hours2
self.std_hours2 = std_hours2
self.total_samples = total_samples
self.base_time = datetime.now()
def generate_time_samples(self):
"""生成混合分布的时间采样点"""
# 根据权重计算每个分布的样本数
samples1 = int(self.total_samples * self.weight1)
samples2 = self.total_samples - samples1
# 生成两个正态分布的小时偏移
hours_offset1 = np.random.normal(
loc=self.n_hours1,
scale=self.std_hours1,
size=samples1
)
hours_offset2 = np.random.normal(
loc=self.n_hours2,
scale=self.std_hours2,
size=samples2
)
# 合并两个分布的偏移
hours_offset = np.concatenate([hours_offset1, hours_offset2])
# 将偏移转换为实际时间戳(使用绝对值确保时间点在过去)
timestamps = [self.base_time - timedelta(hours=abs(offset)) for offset in hours_offset]
# 按时间排序(从最早到最近)
return sorted(timestamps)
def get_timestamp_array(self):
"""返回时间戳数组"""
timestamps = self.generate_time_samples()
return [int(t.timestamp()) for t in timestamps]
def print_time_samples(timestamps, show_distribution=True):
"""打印时间样本和分布信息"""
print(f"\n生成的{len(timestamps)}个时间点分布:")
print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
print("-" * 50)
now = datetime.now()
time_diffs = []
for i, timestamp in enumerate(timestamps, 1):
hours_diff = (now - timestamp).total_seconds() / 3600
time_diffs.append(hours_diff)
print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
# 打印统计信息
print("\n统计信息:")
print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
print(f"标准差:{np.std(time_diffs):.2f}小时")
print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
if show_distribution:
# 计算时间分布的直方图
hist, bins = np.histogram(time_diffs, bins=40)
print("\n时间分布(每个*代表一个时间点):")
for i in range(len(hist)):
if hist[i] > 0:
print(f"{bins[i]:6.1f}-{bins[i+1]:6.1f}小时: {'*' * int(hist[i])}")
# 使用示例
if __name__ == "__main__":
# 创建一个双峰分布的记忆调度器
scheduler = MemoryBuildScheduler(
n_hours1=12, # 第一个分布均值12小时前
std_hours1=8, # 第一个分布标准差
weight1=0.7, # 第一个分布权重 70%
n_hours2=36, # 第二个分布均值36小时前
std_hours2=24, # 第二个分布标准差
weight2=0.3, # 第二个分布权重 30%
total_samples=50 # 总共生成50个时间点
)
# 生成时间分布
timestamps = scheduler.generate_time_samples()
# 打印结果,包含分布可视化
print_time_samples(timestamps, show_distribution=True)
# 打印时间戳数组
timestamp_array = scheduler.get_timestamp_array()
print("\n时间戳数组Unix时间戳")
print("[", end="")
for i, ts in enumerate(timestamp_array):
if i > 0:
print(", ", end="")
print(ts, end="")
print("]")