ruff
This commit is contained in:
committed by
Windpicker-owo
parent
e65ab14f94
commit
950b086063
@@ -29,20 +29,21 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class HippocampusSampleConfig:
|
||||
"""海马体采样配置"""
|
||||
|
||||
# 双峰分布参数
|
||||
recent_mean_hours: float = 12.0 # 近期分布均值(小时)
|
||||
recent_std_hours: float = 8.0 # 近期分布标准差(小时)
|
||||
recent_weight: float = 0.7 # 近期分布权重
|
||||
recent_std_hours: float = 8.0 # 近期分布标准差(小时)
|
||||
recent_weight: float = 0.7 # 近期分布权重
|
||||
|
||||
distant_mean_hours: float = 48.0 # 远期分布均值(小时)
|
||||
distant_std_hours: float = 24.0 # 远期分布标准差(小时)
|
||||
distant_weight: float = 0.3 # 远期分布权重
|
||||
distant_std_hours: float = 24.0 # 远期分布标准差(小时)
|
||||
distant_weight: float = 0.3 # 远期分布权重
|
||||
|
||||
# 采样参数
|
||||
total_samples: int = 50 # 总采样数
|
||||
sample_interval: int = 1800 # 采样间隔(秒)
|
||||
max_sample_length: int = 30 # 每次采样的最大消息数量
|
||||
batch_size: int = 5 # 批处理大小
|
||||
total_samples: int = 50 # 总采样数
|
||||
sample_interval: int = 1800 # 采样间隔(秒)
|
||||
max_sample_length: int = 30 # 每次采样的最大消息数量
|
||||
batch_size: int = 5 # 批处理大小
|
||||
|
||||
@classmethod
|
||||
def from_global_config(cls) -> "HippocampusSampleConfig":
|
||||
@@ -84,12 +85,10 @@ class HippocampusSampler:
|
||||
try:
|
||||
# 初始化LLM模型
|
||||
from src.config.config import model_config
|
||||
|
||||
task_config = getattr(model_config.model_task_config, "utils", None)
|
||||
if task_config:
|
||||
self.memory_builder_model = LLMRequest(
|
||||
model_set=task_config,
|
||||
request_type="memory.hippocampus_build"
|
||||
)
|
||||
self.memory_builder_model = LLMRequest(model_set=task_config, request_type="memory.hippocampus_build")
|
||||
asyncio.create_task(self.start_background_sampling())
|
||||
logger.info("✅ 海马体采样器初始化成功")
|
||||
else:
|
||||
@@ -107,14 +106,10 @@ class HippocampusSampler:
|
||||
|
||||
# 生成两个正态分布的小时偏移
|
||||
recent_offsets = np.random.normal(
|
||||
loc=self.config.recent_mean_hours,
|
||||
scale=self.config.recent_std_hours,
|
||||
size=recent_samples
|
||||
loc=self.config.recent_mean_hours, scale=self.config.recent_std_hours, size=recent_samples
|
||||
)
|
||||
distant_offsets = np.random.normal(
|
||||
loc=self.config.distant_mean_hours,
|
||||
scale=self.config.distant_std_hours,
|
||||
size=distant_samples
|
||||
loc=self.config.distant_mean_hours, scale=self.config.distant_std_hours, size=distant_samples
|
||||
)
|
||||
|
||||
# 合并两个分布的偏移
|
||||
@@ -122,10 +117,7 @@ class HippocampusSampler:
|
||||
|
||||
# 转换为时间戳(使用绝对值确保时间点在过去)
|
||||
base_time = datetime.now()
|
||||
timestamps = [
|
||||
base_time - timedelta(hours=abs(offset))
|
||||
for offset in all_offsets
|
||||
]
|
||||
timestamps = [base_time - timedelta(hours=abs(offset)) for offset in all_offsets]
|
||||
|
||||
# 按时间排序(从最早到最近)
|
||||
return sorted(timestamps)
|
||||
@@ -171,7 +163,8 @@ class HippocampusSampler:
|
||||
if messages and len(messages) >= 2: # 至少需要2条消息
|
||||
# 过滤掉已经记忆过的消息
|
||||
filtered_messages = [
|
||||
msg for msg in messages
|
||||
msg
|
||||
for msg in messages
|
||||
if msg.get("memorized_times", 0) < 2 # 最多记忆2次
|
||||
]
|
||||
|
||||
@@ -229,7 +222,7 @@ class HippocampusSampler:
|
||||
conversation_text=input_text,
|
||||
context=context,
|
||||
timestamp=time.time(),
|
||||
bypass_interval=True # 海马体采样器绕过构建间隔限制
|
||||
bypass_interval=True, # 海马体采样器绕过构建间隔限制
|
||||
)
|
||||
|
||||
if memories:
|
||||
@@ -367,7 +360,7 @@ class HippocampusSampler:
|
||||
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
|
||||
|
||||
for i in range(0, len(time_samples), max_concurrent):
|
||||
batch = time_samples[i:i + max_concurrent]
|
||||
batch = time_samples[i : i + max_concurrent]
|
||||
tasks = []
|
||||
|
||||
# 创建并发收集任务
|
||||
@@ -392,7 +385,9 @@ class HippocampusSampler:
|
||||
|
||||
return collected_messages
|
||||
|
||||
async def _fuse_and_deduplicate_messages(self, collected_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]:
|
||||
async def _fuse_and_deduplicate_messages(
|
||||
self, collected_messages: list[list[dict[str, Any]]]
|
||||
) -> list[list[dict[str, Any]]]:
|
||||
"""融合和去重消息样本"""
|
||||
if not collected_messages:
|
||||
return []
|
||||
@@ -416,7 +411,7 @@ class HippocampusSampler:
|
||||
chat_id = message.get("chat_id", "")
|
||||
|
||||
# 简单哈希:内容前50字符 + 时间戳(精确到分钟) + 聊天ID
|
||||
hash_key = f"{content[:50]}_{int(timestamp//60)}_{chat_id}"
|
||||
hash_key = f"{content[:50]}_{int(timestamp // 60)}_{chat_id}"
|
||||
|
||||
if hash_key not in seen_hashes and len(content.strip()) > 10:
|
||||
seen_hashes.add(hash_key)
|
||||
@@ -448,7 +443,9 @@ class HippocampusSampler:
|
||||
# 返回原始消息组作为备选
|
||||
return collected_messages[:5] # 限制返回数量
|
||||
|
||||
def _merge_adjacent_messages(self, messages: list[dict[str, Any]], time_gap: int = 1800) -> list[list[dict[str, Any]]]:
|
||||
def _merge_adjacent_messages(
|
||||
self, messages: list[dict[str, Any]], time_gap: int = 1800
|
||||
) -> list[list[dict[str, Any]]]:
|
||||
"""合并时间间隔内的消息"""
|
||||
if not messages:
|
||||
return []
|
||||
@@ -479,7 +476,9 @@ class HippocampusSampler:
|
||||
|
||||
return result_groups
|
||||
|
||||
async def _build_batch_memory(self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]) -> dict[str, Any]:
|
||||
async def _build_batch_memory(
|
||||
self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]
|
||||
) -> dict[str, Any]:
|
||||
"""批量构建记忆"""
|
||||
if not fused_messages:
|
||||
return {"memory_count": 0, "memories": []}
|
||||
@@ -513,10 +512,7 @@ class HippocampusSampler:
|
||||
|
||||
# 一次性构建记忆
|
||||
memories = await self.memory_system.build_memory_from_conversation(
|
||||
conversation_text=batch_input_text,
|
||||
context=batch_context,
|
||||
timestamp=time.time(),
|
||||
bypass_interval=True
|
||||
conversation_text=batch_input_text, context=batch_context, timestamp=time.time(), bypass_interval=True
|
||||
)
|
||||
|
||||
if memories:
|
||||
@@ -545,11 +541,7 @@ class HippocampusSampler:
|
||||
if len(self.last_sample_results) > 10:
|
||||
self.last_sample_results.pop(0)
|
||||
|
||||
return {
|
||||
"memory_count": total_memory_count,
|
||||
"memories": total_memories,
|
||||
"result": result
|
||||
}
|
||||
return {"memory_count": total_memory_count, "memories": total_memories, "result": result}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量构建记忆失败: {e}")
|
||||
@@ -601,11 +593,7 @@ class HippocampusSampler:
|
||||
except Exception as e:
|
||||
logger.debug(f"单独构建失败: {e}")
|
||||
|
||||
return {
|
||||
"memory_count": total_count,
|
||||
"memories": total_memories,
|
||||
"fallback_mode": True
|
||||
}
|
||||
return {"memory_count": total_count, "memories": total_memories, "fallback_mode": True}
|
||||
|
||||
async def process_sample_timestamp(self, target_timestamp: float) -> str | None:
|
||||
"""处理单个时间戳采样(保留作为备选方法)"""
|
||||
@@ -696,7 +684,9 @@ class HippocampusSampler:
|
||||
"performance_metrics": {
|
||||
"avg_messages_per_sample": f"{recent_avg_messages:.1f}",
|
||||
"avg_memories_per_sample": f"{recent_avg_memory_count:.1f}",
|
||||
"fusion_efficiency": f"{(recent_avg_messages/max(recent_avg_memory_count, 1)):.1f}x" if recent_avg_messages > 0 else "N/A"
|
||||
"fusion_efficiency": f"{(recent_avg_messages / max(recent_avg_memory_count, 1)):.1f}x"
|
||||
if recent_avg_messages > 0
|
||||
else "N/A",
|
||||
},
|
||||
"config": {
|
||||
"sample_interval": self.config.sample_interval,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
返回: 适合直接嵌入提示词的大段文本;若无有效记忆返回空串。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
@@ -24,9 +24,12 @@ from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||
# 记忆采样模式枚举
|
||||
class MemorySamplingMode(Enum):
|
||||
"""记忆采样模式"""
|
||||
|
||||
HIPPOCAMPUS = "hippocampus" # 海马体模式:定时任务采样
|
||||
IMMEDIATE = "immediate" # 即时模式:回复后立即采样
|
||||
ALL = "all" # 所有模式:同时使用海马体和即时采样
|
||||
IMMEDIATE = "immediate" # 即时模式:回复后立即采样
|
||||
ALL = "all" # 所有模式:同时使用海马体和即时采样
|
||||
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -165,7 +168,6 @@ class MemorySystem:
|
||||
async def initialize(self):
|
||||
"""异步初始化记忆系统"""
|
||||
try:
|
||||
|
||||
# 初始化LLM模型
|
||||
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
|
||||
|
||||
@@ -264,6 +266,7 @@ class MemorySystem:
|
||||
if global_config.memory.enable_hippocampus_sampling:
|
||||
try:
|
||||
from .hippocampus_sampler import initialize_hippocampus_sampler
|
||||
|
||||
self.hippocampus_sampler = await initialize_hippocampus_sampler(self)
|
||||
logger.info("✅ 海马体采样器初始化成功")
|
||||
except Exception as e:
|
||||
@@ -321,7 +324,11 @@ class MemorySystem:
|
||||
return []
|
||||
|
||||
async def build_memory_from_conversation(
|
||||
self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None, bypass_interval: bool = False
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: dict[str, Any],
|
||||
timestamp: float | None = None,
|
||||
bypass_interval: bool = False,
|
||||
) -> list[MemoryChunk]:
|
||||
"""从对话中构建记忆
|
||||
|
||||
@@ -560,7 +567,6 @@ class MemorySystem:
|
||||
sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "precision")
|
||||
current_mode = MemorySamplingMode(sampling_mode)
|
||||
|
||||
|
||||
context["__sampling_mode"] = current_mode.value
|
||||
logger.debug(f"使用记忆采样模式: {current_mode.value}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user