style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 19:38:39 +08:00
committed by Windpicker-owo
parent e7aaafde2f
commit 00ba07e0e1
111 changed files with 2343 additions and 2316 deletions

View File

@@ -31,9 +31,11 @@ class AntiInjectionStatistics:
try:
async with get_db_session() as session:
# 获取最新的统计记录,如果没有则创建
stats = (await session.execute(
select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())
)).scalars().first()
stats = (
(await session.execute(select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())))
.scalars()
.first()
)
if not stats:
stats = AntiInjectionStats()
session.add(stats)
@@ -49,9 +51,11 @@ class AntiInjectionStatistics:
"""更新统计数据"""
try:
async with get_db_session() as session:
stats = (await session.execute(
select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())
)).scalars().first()
stats = (
(await session.execute(select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())))
.scalars()
.first()
)
if not stats:
stats = AntiInjectionStats()
session.add(stats)

View File

@@ -2,13 +2,13 @@ from typing import Dict, List, Optional, Any
import time
from src.plugin_system.base.base_chatter import BaseChatter
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner as ActionPlanner
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.plugin_system.base.component_types import ChatType, ComponentType
from src.plugin_system.base.component_types import ChatType
from src.common.logger import get_logger
logger = get_logger("chatter_manager")
class ChatterManager:
def __init__(self, action_manager: ChatterActionManager):
self.action_manager = action_manager
@@ -27,6 +27,7 @@ class ChatterManager:
"""从组件注册表自动注册已注册的chatter组件"""
try:
from src.plugin_system.core.component_registry import component_registry
# 获取所有CHATTER类型的组件
chatter_components = component_registry.get_enabled_chatter_registry()
for chatter_name, chatter_class in chatter_components.items():
@@ -70,7 +71,7 @@ class ChatterManager:
inactive_streams = []
for stream_id, instance in self.instances.items():
if hasattr(instance, 'get_activity_time'):
if hasattr(instance, "get_activity_time"):
activity_time = instance.get_activity_time()
if (current_time - activity_time) > max_inactive_seconds:
inactive_streams.append(stream_id)
@@ -91,6 +92,7 @@ class ChatterManager:
if not chatter_class:
# 如果没有找到精确匹配尝试查找支持ALL类型的chatter
from src.plugin_system.base.component_types import ChatType
all_chatter_class = self.get_chatter_class(ChatType.ALL)
if all_chatter_class:
chatter_class = all_chatter_class
@@ -110,6 +112,7 @@ class ChatterManager:
# 从 mood_manager 获取最新的 chat_stream 并同步回 StreamContext
try:
from src.mood.mood_manager import mood_manager
mood = mood_manager.get_mood_by_chat_id(stream_id)
if mood and mood.chat_stream:
context.chat_stream = mood.chat_stream
@@ -125,6 +128,7 @@ class ChatterManager:
# 在处理完成后,清除该流的未读消息
try:
from src.chat.message_manager.message_manager import message_manager
await message_manager.clear_stream_unread_messages(stream_id)
except Exception as clear_e:
logger.error(f"清除流 {stream_id} 未读消息时发生错误: {clear_e}")
@@ -149,4 +153,4 @@ class ChatterManager:
"streams_processed": 0,
"successful_executions": 0,
"failed_executions": 0,
}
}

View File

@@ -3,7 +3,6 @@
表情包发送历史记录模块
"""
import os
from typing import List, Dict
from collections import deque
from typing import List, Dict

View File

@@ -204,9 +204,7 @@ class MaiEmoji:
# 2. 删除数据库记录
try:
async with get_db_session() as session:
result = await session.execute(
select(Emoji).where(Emoji.emoji_hash == self.hash)
)
result = await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash))
will_delete_emoji = result.scalar_one_or_none()
if will_delete_emoji is None:
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
@@ -947,10 +945,7 @@ class EmojiManager:
existing_description = None
try:
async with get_db_session() as session:
stmt = select(Images).where(
Images.emoji_hash == image_hash,
Images.type == "emoji"
)
stmt = select(Images).where(Images.emoji_hash == image_hash, Images.type == "emoji")
result = await session.execute(stmt)
existing_image = result.scalar_one_or_none()
if existing_image and existing_image.description:

View File

@@ -12,7 +12,7 @@ from .energy_manager import (
ActivityEnergyCalculator,
RecencyEnergyCalculator,
RelationshipEnergyCalculator,
energy_manager
energy_manager,
)
__all__ = [
@@ -24,5 +24,5 @@ __all__ = [
"ActivityEnergyCalculator",
"RecencyEnergyCalculator",
"RelationshipEnergyCalculator",
"energy_manager"
]
"energy_manager",
]

View File

@@ -17,16 +17,18 @@ logger = get_logger("energy_system")
class EnergyLevel(Enum):
"""能量等级"""
VERY_LOW = 0.1 # 非常低
LOW = 0.3 # 低
NORMAL = 0.5 # 正常
HIGH = 0.7 # 高
VERY_HIGH = 0.9 # 非常高
LOW = 0.3 # 低
NORMAL = 0.5 # 正常
HIGH = 0.7 # 高
VERY_HIGH = 0.9 # 非常高
@dataclass
class EnergyComponent:
"""能量组件"""
name: str
value: float
weight: float = 1.0
@@ -47,6 +49,7 @@ class EnergyComponent:
class EnergyContext(TypedDict):
"""能量计算上下文"""
stream_id: str
messages: List[Any]
user_id: Optional[str]
@@ -54,6 +57,7 @@ class EnergyContext(TypedDict):
class EnergyResult(TypedDict):
"""能量计算结果"""
energy: float
level: EnergyLevel
distribution_interval: float
@@ -114,12 +118,7 @@ class ActivityEnergyCalculator(EnergyCalculator):
"""活跃度能量计算器"""
def __init__(self):
self.action_weights = {
"reply": 0.4,
"react": 0.3,
"mention": 0.2,
"other": 0.1
}
self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1}
def calculate(self, context: Dict[str, Any]) -> float:
"""基于活跃度计算能量"""
@@ -188,7 +187,7 @@ class RecencyEnergyCalculator(EnergyCalculator):
else:
recency_score = 0.1
logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age/3600:.1f}小时)")
logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age / 3600:.1f}小时)")
return recency_score
def get_weight(self) -> float:
@@ -236,11 +235,7 @@ class EnergyManager:
self.cache_ttl: int = 60 # 1分钟缓存
# AFC阈值配置
self.thresholds: Dict[str, float] = {
"high_match": 0.8,
"reply": 0.4,
"non_reply": 0.2
}
self.thresholds: Dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2}
# 统计信息
self.stats: Dict[str, Union[int, float, str]] = {
@@ -260,9 +255,13 @@ class EnergyManager:
"""从配置加载AFC阈值"""
try:
if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None:
self.thresholds["high_match"] = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
self.thresholds["high_match"] = getattr(
global_config.affinity_flow, "high_match_interest_threshold", 0.8
)
self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
self.thresholds["non_reply"] = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
self.thresholds["non_reply"] = getattr(
global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2
)
# 确保阈值关系合理
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
@@ -306,6 +305,7 @@ class EnergyManager:
# 支持同步和异步计算器
if callable(calculator.calculate):
import inspect
if inspect.iscoroutinefunction(calculator.calculate):
score = await calculator.calculate(context)
else:
@@ -347,11 +347,12 @@ class EnergyManager:
calculation_time = time.time() - start_time
total_calculations = self.stats["total_calculations"]
self.stats["average_calculation_time"] = (
(self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time)
/ total_calculations
)
self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time
) / total_calculations
logger.debug(f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)")
logger.debug(
f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)"
)
return final_energy
def _apply_threshold_adjustment(self, energy: float) -> float:
@@ -405,6 +406,7 @@ class EnergyManager:
# 添加随机扰动避免同步
import random
jitter = random.uniform(0.8, 1.2)
final_interval = base_interval * jitter
@@ -424,7 +426,8 @@ class EnergyManager:
"""清理过期缓存"""
current_time = time.time()
expired_keys = [
stream_id for stream_id, (_, timestamp) in self.energy_cache.items()
stream_id
for stream_id, (_, timestamp) in self.energy_cache.items()
if current_time - timestamp > self.cache_ttl
]
@@ -479,4 +482,4 @@ class EnergyManager:
# 全局能量管理器实例
energy_manager = EnergyManager()
energy_manager = EnergyManager()

View File

@@ -352,7 +352,7 @@ class ExpressionLearner:
create_date=current_time, # 手动设置创建日期
)
session.add(new_expression)
# 限制最大数量
exprs_result = await session.execute(
select(Expression)
@@ -456,11 +456,10 @@ class ExpressionLearnerManager:
self._ensure_expression_directories()
async def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
await self._auto_migrate_json_to_db()
await self._migrate_old_data_create_date()
if chat_id not in self.expression_learners:
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
return self.expression_learners[chat_id]
@@ -604,7 +603,9 @@ class ExpressionLearnerManager:
try:
async with get_db_session() as session:
# 查找所有create_date为空的表达方式
old_expressions_result = await session.execute(select(Expression).where(Expression.create_date.is_(None)))
old_expressions_result = await session.execute(
select(Expression).where(Expression.create_date.is_(None))
)
old_expressions = old_expressions_result.scalars().all()
updated_count = 0

View File

@@ -131,7 +131,9 @@ class BotInterestManager:
self.current_interests = generated_interests
active_count = len(generated_interests.get_active_tags())
logger.info(f"成功生成 {active_count} 个新兴趣标签。")
tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in generated_interests.get_active_tags()]
tags_info = [
f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in generated_interests.get_active_tags()
]
tags_str = "\n".join(tags_info)
logger.info(f"当前兴趣标签:\n{tags_str}")
@@ -639,11 +641,19 @@ class BotInterestManager:
async with get_db_session() as session:
# 查询最新的兴趣标签配置
db_interests = (await session.execute(
select(DBBotPersonalityInterests)
.where(DBBotPersonalityInterests.personality_id == personality_id)
.order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc())
)).scalars().first()
db_interests = (
(
await session.execute(
select(DBBotPersonalityInterests)
.where(DBBotPersonalityInterests.personality_id == personality_id)
.order_by(
DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc()
)
)
)
.scalars()
.first()
)
if db_interests:
logger.debug(f"在数据库中找到兴趣标签配置, 版本: {db_interests.version}")
@@ -728,10 +738,17 @@ class BotInterestManager:
async with get_db_session() as session:
# 检查是否已存在相同personality_id的记录
existing_record = (await session.execute(
select(DBBotPersonalityInterests)
.where(DBBotPersonalityInterests.personality_id == interests.personality_id)
)).scalars().first()
existing_record = (
(
await session.execute(
select(DBBotPersonalityInterests).where(
DBBotPersonalityInterests.personality_id == interests.personality_id
)
)
)
.scalars()
.first()
)
if existing_record:
# 更新现有记录
@@ -763,10 +780,17 @@ class BotInterestManager:
# 验证保存是否成功
async with get_db_session() as session:
saved_record = (await session.execute(
select(DBBotPersonalityInterests)
.where(DBBotPersonalityInterests.personality_id == interests.personality_id)
)).scalars().first()
saved_record = (
(
await session.execute(
select(DBBotPersonalityInterests).where(
DBBotPersonalityInterests.personality_id == interests.personality_id
)
)
)
.scalars()
.first()
)
if saved_record:
logger.info(f"✅ 验证成功数据库中存在personality_id为 {interests.personality_id} 的记录")
logger.info(f" 版本: {saved_record.version}")

View File

@@ -161,7 +161,7 @@ class EmbeddingStore:
@staticmethod
def _get_embeddings_batch_threaded(
strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
) -> List[Tuple[str, List[float]]]:
"""使用多线程批量获取嵌入向量

View File

@@ -101,7 +101,7 @@ class QAManager:
async def get_knowledge(self, question: str) -> Optional[Dict[str, Any]]:
"""
获取知识,返回结构化字典
Args:
question: 用户提出的问题
@@ -114,30 +114,27 @@ class QAManager:
return None
query_res = processed_result[0]
knowledge_items = []
for res_hash, relevance, *_ in query_res:
if store_item := self.embed_manager.paragraphs_embedding_store.store.get(res_hash):
knowledge_items.append({
"content": store_item.str,
"source": "内部知识库",
"relevance": f"{relevance:.4f}"
})
knowledge_items.append(
{"content": store_item.str, "source": "内部知识库", "relevance": f"{relevance:.4f}"}
)
if not knowledge_items:
return None
# 使用LLM生成总结
knowledge_text_for_summary = "\n\n".join([item['content'] for item in knowledge_items[:5]]) # 最多总结前5条
summary_prompt = f"根据以下信息,为问题 '{question}' 生成一个简洁的、不超过50字的摘要\n\n{knowledge_text_for_summary}"
knowledge_text_for_summary = "\n\n".join([item["content"] for item in knowledge_items[:5]]) # 最多总结前5条
summary_prompt = (
f"根据以下信息,为问题 '{question}' 生成一个简洁的、不超过50字的摘要\n\n{knowledge_text_for_summary}"
)
try:
summary, (_, _, _) = await self.qa_model.generate_response_async(summary_prompt)
except Exception as e:
logger.error(f"生成知识摘要失败: {e}")
summary = "无法生成摘要。"
return {
"knowledge_items": knowledge_items,
"summary": summary.strip() if summary else "没有可用的摘要。"
}
return {"knowledge_items": knowledge_items, "summary": summary.strip() if summary else "没有可用的摘要。"}

View File

@@ -12,44 +12,23 @@ from .memory_chunk import (
MemoryType,
ImportanceLevel,
ConfidenceLevel,
create_memory_chunk
create_memory_chunk,
)
# 遗忘引擎
from .memory_forgetting_engine import (
MemoryForgettingEngine,
ForgettingConfig,
get_memory_forgetting_engine
)
from .memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig, get_memory_forgetting_engine
# Vector DB存储系统
from .vector_memory_storage_v2 import (
VectorMemoryStorage,
VectorStorageConfig,
get_vector_memory_storage
)
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
# 记忆核心系统
from .memory_system import (
MemorySystem,
MemorySystemConfig,
get_memory_system,
initialize_memory_system
)
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
# 记忆管理器
from .memory_manager import (
MemoryManager,
MemoryResult,
memory_manager
)
from .memory_manager import MemoryManager, MemoryResult, memory_manager
# 激活器
from .enhanced_memory_activator import (
MemoryActivator,
memory_activator,
enhanced_memory_activator
)
from .enhanced_memory_activator import MemoryActivator, memory_activator, enhanced_memory_activator
# 兼容性别名
from .memory_chunk import MemoryChunk as Memory
@@ -64,28 +43,23 @@ __all__ = [
"ImportanceLevel",
"ConfidenceLevel",
"create_memory_chunk",
# 遗忘引擎
"MemoryForgettingEngine",
"ForgettingConfig",
"get_memory_forgetting_engine",
# Vector DB存储
"VectorMemoryStorage",
"VectorStorageConfig",
"VectorStorageConfig",
"get_vector_memory_storage",
# 记忆系统
"MemorySystem",
"MemorySystemConfig",
"get_memory_system",
"initialize_memory_system",
# 记忆管理器
"MemoryManager",
"MemoryResult",
"MemoryResult",
"memory_manager",
# 激活器
"MemoryActivator",
"memory_activator",
@@ -95,4 +69,4 @@ __all__ = [
# 版本信息
__version__ = "3.0.0"
__author__ = "MoFox Team"
__description__ = "简化记忆系统 - 统一记忆架构与智能遗忘机制"
__description__ = "简化记忆系统 - 统一记忆架构与智能遗忘机制"

View File

@@ -4,15 +4,14 @@
将增强记忆系统集成到现有MoFox Bot架构中
"""
import asyncio
import time
from typing import Dict, List, Optional, Any, Tuple
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from src.common.logger import get_logger
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_formatter import MemoryFormatter, FormatterConfig, format_memories_for_llm
from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@@ -36,6 +35,7 @@ MEMORY_TYPE_LABELS = {
@dataclass
class AdapterConfig:
"""适配器配置"""
enable_enhanced_memory: bool = True
integration_mode: str = "enhanced_only" # replace, enhanced_only
auto_migration: bool = True
@@ -61,7 +61,7 @@ class EnhancedMemoryAdapter:
"hybrid_used": 0,
"memories_created": 0,
"memories_retrieved": 0,
"average_processing_time": 0.0
"average_processing_time": 0.0,
}
async def initialize(self):
@@ -79,14 +79,11 @@ class EnhancedMemoryAdapter:
memory_value_threshold=self.config.memory_value_threshold,
fusion_threshold=self.config.fusion_threshold,
max_retrieval_results=self.config.max_retrieval_results,
enable_learning=True # 启用学习功能
enable_learning=True, # 启用学习功能
)
# 创建集成层
self.integration_layer = MemoryIntegrationLayer(
llm_model=self.llm_model,
config=integration_config
)
self.integration_layer = MemoryIntegrationLayer(llm_model=self.llm_model, config=integration_config)
# 初始化集成层
await self.integration_layer.initialize()
@@ -99,10 +96,7 @@ class EnhancedMemoryAdapter:
# 如果初始化失败,禁用增强记忆功能
self.config.enable_enhanced_memory = False
async def process_conversation_memory(
self,
context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
async def process_conversation_memory(self, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""处理对话记忆,以上下文为唯一输入"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"success": False, "error": "Enhanced memory not available"}
@@ -152,11 +146,7 @@ class EnhancedMemoryAdapter:
return {"success": False, "error": str(e)}
async def retrieve_relevant_memories(
self,
query: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: Optional[int] = None
) -> List[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.config.enable_enhanced_memory:
@@ -164,9 +154,7 @@ class EnhancedMemoryAdapter:
try:
limit = limit or self.config.max_retrieval_results
memories = await self.integration_layer.retrieve_relevant_memories(
query, None, context, limit
)
memories = await self.integration_layer.retrieve_relevant_memories(query, None, context, limit)
self.adapter_stats["memories_retrieved"] += len(memories)
logger.debug(f"检索到 {len(memories)} 条相关记忆")
@@ -178,11 +166,7 @@ class EnhancedMemoryAdapter:
return []
async def get_memory_context_for_prompt(
self,
query: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
max_memories: int = 5
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, max_memories: int = 5
) -> str:
"""获取用于提示词的记忆上下文"""
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
@@ -197,15 +181,11 @@ class EnhancedMemoryAdapter:
include_confidence=False,
use_emoji_icons=True,
group_by_type=False,
max_display_length=150
)
return format_memories_for_llm(
memories=memories,
query_context=query,
config=formatter_config
max_display_length=150,
)
return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config)
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
"""获取增强记忆系统摘要"""
if not self._initialized or not self.config.enable_enhanced_memory:
@@ -227,7 +207,7 @@ class EnhancedMemoryAdapter:
"adapter_stats": adapter_stats,
"integration_stats": integration_stats,
"total_memories_created": adapter_stats["memories_created"],
"total_memories_retrieved": adapter_stats["memories_retrieved"]
"total_memories_retrieved": adapter_stats["memories_retrieved"],
}
except Exception as e:
@@ -285,12 +265,12 @@ async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAd
from src.config.config import global_config
adapter_config = AdapterConfig(
enable_enhanced_memory=getattr(global_config.memory, 'enable_enhanced_memory', True),
integration_mode=getattr(global_config.memory, 'enhanced_memory_mode', 'enhanced_only'),
auto_migration=getattr(global_config.memory, 'enable_memory_migration', True),
memory_value_threshold=getattr(global_config.memory, 'memory_value_threshold', 0.6),
fusion_threshold=getattr(global_config.memory, 'fusion_threshold', 0.85),
max_retrieval_results=getattr(global_config.memory, 'max_retrieval_results', 10)
enable_enhanced_memory=getattr(global_config.memory, "enable_enhanced_memory", True),
integration_mode=getattr(global_config.memory, "enhanced_memory_mode", "enhanced_only"),
auto_migration=getattr(global_config.memory, "enable_memory_migration", True),
memory_value_threshold=getattr(global_config.memory, "memory_value_threshold", 0.6),
fusion_threshold=getattr(global_config.memory, "fusion_threshold", 0.85),
max_retrieval_results=getattr(global_config.memory, "max_retrieval_results", 10),
)
_enhanced_memory_adapter = EnhancedMemoryAdapter(llm_model, adapter_config)
@@ -312,13 +292,13 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest):
async def process_conversation_with_enhanced_memory(
context: Dict[str, Any],
llm_model: Optional[LLMRequest] = None
context: Dict[str, Any], llm_model: Optional[LLMRequest] = None
) -> Dict[str, Any]:
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
@@ -345,12 +325,13 @@ async def retrieve_memories_with_enhanced_system(
user_id: str,
context: Optional[Dict[str, Any]] = None,
limit: int = 10,
llm_model: Optional[LLMRequest] = None
llm_model: Optional[LLMRequest] = None,
) -> List[MemoryChunk]:
"""使用增强记忆系统检索记忆"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
@@ -366,12 +347,13 @@ async def get_memory_context_for_prompt(
user_id: str,
context: Optional[Dict[str, Any]] = None,
max_memories: int = 5,
llm_model: Optional[LLMRequest] = None
llm_model: Optional[LLMRequest] = None,
) -> str:
"""获取用于提示词的记忆上下文"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
@@ -379,4 +361,4 @@ async def get_memory_context_for_prompt(
return await adapter.get_memory_context_for_prompt(query, user_id, context, max_memories)
except Exception as e:
logger.error(f"获取记忆上下文失败: {e}", exc_info=True)
return ""
return ""

View File

@@ -4,7 +4,6 @@
用于在消息处理过程中自动构建和检索记忆
"""
import asyncio
from typing import Dict, List, Any, Optional
from datetime import datetime
@@ -19,8 +18,7 @@ class EnhancedMemoryHooks:
"""增强记忆系统钩子 - 自动处理消息的记忆构建和检索"""
def __init__(self):
self.enabled = (global_config.memory.enable_memory and
global_config.memory.enable_enhanced_memory)
self.enabled = global_config.memory.enable_memory and global_config.memory.enable_enhanced_memory
self.processed_messages = set() # 避免重复处理
async def process_message_for_memory(
@@ -29,7 +27,7 @@ class EnhancedMemoryHooks:
user_id: str,
chat_id: str,
message_id: str,
context: Optional[Dict[str, Any]] = None
context: Optional[Dict[str, Any]] = None,
) -> bool:
"""
处理消息并构建记忆
@@ -76,7 +74,7 @@ class EnhancedMemoryHooks:
"timestamp": datetime.now().timestamp(),
"message_type": "user_message",
**bot_context,
**(context or {})
**(context or {}),
}
# 处理对话并构建记忆
@@ -84,7 +82,7 @@ class EnhancedMemoryHooks:
conversation_text=message_content,
context=memory_context,
user_id=user_id,
timestamp=memory_context["timestamp"]
timestamp=memory_context["timestamp"],
)
# 标记消息已处理
@@ -108,7 +106,7 @@ class EnhancedMemoryHooks:
user_id: str,
chat_id: str,
limit: int = 5,
extra_context: Optional[Dict[str, Any]] = None
extra_context: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
为回复获取相关记忆
@@ -134,9 +132,7 @@ class EnhancedMemoryHooks:
context = {
"chat_id": chat_id,
"query_intent": "response_generation",
"expected_memory_types": [
"personal_fact", "event", "preference", "opinion"
]
"expected_memory_types": ["personal_fact", "event", "preference", "opinion"],
}
if extra_context:
@@ -144,10 +140,7 @@ class EnhancedMemoryHooks:
# 获取相关记忆
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
query_text=query_text,
user_id=user_id,
context=context,
limit=limit
query_text=query_text, user_id=user_id, context=context, limit=limit
)
# 转换为字典格式
@@ -199,4 +192,4 @@ class EnhancedMemoryHooks:
# 创建全局实例
enhanced_memory_hooks = EnhancedMemoryHooks()
enhanced_memory_hooks = EnhancedMemoryHooks()

View File

@@ -4,7 +4,6 @@
用于在现有系统中无缝集成增强记忆功能
"""
import asyncio
from typing import Dict, Any, Optional
from src.common.logger import get_logger
@@ -14,11 +13,7 @@ logger = get_logger(__name__)
async def process_user_message_memory(
message_content: str,
user_id: str,
chat_id: str,
message_id: str,
context: Optional[Dict[str, Any]] = None
message_content: str, user_id: str, chat_id: str, message_id: str, context: Optional[Dict[str, Any]] = None
) -> bool:
"""
处理用户消息并构建记忆
@@ -35,11 +30,7 @@ async def process_user_message_memory(
"""
try:
success = await enhanced_memory_hooks.process_message_for_memory(
message_content=message_content,
user_id=user_id,
chat_id=chat_id,
message_id=message_id,
context=context
message_content=message_content, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
)
if success:
@@ -53,11 +44,7 @@ async def process_user_message_memory(
async def get_relevant_memories_for_response(
query_text: str,
user_id: str,
chat_id: str,
limit: int = 5,
extra_context: Optional[Dict[str, Any]] = None
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
为回复获取相关记忆
@@ -74,29 +61,17 @@ async def get_relevant_memories_for_response(
"""
try:
memories = await enhanced_memory_hooks.get_memory_for_response(
query_text=query_text,
user_id=user_id,
chat_id=chat_id,
limit=limit,
extra_context=extra_context
query_text=query_text, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=extra_context
)
result = {
"has_memories": len(memories) > 0,
"memories": memories,
"memory_count": len(memories)
}
result = {"has_memories": len(memories) > 0, "memories": memories, "memory_count": len(memories)}
logger.debug(f"为回复获取到 {len(memories)} 条相关记忆")
return result
except Exception as e:
logger.error(f"获取回复记忆失败: {e}")
return {
"has_memories": False,
"memories": [],
"memory_count": 0
}
return {"has_memories": False, "memories": [], "memory_count": 0}
def format_memories_for_prompt(memories: Dict[str, Any]) -> str:
@@ -152,16 +127,13 @@ def get_memory_system_status() -> Dict[str, Any]:
"enabled": enhanced_memory_hooks.enabled,
"enhanced_system_initialized": enhanced_memory_manager.is_initialized,
"processed_messages_count": len(enhanced_memory_hooks.processed_messages),
"system_type": "enhanced_memory_system"
"system_type": "enhanced_memory_system",
}
# 便捷函数
async def remember_message(
message: str,
user_id: str = "default_user",
chat_id: str = "default_chat",
context: Optional[Dict[str, Any]] = None
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: Optional[Dict[str, Any]] = None
) -> bool:
"""
便捷的记忆构建函数
@@ -175,13 +147,10 @@ async def remember_message(
bool: 是否成功
"""
import uuid
message_id = str(uuid.uuid4())
return await process_user_message_memory(
message_content=message,
user_id=user_id,
chat_id=chat_id,
message_id=message_id,
context=context
message_content=message, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
)
@@ -190,7 +159,7 @@ async def recall_memories(
user_id: str = "default_user",
chat_id: str = "default_chat",
limit: int = 5,
context: Optional[Dict[str, Any]] = None
context: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
便捷的记忆检索函数
@@ -205,9 +174,5 @@ async def recall_memories(
Dict: 记忆信息
"""
return await get_relevant_memories_for_response(
query_text=query,
user_id=user_id,
chat_id=chat_id,
limit=limit,
extra_context=context
)
query_text=query, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=context
)

View File

@@ -18,32 +18,34 @@ logger = get_logger(__name__)
class IntentType(Enum):
"""对话意图类型"""
FACT_QUERY = "fact_query" # 事实查询
EVENT_RECALL = "event_recall" # 事件回忆
FACT_QUERY = "fact_query" # 事实查询
EVENT_RECALL = "event_recall" # 事件回忆
PREFERENCE_CHECK = "preference_check" # 偏好检查
GENERAL_CHAT = "general_chat" # 一般对话
UNKNOWN = "unknown" # 未知意图
GENERAL_CHAT = "general_chat" # 一般对话
UNKNOWN = "unknown" # 未知意图
@dataclass
class ReRankingConfig:
"""重排序配置"""
# 权重配置 (w1 + w2 + w3 + w4 = 1.0)
semantic_weight: float = 0.5 # 语义相似度权重
recency_weight: float = 0.2 # 时效性权重
usage_freq_weight: float = 0.2 # 使用频率权重
type_match_weight: float = 0.1 # 类型匹配权重
semantic_weight: float = 0.5 # 语义相似度权重
recency_weight: float = 0.2 # 时效性权重
usage_freq_weight: float = 0.2 # 使用频率权重
type_match_weight: float = 0.1 # 类型匹配权重
# 时效性衰减参数
recency_decay_rate: float = 0.1 # 时效性衰减率 (天)
recency_decay_rate: float = 0.1 # 时效性衰减率 (天)
# 使用频率计算参数
freq_log_base: float = 2.0 # 对数底数
freq_max_score: float = 5.0 # 最大频率得分
freq_log_base: float = 2.0 # 对数底数
freq_max_score: float = 5.0 # 最大频率得分
# 类型匹配权重映射
type_match_weights: Dict[str, Dict[str, float]] = None
def __post_init__(self):
"""初始化类型匹配权重"""
if self.type_match_weights is None:
@@ -53,102 +55,150 @@ class ReRankingConfig:
MemoryType.KNOWLEDGE.value: 0.8,
MemoryType.PREFERENCE.value: 0.5,
MemoryType.EVENT.value: 0.3,
"default": 0.3
"default": 0.3,
},
IntentType.EVENT_RECALL.value: {
MemoryType.EVENT.value: 1.0,
MemoryType.EXPERIENCE.value: 0.8,
MemoryType.EMOTION.value: 0.6,
MemoryType.PERSONAL_FACT.value: 0.5,
"default": 0.5
"default": 0.5,
},
IntentType.PREFERENCE_CHECK.value: {
MemoryType.PREFERENCE.value: 1.0,
MemoryType.OPINION.value: 0.8,
MemoryType.GOAL.value: 0.6,
MemoryType.PERSONAL_FACT.value: 0.4,
"default": 0.4
"default": 0.4,
},
IntentType.GENERAL_CHAT.value: {
"default": 0.8
},
IntentType.UNKNOWN.value: {
"default": 0.8
}
IntentType.GENERAL_CHAT.value: {"default": 0.8},
IntentType.UNKNOWN.value: {"default": 0.8},
}
class IntentClassifier:
"""轻量级意图识别器"""
def __init__(self):
# 关键词模式匹配规则
self.patterns = {
IntentType.FACT_QUERY: [
# 中文模式
"我是", "我的", "我叫", "我在", "我住在", "我的职业", "我的工作",
"什么时候", "在哪里", "是什么", "多少", "几岁", "年龄",
"我是",
"我的",
"我叫",
"我在",
"我住在",
"我的职业",
"我的工作",
"什么时候",
"在哪里",
"是什么",
"多少",
"几岁",
"年龄",
# 英文模式
"what is", "where is", "when is", "how old", "my name", "i am", "i live"
"what is",
"where is",
"when is",
"how old",
"my name",
"i am",
"i live",
],
IntentType.EVENT_RECALL: [
# 中文模式
"记得", "想起", "还记得", "那次", "上次", "之前", "以前", "曾经",
"发生过", "经历", "做过", "去过", "见过",
"记得",
"想起",
"还记得",
"那次",
"上次",
"之前",
"以前",
"曾经",
"发生过",
"经历",
"做过",
"去过",
"见过",
# 英文模式
"remember", "recall", "last time", "before", "previously", "happened", "experience"
"remember",
"recall",
"last time",
"before",
"previously",
"happened",
"experience",
],
IntentType.PREFERENCE_CHECK: [
# 中文模式
"喜欢", "不喜欢", "偏好", "爱好", "兴趣", "讨厌", "最爱", "最喜欢",
"习惯", "通常", "一般", "倾向于", "喜欢",
"喜欢",
"喜欢",
"偏好",
"爱好",
"兴趣",
"讨厌",
"最爱",
"最喜欢",
"习惯",
"通常",
"一般",
"倾向于",
"更喜欢",
# 英文模式
"like", "love", "hate", "prefer", "favorite", "usually", "tend to", "interest"
]
"like",
"love",
"hate",
"prefer",
"favorite",
"usually",
"tend to",
"interest",
],
}
def classify_intent(self, query: str, context: Dict[str, Any]) -> IntentType:
"""识别对话意图"""
if not query:
return IntentType.UNKNOWN
query_lower = query.lower()
# 统计各意图的匹配分数
intent_scores = {intent: 0 for intent in IntentType}
for intent, patterns in self.patterns.items():
for pattern in patterns:
if pattern in query_lower:
intent_scores[intent] += 1
# 返回得分最高的意图
max_score = max(intent_scores.values())
if max_score == 0:
return IntentType.GENERAL_CHAT
for intent, score in intent_scores.items():
if score == max_score:
return intent
return IntentType.GENERAL_CHAT
class EnhancedReRanker:
"""增强重排序器 - 实现文档设计的多维度评分模型"""
def __init__(self, config: Optional[ReRankingConfig] = None):
self.config = config or ReRankingConfig()
self.intent_classifier = IntentClassifier()
# 验证权重和为1.0
total_weight = (
self.config.semantic_weight +
self.config.recency_weight +
self.config.usage_freq_weight +
self.config.type_match_weight
self.config.semantic_weight
+ self.config.recency_weight
+ self.config.usage_freq_weight
+ self.config.type_match_weight
)
if abs(total_weight - 1.0) > 0.01:
logger.warning(f"重排序权重和不为1.0: {total_weight}, 将进行归一化")
# 归一化权重
@@ -156,94 +206,94 @@ class EnhancedReRanker:
self.config.recency_weight /= total_weight
self.config.usage_freq_weight /= total_weight
self.config.type_match_weight /= total_weight
def rerank_memories(
self,
query: str,
candidate_memories: List[Tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
context: Dict[str, Any],
limit: int = 10
limit: int = 10,
) -> List[Tuple[str, MemoryChunk, float]]:
"""
对候选记忆进行重排序
Args:
query: 查询文本
candidate_memories: 候选记忆列表 [(memory_id, memory, vector_similarity)]
context: 上下文信息
limit: 返回数量限制
Returns:
重排序后的记忆列表 [(memory_id, memory, final_score)]
"""
if not candidate_memories:
return []
# 识别查询意图
intent = self.intent_classifier.classify_intent(query, context)
logger.debug(f"识别到查询意图: {intent.value}")
# 计算每个候选记忆的最终得分
scored_memories = []
current_time = time.time()
for memory_id, memory, vector_sim in candidate_memories:
try:
# 1. 语义相似度得分 (已归一化到[0,1])
semantic_score = self._normalize_similarity(vector_sim)
# 2. 时效性得分
recency_score = self._calculate_recency_score(memory, current_time)
# 3. 使用频率得分
usage_freq_score = self._calculate_usage_frequency_score(memory)
# 4. 类型匹配得分
type_match_score = self._calculate_type_match_score(memory, intent)
# 计算最终得分
final_score = (
self.config.semantic_weight * semantic_score +
self.config.recency_weight * recency_score +
self.config.usage_freq_weight * usage_freq_score +
self.config.type_match_weight * type_match_score
self.config.semantic_weight * semantic_score
+ self.config.recency_weight * recency_score
+ self.config.usage_freq_weight * usage_freq_score
+ self.config.type_match_weight * type_match_score
)
scored_memories.append((memory_id, memory, final_score))
# 记录调试信息
logger.debug(
f"记忆评分 {memory_id[:8]}: semantic={semantic_score:.3f}, "
f"recency={recency_score:.3f}, freq={usage_freq_score:.3f}, "
f"type={type_match_score:.3f}, final={final_score:.3f}"
)
except Exception as e:
logger.error(f"计算记忆 {memory_id} 得分时出错: {e}")
# 使用向量相似度作为后备得分
scored_memories.append((memory_id, memory, vector_sim))
# 按最终得分降序排序
scored_memories.sort(key=lambda x: x[2], reverse=True)
# 返回前N个结果
result = scored_memories[:limit]
highest_score = result[0][2] if result else 0.0
logger.info(
f"重排序完成: 候选={len(candidate_memories)}, 返回={len(result)}, "
f"意图={intent.value}, 最高分={highest_score:.3f}"
)
return result
def _normalize_similarity(self, raw_similarity: float) -> float:
"""归一化相似度到[0,1]区间"""
# 假设原始相似度已经在[-1,1]或[0,1]区间
if raw_similarity < 0:
return (raw_similarity + 1) / 2 # 从[-1,1]映射到[0,1]
return min(1.0, max(0.0, raw_similarity)) # 确保在[0,1]区间
def _calculate_recency_score(self, memory: MemoryChunk, current_time: float) -> float:
"""
计算时效性得分
@@ -251,13 +301,13 @@ class EnhancedReRanker:
"""
last_accessed = memory.metadata.last_accessed or memory.metadata.created_at
days_old = (current_time - last_accessed) / (24 * 3600) # 转换为天数
if days_old < 0:
days_old = 0 # 处理时间异常
score = 1 / (1 + self.config.recency_decay_rate * days_old)
return min(1.0, max(0.0, score))
def _calculate_usage_frequency_score(self, memory: MemoryChunk) -> float:
"""
计算使用频率得分
@@ -266,22 +316,22 @@ class EnhancedReRanker:
access_count = memory.metadata.access_count
if access_count <= 0:
return 0.0
log_count = math.log2(access_count + 1)
score = log_count / self.config.freq_max_score
return min(1.0, max(0.0, score))
def _calculate_type_match_score(self, memory: MemoryChunk, intent: IntentType) -> float:
"""计算类型匹配得分"""
memory_type = memory.memory_type.value
intent_value = intent.value
# 获取对应意图的类型权重映射
type_weights = self.config.type_match_weights.get(intent_value, {})
# 查找具体类型的权重,如果没有则使用默认权重
score = type_weights.get(memory_type, type_weights.get("default", 0.8))
return min(1.0, max(0.0, score))
@@ -294,7 +344,7 @@ def rerank_candidate_memories(
candidate_memories: List[Tuple[str, MemoryChunk, float]],
context: Dict[str, Any],
limit: int = 10,
config: Optional[ReRankingConfig] = None
config: Optional[ReRankingConfig] = None,
) -> List[Tuple[str, MemoryChunk, float]]:
"""
便捷函数:对候选记忆进行重排序
@@ -303,5 +353,5 @@ def rerank_candidate_memories(
reranker = EnhancedReRanker(config)
else:
reranker = default_reranker
return reranker.rerank_memories(query, candidate_memories, context, limit)
return reranker.rerank_memories(query, candidate_memories, context, limit)

View File

@@ -12,7 +12,7 @@ from enum import Enum
from src.common.logger import get_logger
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@@ -20,13 +20,15 @@ logger = get_logger(__name__)
class IntegrationMode(Enum):
"""集成模式"""
REPLACE = "replace" # 完全替换现有记忆系统
REPLACE = "replace" # 完全替换现有记忆系统
ENHANCED_ONLY = "enhanced_only" # 仅使用增强记忆系统
@dataclass
class IntegrationConfig:
"""集成配置"""
mode: IntegrationMode = IntegrationMode.ENHANCED_ONLY
enable_enhanced_memory: bool = True
memory_value_threshold: float = 0.6
@@ -51,7 +53,7 @@ class MemoryIntegrationLayer:
"enhanced_queries": 0,
"memory_creations": 0,
"average_response_time": 0.0,
"success_rate": 0.0
"success_rate": 0.0,
}
# 初始化锁
@@ -88,6 +90,7 @@ class MemoryIntegrationLayer:
# 创建增强记忆系统配置
from src.chat.memory_system.enhanced_memory_core import MemorySystemConfig
memory_config = MemorySystemConfig.from_global_config()
# 使用集成配置覆盖部分值
@@ -96,9 +99,7 @@ class MemoryIntegrationLayer:
memory_config.final_recall_limit = self.config.max_retrieval_results
# 创建增强记忆系统
self.enhanced_memory = EnhancedMemorySystem(
config=memory_config
)
self.enhanced_memory = EnhancedMemorySystem(config=memory_config)
# 如果外部提供了LLM模型注入到系统中
if self.llm_model is not None:
@@ -112,10 +113,7 @@ class MemoryIntegrationLayer:
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
raise
async def process_conversation(
self,
context: Dict[str, Any]
) -> Dict[str, Any]:
async def process_conversation(self, context: Dict[str, Any]) -> Dict[str, Any]:
"""处理对话记忆,仅使用上下文信息"""
if not self._initialized or not self.enhanced_memory:
return {"success": False, "error": "Memory system not available"}
@@ -154,7 +152,7 @@ class MemoryIntegrationLayer:
query: str,
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None
limit: Optional[int] = None,
) -> List[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.enhanced_memory:
@@ -163,10 +161,7 @@ class MemoryIntegrationLayer:
try:
limit = limit or self.config.max_retrieval_results
memories = await self.enhanced_memory.retrieve_relevant_memories(
query=query,
user_id=None,
context=context or {},
limit=limit
query=query, user_id=None, context=context or {}, limit=limit
)
memory_count = len(memories)
@@ -191,7 +186,7 @@ class MemoryIntegrationLayer:
"status": "initialized",
"mode": self.config.mode.value,
"enhanced_memory": enhanced_status,
"integration_stats": self.integration_stats.copy()
"integration_stats": self.integration_stats.copy(),
}
except Exception as e:
@@ -248,4 +243,4 @@ class MemoryIntegrationLayer:
logger.info("✅ 记忆系统集成层已关闭")
except Exception as e:
logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True)
logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True)

View File

@@ -4,17 +4,15 @@
提供与现有MoFox Bot系统的无缝集成点
"""
import asyncio
import time
from typing import Dict, List, Optional, Any, Callable
from typing import Dict, Optional, Any
from dataclasses import dataclass
from src.common.logger import get_logger
from src.chat.memory_system.enhanced_memory_adapter import (
get_enhanced_memory_adapter,
process_conversation_with_enhanced_memory,
retrieve_memories_with_enhanced_system,
get_memory_context_for_prompt
get_memory_context_for_prompt,
)
logger = get_logger(__name__)
@@ -23,6 +21,7 @@ logger = get_logger(__name__)
@dataclass
class HookResult:
"""钩子执行结果"""
success: bool
data: Any = None
error: Optional[str] = None
@@ -39,7 +38,7 @@ class MemoryIntegrationHooks:
"memory_retrieval_hooks": 0,
"prompt_enhancement_hooks": 0,
"total_hook_executions": 0,
"average_hook_time": 0.0
"average_hook_time": 0.0,
}
async def register_hooks(self):
@@ -130,10 +129,7 @@ class MemoryIntegrationHooks:
from src.plugin_system.base.component_types import EventType
# 注册消息后处理事件
event_manager.subscribe(
EventType.MESSAGE_PROCESSED,
self._on_message_processed_handler
)
event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler)
logger.debug("已注册到事件系统的消息处理钩子")
except ImportError:
@@ -144,10 +140,8 @@ class MemoryIntegrationHooks:
from src.chat.message_manager import message_manager
# 如果消息管理器支持钩子注册
if hasattr(message_manager, 'register_post_process_hook'):
message_manager.register_post_process_hook(
self._on_message_processed_hook
)
if hasattr(message_manager, "register_post_process_hook"):
message_manager.register_post_process_hook(self._on_message_processed_hook)
logger.debug("已注册到消息管理器的处理钩子")
except ImportError:
@@ -164,10 +158,8 @@ class MemoryIntegrationHooks:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
if hasattr(chat_manager, 'register_save_hook'):
chat_manager.register_save_hook(
self._on_chat_stream_save_hook
)
if hasattr(chat_manager, "register_save_hook"):
chat_manager.register_save_hook(self._on_chat_stream_save_hook)
logger.debug("已注册到聊天流管理器的保存钩子")
except ImportError:
@@ -183,10 +175,8 @@ class MemoryIntegrationHooks:
try:
from src.chat.replyer.default_generator import default_generator
if hasattr(default_generator, 'register_pre_generation_hook'):
default_generator.register_pre_generation_hook(
self._on_pre_response_hook
)
if hasattr(default_generator, "register_pre_generation_hook"):
default_generator.register_pre_generation_hook(self._on_pre_response_hook)
logger.debug("已注册到回复生成器的前置钩子")
except ImportError:
@@ -202,10 +192,8 @@ class MemoryIntegrationHooks:
try:
from src.chat.knowledge.knowledge_lib import knowledge_manager
if hasattr(knowledge_manager, 'register_query_enhancer'):
knowledge_manager.register_query_enhancer(
self._on_knowledge_query_hook
)
if hasattr(knowledge_manager, "register_query_enhancer"):
knowledge_manager.register_query_enhancer(self._on_knowledge_query_hook)
logger.debug("已注册到知识库的查询增强钩子")
except ImportError:
@@ -221,10 +209,8 @@ class MemoryIntegrationHooks:
try:
from src.chat.utils.prompt import prompt_manager
if hasattr(prompt_manager, 'register_enhancer'):
prompt_manager.register_enhancer(
self._on_prompt_building_hook
)
if hasattr(prompt_manager, "register_enhancer"):
prompt_manager.register_enhancer(self._on_prompt_building_hook)
logger.debug("已注册到提示词管理器的增强钩子")
except ImportError:
@@ -278,7 +264,7 @@ class MemoryIntegrationHooks:
"platform": message_info.get("platform", "unknown"),
"interest_value": message_data.get("interest_value", 0.0),
"keywords": message_data.get("key_words", []),
"timestamp": message_data.get("time", time.time())
"timestamp": message_data.get("time", time.time()),
}
# 使用增强记忆系统处理对话
@@ -296,7 +282,7 @@ class MemoryIntegrationHooks:
return HookResult(success=True, data=result, processing_time=processing_time)
else:
logger.warning(f"消息处理钩子执行失败: {result.get('error')}")
return HookResult(success=False, error=result.get('error'), processing_time=processing_time)
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
@@ -334,7 +320,7 @@ class MemoryIntegrationHooks:
"stream_id": chat_stream_data.get("stream_id"),
"platform": chat_stream_data.get("platform", "unknown"),
"message_count": len(messages),
"timestamp": time.time()
"timestamp": time.time(),
}
# 使用增强记忆系统处理对话
@@ -352,7 +338,7 @@ class MemoryIntegrationHooks:
return HookResult(success=True, data=result, processing_time=processing_time)
else:
logger.warning(f"聊天流保存钩子执行失败: {result.get('error')}")
return HookResult(success=False, error=result.get('error'), processing_time=processing_time)
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
@@ -375,9 +361,7 @@ class MemoryIntegrationHooks:
return HookResult(success=True, data="No query provided")
# 检索相关记忆
memories = await retrieve_memories_with_enhanced_system(
query, user_id, context, limit=5
)
memories = await retrieve_memories_with_enhanced_system(query, user_id, context, limit=5)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
@@ -411,9 +395,7 @@ class MemoryIntegrationHooks:
return HookResult(success=True, data="No query provided")
# 获取记忆上下文并增强查询
memory_context = await get_memory_context_for_prompt(
query, user_id, context, max_memories=3
)
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=3)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
@@ -445,9 +427,7 @@ class MemoryIntegrationHooks:
return HookResult(success=True, data="No query provided")
# 获取记忆上下文
memory_context = await get_memory_context_for_prompt(
query, user_id, context, max_memories=5
)
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=5)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
@@ -499,6 +479,7 @@ class MemoryMaintenanceTask:
# 获取适配器实例
try:
from src.chat.memory_system.enhanced_memory_adapter import _enhanced_memory_adapter
if _enhanced_memory_adapter:
await _enhanced_memory_adapter.maintenance()
logger.info("✅ 增强记忆系统维护任务完成")
@@ -543,4 +524,4 @@ async def initialize_memory_integration_hooks():
return hooks
except Exception as e:
logger.error(f"❌ 记忆集成钩子初始化失败: {e}", exc_info=True)
return None
return None

View File

@@ -4,12 +4,10 @@
为记忆系统提供多维度的精准过滤和查询能力
"""
import os
import time
import orjson
from typing import Dict, List, Optional, Tuple, Set, Any, Union
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum
import threading
from collections import defaultdict
@@ -23,23 +21,25 @@ logger = get_logger(__name__)
class IndexType(Enum):
"""索引类型"""
MEMORY_TYPE = "memory_type" # 记忆类型索引
USER_ID = "user_id" # 用户ID索引
SUBJECT = "subject" # 主体索引
KEYWORD = "keyword" # 关键词索引
TAG = "tag" # 标签索引
CATEGORY = "category" # 分类索引
TIMESTAMP = "timestamp" # 时间索引
CONFIDENCE = "confidence" # 置信度索引
IMPORTANCE = "importance" # 重要性索引
MEMORY_TYPE = "memory_type" # 记忆类型索引
USER_ID = "user_id" # 用户ID索引
SUBJECT = "subject" # 主体索引
KEYWORD = "keyword" # 关键词索引
TAG = "tag" # 标签索引
CATEGORY = "category" # 分类索引
TIMESTAMP = "timestamp" # 时间索引
CONFIDENCE = "confidence" # 置信度索引
IMPORTANCE = "importance" # 重要性索引
RELATIONSHIP_SCORE = "relationship_score" # 关系分索引
ACCESS_FREQUENCY = "access_frequency" # 访问频率索引
SEMANTIC_HASH = "semantic_hash" # 语义哈希索引
SEMANTIC_HASH = "semantic_hash" # 语义哈希索引
@dataclass
class IndexQuery:
"""索引查询条件"""
user_ids: Optional[List[str]] = None
memory_types: Optional[List[MemoryType]] = None
subjects: Optional[List[str]] = None
@@ -61,6 +61,7 @@ class IndexQuery:
@dataclass
class IndexResult:
"""索引结果"""
memory_ids: List[str]
total_count: int
query_time: float
@@ -102,7 +103,7 @@ class MetadataIndexManager:
"average_query_time": 0.0,
"total_queries": 0,
"cache_hit_rate": 0.0,
"cache_hits": 0
"cache_hits": 0,
}
# 线程锁
@@ -171,9 +172,8 @@ class MetadataIndexManager:
index_time = time.time() - start_time
self.index_stats["index_build_time"] = (
(self.index_stats["index_build_time"] * (len(memories) - 1) + index_time) /
len(memories)
)
self.index_stats["index_build_time"] * (len(memories) - 1) + index_time
) / len(memories)
logger.debug(f"元数据索引完成,{len(memories)} 条记忆,耗时 {index_time:.3f}")
@@ -258,7 +258,7 @@ class MetadataIndexManager:
"relationship_score": memory.metadata.relationship_score,
"relevance_score": memory.metadata.relevance_score,
"semantic_hash": memory.semantic_hash,
"subjects": memory.subjects
"subjects": memory.subjects,
}
# 记忆类型索引
@@ -355,21 +355,20 @@ class MetadataIndexManager:
# 限制数量
if query.limit and len(filtered_ids) > query.limit:
filtered_ids = filtered_ids[:query.limit]
filtered_ids = filtered_ids[: query.limit]
# 记录查询统计
query_time = time.time() - start_time
self.index_stats["total_queries"] += 1
self.index_stats["average_query_time"] = (
(self.index_stats["average_query_time"] * (self.index_stats["total_queries"] - 1) + query_time) /
self.index_stats["total_queries"]
)
self.index_stats["average_query_time"] * (self.index_stats["total_queries"] - 1) + query_time
) / self.index_stats["total_queries"]
return IndexResult(
memory_ids=filtered_ids,
total_count=len(filtered_ids),
query_time=query_time,
filtered_by=self._get_applied_filters(query)
filtered_by=self._get_applied_filters(query),
)
except Exception as e:
@@ -486,15 +485,15 @@ class MetadataIndexManager:
if query.time_range:
start_time, end_time = query.time_range
filtered_ids = [
memory_id for memory_id in filtered_ids
if self._is_in_time_range(memory_id, start_time, end_time)
memory_id for memory_id in filtered_ids if self._is_in_time_range(memory_id, start_time, end_time)
]
# 置信度过滤
if query.confidence_levels:
confidence_set = set(query.confidence_levels)
filtered_ids = [
memory_id for memory_id in filtered_ids
memory_id
for memory_id in filtered_ids
if self.memory_metadata_cache[memory_id]["confidence"] in confidence_set
]
@@ -502,27 +501,31 @@ class MetadataIndexManager:
if query.importance_levels:
importance_set = set(query.importance_levels)
filtered_ids = [
memory_id for memory_id in filtered_ids
memory_id
for memory_id in filtered_ids
if self.memory_metadata_cache[memory_id]["importance"] in importance_set
]
# 关系分范围过滤
if query.min_relationship_score is not None:
filtered_ids = [
memory_id for memory_id in filtered_ids
memory_id
for memory_id in filtered_ids
if self.memory_metadata_cache[memory_id]["relationship_score"] >= query.min_relationship_score
]
if query.max_relationship_score is not None:
filtered_ids = [
memory_id for memory_id in filtered_ids
memory_id
for memory_id in filtered_ids
if self.memory_metadata_cache[memory_id]["relationship_score"] <= query.max_relationship_score
]
# 最小访问次数过滤
if query.min_access_count is not None:
filtered_ids = [
memory_id for memory_id in filtered_ids
memory_id
for memory_id in filtered_ids
if self.memory_metadata_cache[memory_id]["access_count"] >= query.min_access_count
]
@@ -530,7 +533,8 @@ class MetadataIndexManager:
if query.semantic_hashes:
hash_set = set(query.semantic_hashes)
filtered_ids = [
memory_id for memory_id in filtered_ids
memory_id
for memory_id in filtered_ids
if self.memory_metadata_cache[memory_id]["semantic_hash"] in hash_set
]
@@ -560,8 +564,7 @@ class MetadataIndexManager:
elif sort_by == "relevance_score":
# 按相关度排序
memory_ids.sort(
key=lambda mid: self.memory_metadata_cache[mid]["relevance_score"],
reverse=(sort_order == "desc")
key=lambda mid: self.memory_metadata_cache[mid]["relevance_score"], reverse=(sort_order == "desc")
)
elif sort_by == "relationship_score":
@@ -574,8 +577,7 @@ class MetadataIndexManager:
elif sort_by == "last_accessed":
# 按最后访问时间排序
memory_ids.sort(
key=lambda mid: self.memory_metadata_cache[mid]["last_accessed"],
reverse=(sort_order == "desc")
key=lambda mid: self.memory_metadata_cache[mid]["last_accessed"], reverse=(sort_order == "desc")
)
return memory_ids
@@ -665,7 +667,9 @@ class MetadataIndexManager:
self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid != memory_id]
# 从访问频率索引中移除
self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid != memory_id]
self.access_frequency_index = [
(count, mid) for count, mid in self.access_frequency_index if mid != memory_id
]
# 注意:关键词、标签、分类索引需要从原始记忆中获取,这里简化处理
# 实际实现中可能需要重新加载记忆或维护反向索引
@@ -704,7 +708,7 @@ class MetadataIndexManager:
"average_importance": 0.0,
"average_relationship_score": 0.0,
"top_keywords": [],
"top_tags": []
"top_tags": [],
}
if user_id:
@@ -789,23 +793,23 @@ class MetadataIndexManager:
indices_data[index_type.value] = serialized_index
indices_file = self.index_path / "indices.json"
with open(indices_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(indices_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
with open(indices_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(indices_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存时间索引
time_index_file = self.index_path / "time_index.json"
with open(time_index_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.time_index, option=orjson.OPT_INDENT_2).decode('utf-8'))
with open(time_index_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self.time_index, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存关系分索引
relationship_index_file = self.index_path / "relationship_index.json"
with open(relationship_index_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.relationship_index, option=orjson.OPT_INDENT_2).decode('utf-8'))
with open(relationship_index_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self.relationship_index, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存访问频率索引
access_frequency_index_file = self.index_path / "access_frequency_index.json"
with open(access_frequency_index_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.access_frequency_index, option=orjson.OPT_INDENT_2).decode('utf-8'))
with open(access_frequency_index_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self.access_frequency_index, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存元数据缓存
metadata_cache_file = self.index_path / "metadata_cache.json"
@@ -813,13 +817,13 @@ class MetadataIndexManager:
memory_id: self._serialize_metadata_entry(metadata)
for memory_id, metadata in self.memory_metadata_cache.items()
}
with open(metadata_cache_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode('utf-8'))
with open(metadata_cache_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存统计信息
stats_file = self.index_path / "index_stats.json"
with open(stats_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.index_stats, option=orjson.OPT_INDENT_2).decode('utf-8'))
with open(stats_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self.index_stats, option=orjson.OPT_INDENT_2).decode("utf-8"))
self._dirty = False
logger.info("✅ 元数据索引保存完成")
@@ -835,7 +839,7 @@ class MetadataIndexManager:
# 加载各类索引
indices_file = self.index_path / "indices.json"
if indices_file.exists():
with open(indices_file, 'r', encoding='utf-8') as f:
with open(indices_file, "r", encoding="utf-8") as f:
indices_data = orjson.loads(f.read())
for index_type_value, index_data in indices_data.items():
@@ -849,25 +853,25 @@ class MetadataIndexManager:
# 加载时间索引
time_index_file = self.index_path / "time_index.json"
if time_index_file.exists():
with open(time_index_file, 'r', encoding='utf-8') as f:
with open(time_index_file, "r", encoding="utf-8") as f:
self.time_index = orjson.loads(f.read())
# 加载关系分索引
relationship_index_file = self.index_path / "relationship_index.json"
if relationship_index_file.exists():
with open(relationship_index_file, 'r', encoding='utf-8') as f:
with open(relationship_index_file, "r", encoding="utf-8") as f:
self.relationship_index = orjson.loads(f.read())
# 加载访问频率索引
access_frequency_index_file = self.index_path / "access_frequency_index.json"
if access_frequency_index_file.exists():
with open(access_frequency_index_file, 'r', encoding='utf-8') as f:
with open(access_frequency_index_file, "r", encoding="utf-8") as f:
self.access_frequency_index = orjson.loads(f.read())
# 加载元数据缓存
metadata_cache_file = self.index_path / "metadata_cache.json"
if metadata_cache_file.exists():
with open(metadata_cache_file, 'r', encoding='utf-8') as f:
with open(metadata_cache_file, "r", encoding="utf-8") as f:
cache_data = orjson.loads(f.read())
# 转换置信度和重要性为枚举类型
@@ -910,7 +914,7 @@ class MetadataIndexManager:
# 加载统计信息
stats_file = self.index_path / "index_stats.json"
if stats_file.exists():
with open(stats_file, 'r', encoding='utf-8') as f:
with open(stats_file, "r", encoding="utf-8") as f:
self.index_stats = orjson.loads(f.read())
# 更新记忆计数
@@ -937,9 +941,7 @@ class MetadataIndexManager:
# 更新统计信息
if self.index_stats["total_queries"] > 0:
self.index_stats["cache_hit_rate"] = (
self.index_stats["cache_hits"] / self.index_stats["total_queries"]
)
self.index_stats["cache_hit_rate"] = self.index_stats["cache_hits"] / self.index_stats["total_queries"]
logger.info("✅ 元数据索引优化完成")
@@ -967,7 +969,9 @@ class MetadataIndexManager:
self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid in valid_memory_ids]
# 清理访问频率索引中的无效引用
self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid in valid_memory_ids]
self.access_frequency_index = [
(count, mid) for count, mid in self.access_frequency_index if mid in valid_memory_ids
]
# 更新总记忆数
self.index_stats["total_memories"] = len(valid_memory_ids)
@@ -1017,7 +1021,7 @@ class MetadataIndexManager:
"categories": len(self.indices[IndexType.CATEGORY]),
"confidence_levels": len(self.indices[IndexType.CONFIDENCE]),
"importance_levels": len(self.indices[IndexType.IMPORTANCE]),
"semantic_hashes": len(self.indices[IndexType.SEMANTIC_HASH])
"semantic_hashes": len(self.indices[IndexType.SEMANTIC_HASH]),
}
return stats
return stats

View File

@@ -5,15 +5,13 @@
"""
import time
import asyncio
from typing import Dict, List, Optional, Tuple, Set, Any
from typing import Dict, List, Optional, Set, Any
from dataclasses import dataclass, field
from enum import Enum
import numpy as np
import orjson
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.enhanced_reranker import EnhancedReRanker, ReRankingConfig
logger = get_logger(__name__)
@@ -21,30 +19,32 @@ logger = get_logger(__name__)
class RetrievalStage(Enum):
"""检索阶段"""
METADATA_FILTERING = "metadata_filtering" # 元数据过滤阶段
VECTOR_SEARCH = "vector_search" # 向量搜索阶段
SEMANTIC_RERANKING = "semantic_reranking" # 语义重排序阶段
CONTEXTUAL_FILTERING = "contextual_filtering" # 上下文过滤阶段
METADATA_FILTERING = "metadata_filtering" # 元数据过滤阶段
VECTOR_SEARCH = "vector_search" # 向量搜索阶段
SEMANTIC_RERANKING = "semantic_reranking" # 语义重排序阶段
CONTEXTUAL_FILTERING = "contextual_filtering" # 上下文过滤阶段
@dataclass
class RetrievalConfig:
"""检索配置"""
# 各阶段配置 - 优化召回率
metadata_filter_limit: int = 150 # 元数据过滤阶段返回数量(增加)
vector_search_limit: int = 80 # 向量搜索阶段返回数量(增加)
semantic_rerank_limit: int = 30 # 语义重排序阶段返回数量(增加)
final_result_limit: int = 10 # 最终结果数量
metadata_filter_limit: int = 150 # 元数据过滤阶段返回数量(增加)
vector_search_limit: int = 80 # 向量搜索阶段返回数量(增加)
semantic_rerank_limit: int = 30 # 语义重排序阶段返回数量(增加)
final_result_limit: int = 10 # 最终结果数量
# 相似度阈值 - 优化召回率
vector_similarity_threshold: float = 0.5 # 向量相似度阈值(降低以提升召回率)
vector_similarity_threshold: float = 0.5 # 向量相似度阈值(降低以提升召回率)
semantic_similarity_threshold: float = 0.05 # 语义相似度阈值(保持较低以获得更多相关记忆)
# 权重配置
vector_weight: float = 0.4 # 向量相似度权重
semantic_weight: float = 0.3 # 语义相似度权重
context_weight: float = 0.2 # 上下文权重
recency_weight: float = 0.1 # 时效性权重
vector_weight: float = 0.4 # 向量相似度权重
semantic_weight: float = 0.3 # 语义相似度权重
context_weight: float = 0.2 # 上下文权重
recency_weight: float = 0.1 # 时效性权重
@classmethod
def from_global_config(cls):
@@ -53,26 +53,25 @@ class RetrievalConfig:
return cls(
# 各阶段配置 - 优化召回率
metadata_filter_limit=max(150, global_config.memory.metadata_filter_limit), # 增加候选池
vector_search_limit=max(80, global_config.memory.vector_search_limit), # 增加向量搜索结果
semantic_rerank_limit=max(30, global_config.memory.semantic_rerank_limit), # 增加重排序候选
metadata_filter_limit=max(150, global_config.memory.metadata_filter_limit), # 增加候选池
vector_search_limit=max(80, global_config.memory.vector_search_limit), # 增加向量搜索结果
semantic_rerank_limit=max(30, global_config.memory.semantic_rerank_limit), # 增加重排序候选
final_result_limit=global_config.memory.final_result_limit,
# 相似度阈值 - 优化召回率
vector_similarity_threshold=max(0.5, global_config.memory.vector_similarity_threshold), # 确保不低于0.5
semantic_similarity_threshold=0.05, # 进一步降低以提升召回率
# 权重配置
vector_weight=global_config.memory.vector_weight,
semantic_weight=global_config.memory.semantic_weight,
context_weight=global_config.memory.context_weight,
recency_weight=global_config.memory.recency_weight
recency_weight=global_config.memory.recency_weight,
)
@dataclass
class StageResult:
"""阶段结果"""
stage: RetrievalStage
memory_ids: List[str]
processing_time: float
@@ -84,6 +83,7 @@ class StageResult:
@dataclass
class RetrievalResult:
"""检索结果"""
query: str
user_id: str
final_memories: List[MemoryChunk]
@@ -98,16 +98,16 @@ class MultiStageRetrieval:
def __init__(self, config: Optional[RetrievalConfig] = None):
self.config = config or RetrievalConfig.from_global_config()
# 初始化增强重排序器
reranker_config = ReRankingConfig(
semantic_weight=self.config.vector_weight,
recency_weight=self.config.recency_weight,
usage_freq_weight=0.2, # 新增的使用频率权重
type_match_weight=0.1 # 新增的类型匹配权重
type_match_weight=0.1, # 新增的类型匹配权重
)
self.reranker = EnhancedReRanker(reranker_config)
self.retrieval_stats = {
"total_queries": 0,
"average_retrieval_time": 0.0,
@@ -116,8 +116,8 @@ class MultiStageRetrieval:
"vector_search": {"calls": 0, "avg_time": 0.0},
"semantic_reranking": {"calls": 0, "avg_time": 0.0},
"contextual_filtering": {"calls": 0, "avg_time": 0.0},
"enhanced_reranking": {"calls": 0, "avg_time": 0.0} # 新增统计
}
"enhanced_reranking": {"calls": 0, "avg_time": 0.0}, # 新增统计
},
}
async def retrieve_memories(
@@ -128,7 +128,7 @@ class MultiStageRetrieval:
metadata_index,
vector_storage,
all_memories_cache: Dict[str, MemoryChunk],
limit: Optional[int] = None
limit: Optional[int] = None,
) -> RetrievalResult:
"""多阶段记忆检索"""
start_time = time.time()
@@ -143,31 +143,39 @@ class MultiStageRetrieval:
# 阶段1元数据过滤
stage1_result = await self._metadata_filtering_stage(
query, user_id, context, metadata_index, all_memories_cache,
debug_log=memory_debug_info
query, user_id, context, metadata_index, all_memories_cache, debug_log=memory_debug_info
)
stage_results.append(stage1_result)
current_memory_ids.update(stage1_result.memory_ids)
# 阶段2向量搜索
stage2_result = await self._vector_search_stage(
query, user_id, context, vector_storage, current_memory_ids, all_memories_cache,
debug_log=memory_debug_info
query,
user_id,
context,
vector_storage,
current_memory_ids,
all_memories_cache,
debug_log=memory_debug_info,
)
stage_results.append(stage2_result)
current_memory_ids.update(stage2_result.memory_ids)
# 阶段3语义重排序
stage3_result = await self._semantic_reranking_stage(
query, user_id, context, current_memory_ids, all_memories_cache,
debug_log=memory_debug_info
query, user_id, context, current_memory_ids, all_memories_cache, debug_log=memory_debug_info
)
stage_results.append(stage3_result)
# 阶段4上下文过滤
stage4_result = await self._contextual_filtering_stage(
query, user_id, context, stage3_result.memory_ids, all_memories_cache, limit,
debug_log=memory_debug_info
query,
user_id,
context,
stage3_result.memory_ids,
all_memories_cache,
limit,
debug_log=memory_debug_info,
)
stage_results.append(stage4_result)
@@ -176,18 +184,27 @@ class MultiStageRetrieval:
logger.debug(f"上下文过滤结果过少({len(stage4_result.memory_ids)}),启用回退机制")
# 回退到更宽松的检索策略
fallback_result = await self._fallback_retrieval_stage(
query, user_id, context, all_memories_cache, limit,
query,
user_id,
context,
all_memories_cache,
limit,
excluded_ids=set(stage4_result.memory_ids),
debug_log=memory_debug_info
debug_log=memory_debug_info,
)
if fallback_result.memory_ids:
stage4_result.memory_ids.extend(fallback_result.memory_ids[:limit - len(stage4_result.memory_ids)])
stage4_result.memory_ids.extend(fallback_result.memory_ids[: limit - len(stage4_result.memory_ids)])
logger.debug(f"回退机制补充了 {len(fallback_result.memory_ids)} 条记忆")
# 阶段5增强重排序 (新增)
stage5_result = await self._enhanced_reranking_stage(
query, user_id, context, stage4_result.memory_ids, all_memories_cache, limit,
debug_log=memory_debug_info
query,
user_id,
context,
stage4_result.memory_ids,
all_memories_cache,
limit,
debug_log=memory_debug_info,
)
stage_results.append(stage5_result)
@@ -226,13 +243,21 @@ class MultiStageRetrieval:
"semantic_score": trace.get("semantic_stage", {}).get("score"),
"context_score": trace.get("context_stage", {}).get("context_score"),
"final_score": trace.get("context_stage", {}).get("final_score"),
"status": trace.get("context_stage", {}).get("status") or trace.get("vector_stage", {}).get("status") or trace.get("semantic_stage", {}).get("status"),
"status": trace.get("context_stage", {}).get("status")
or trace.get("vector_stage", {}).get("status")
or trace.get("semantic_stage", {}).get("status"),
"is_final": memory_id in final_ids_set,
}
debug_entries.append(entry)
# 限制日志输出数量
debug_entries.sort(key=lambda item: (item.get("is_final", False), item.get("final_score") or item.get("vector_similarity") or 0.0), reverse=True)
debug_entries.sort(
key=lambda item: (
item.get("is_final", False),
item.get("final_score") or item.get("vector_similarity") or 0.0,
),
reverse=True,
)
debug_payload = {
"query": query,
"semantic_query": context.get("resolved_query_text", query),
@@ -266,7 +291,7 @@ class MultiStageRetrieval:
stage_results=stage_results,
total_processing_time=total_time,
total_filtered=total_filtered,
retrieval_stats=self.retrieval_stats.copy()
retrieval_stats=self.retrieval_stats.copy(),
)
except Exception as e:
@@ -279,7 +304,7 @@ class MultiStageRetrieval:
stage_results=stage_results,
total_processing_time=time.time() - start_time,
total_filtered=0,
retrieval_stats=self.retrieval_stats.copy()
retrieval_stats=self.retrieval_stats.copy(),
)
async def _metadata_filtering_stage(
@@ -290,7 +315,7 @@ class MultiStageRetrieval:
metadata_index,
all_memories_cache: Dict[str, MemoryChunk],
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
) -> StageResult:
"""阶段1元数据过滤"""
start_time = time.time()
@@ -302,7 +327,9 @@ class MultiStageRetrieval:
memory_types = self._extract_memory_types_from_context(context)
keywords = self._extract_keywords_from_query(query, query_plan)
subjects = query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None
subjects = (
query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None
)
index_query = IndexQuery(
user_ids=None,
@@ -311,7 +338,7 @@ class MultiStageRetrieval:
keywords=keywords,
limit=self.config.metadata_filter_limit,
sort_by="last_accessed",
sort_order="desc"
sort_order="desc",
)
# 执行查询
@@ -328,19 +355,16 @@ class MultiStageRetrieval:
reverse=True,
)
if memory_types:
type_filtered = [
mid for mid in sorted_ids
if all_memories_cache[mid].memory_type in memory_types
]
type_filtered = [mid for mid in sorted_ids if all_memories_cache[mid].memory_type in memory_types]
sorted_ids = type_filtered or sorted_ids
if subjects:
subject_candidates = [s.lower() for s in subjects if isinstance(s, str) and s.strip()]
if subject_candidates:
subject_filtered = [
mid for mid in sorted_ids
mid
for mid in sorted_ids
if any(
subj.strip().lower() in subject_candidates
for subj in all_memories_cache[mid].subjects
subj.strip().lower() in subject_candidates for subj in all_memories_cache[mid].subjects
)
]
sorted_ids = subject_filtered or sorted_ids
@@ -367,12 +391,14 @@ class MultiStageRetrieval:
bool(subjects),
bool(keywords),
)
details.append({
"note": "fallback_recent",
"requested_types": [mt.value for mt in memory_types] if memory_types else [],
"subjects": subjects or [],
"keywords": keywords or [],
})
details.append(
{
"note": "fallback_recent",
"requested_types": [mt.value for mt in memory_types] if memory_types else [],
"subjects": subjects or [],
"keywords": keywords or [],
}
)
logger.debug(
"元数据过滤:候选=%d, 返回=%d",
@@ -419,7 +445,7 @@ class MultiStageRetrieval:
candidate_ids: Set[str],
all_memories_cache: Dict[str, MemoryChunk],
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
) -> StageResult:
"""阶段2向量搜索"""
start_time = time.time()
@@ -441,8 +467,7 @@ class MultiStageRetrieval:
# 执行向量搜索
search_result = await vector_storage.search_similar_memories(
query_vector=query_embedding,
limit=self.config.vector_search_limit
query_vector=query_embedding, limit=self.config.vector_search_limit
)
if not search_result:
@@ -464,16 +489,18 @@ class MultiStageRetrieval:
if in_metadata_candidates and above_threshold:
filtered_memories.append((memory_id, similarity))
raw_details.append({
"memory_id": memory_id,
"similarity": similarity,
"in_metadata": in_metadata_candidates,
"above_threshold": above_threshold,
})
raw_details.append(
{
"memory_id": memory_id,
"similarity": similarity,
"in_metadata": in_metadata_candidates,
"above_threshold": above_threshold,
}
)
# 按相似度排序
filtered_memories.sort(key=lambda x: x[1], reverse=True)
result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]]
result_ids = [memory_id for memory_id, _ in filtered_memories[: self.config.vector_search_limit]]
kept_ids = set(result_ids)
for entry in raw_details:
@@ -534,11 +561,7 @@ class MultiStageRetrieval:
)
def _create_text_search_fallback(
self,
candidate_ids: Set[str],
all_memories_cache: Dict[str, MemoryChunk],
query_text: str,
start_time: float
self, candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], query_text: str, start_time: float
) -> StageResult:
"""当向量搜索失败时,使用文本搜索作为回退策略"""
try:
@@ -561,15 +584,13 @@ class MultiStageRetrieval:
# 按匹配度排序
text_matches.sort(key=lambda x: x[1], reverse=True)
result_ids = [memory_id for memory_id, _ in text_matches[:self.config.vector_search_limit]]
result_ids = [memory_id for memory_id, _ in text_matches[: self.config.vector_search_limit]]
details = []
for memory_id, score in text_matches[:self.config.vector_search_limit]:
details.append({
"memory_id": memory_id,
"text_match_score": round(score, 4),
"status": "text_match_fallback"
})
for memory_id, score in text_matches[: self.config.vector_search_limit]:
details.append(
{"memory_id": memory_id, "text_match_score": round(score, 4), "status": "text_match_fallback"}
)
logger.debug(f"向量搜索回退到文本匹配:找到 {len(result_ids)} 条匹配记忆")
@@ -579,18 +600,18 @@ class MultiStageRetrieval:
processing_time=time.time() - start_time,
filtered_count=len(candidate_ids) - len(result_ids),
score_threshold=0.0, # 文本匹配无严格阈值
details=details
details=details,
)
except Exception as e:
logger.error(f"文本搜索回退失败: {e}")
return StageResult(
stage=RetrievalStage.VECTOR_SEARCH,
memory_ids=list(candidate_ids)[:self.config.vector_search_limit],
memory_ids=list(candidate_ids)[: self.config.vector_search_limit],
processing_time=time.time() - start_time,
filtered_count=0,
score_threshold=0.0,
details=[{"error": str(e), "note": "text_fallback_failed"}]
details=[{"error": str(e), "note": "text_fallback_failed"}],
)
async def _semantic_reranking_stage(
@@ -601,7 +622,7 @@ class MultiStageRetrieval:
candidate_ids: Set[str],
all_memories_cache: Dict[str, MemoryChunk],
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
) -> StageResult:
"""阶段3语义重排序"""
start_time = time.time()
@@ -643,7 +664,7 @@ class MultiStageRetrieval:
# 按语义相似度排序
reranked_memories.sort(key=lambda x: x[1], reverse=True)
result_ids = [memory_id for memory_id, _ in reranked_memories[:self.config.semantic_rerank_limit]]
result_ids = [memory_id for memory_id, _ in reranked_memories[: self.config.semantic_rerank_limit]]
kept_ids = set(result_ids)
filtered_count = len(candidate_ids) - len(result_ids)
@@ -688,7 +709,7 @@ class MultiStageRetrieval:
all_memories_cache: Dict[str, MemoryChunk],
limit: int,
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
) -> StageResult:
"""阶段4上下文过滤"""
start_time = time.time()
@@ -777,7 +798,7 @@ class MultiStageRetrieval:
limit: int,
*,
excluded_ids: Optional[Set[str]] = None,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
) -> StageResult:
"""回退检索阶段 - 当主检索失败时使用更宽松的策略"""
start_time = time.time()
@@ -806,13 +827,15 @@ class MultiStageRetrieval:
if not fallback_candidates:
logger.debug("关键词匹配无结果,使用时序最近策略")
recent_memories = sorted(
[(mid, mem.metadata.last_accessed or mem.metadata.created_at)
for mid, mem in all_memories_cache.items()
if mid not in excluded_ids],
[
(mid, mem.metadata.last_accessed or mem.metadata.created_at)
for mid, mem in all_memories_cache.items()
if mid not in excluded_ids
],
key=lambda x: x[1],
reverse=True
reverse=True,
)
fallback_candidates = [(mid, 0.5) for mid, _ in recent_memories[:limit*2]]
fallback_candidates = [(mid, 0.5) for mid, _ in recent_memories[: limit * 2]]
# 按分数排序
fallback_candidates.sort(key=lambda x: x[1], reverse=True)
@@ -857,7 +880,9 @@ class MultiStageRetrieval:
details=[{"error": str(e)}],
)
async def _generate_query_embedding(self, query: str, context: Dict[str, Any], vector_storage) -> Optional[List[float]]:
async def _generate_query_embedding(
self, query: str, context: Dict[str, Any], vector_storage
) -> Optional[List[float]]:
"""生成查询向量"""
try:
query_plan = context.get("query_plan")
@@ -875,15 +900,15 @@ class MultiStageRetrieval:
logger.debug(f"正在生成查询向量,文本: '{query_text[:100]}'")
embedding = await vector_storage.generate_query_embedding(query_text)
if embedding is None:
logger.warning("向量存储返回空的查询向量")
return None
if len(embedding) == 0:
logger.warning("向量存储返回空列表作为查询向量")
return None
logger.debug(f"查询向量生成成功,维度: {len(embedding)}")
return embedding
@@ -926,8 +951,8 @@ class MultiStageRetrieval:
import re
# 分词处理
query_words = list(jieba.cut(query_text)) + re.findall(r'[a-zA-Z]+', query_text)
memory_words = list(jieba.cut(memory_text)) + re.findall(r'[a-zA-Z]+', memory_text)
query_words = list(jieba.cut(query_text)) + re.findall(r"[a-zA-Z]+", query_text)
memory_words = list(jieba.cut(memory_text)) + re.findall(r"[a-zA-Z]+", memory_text)
# 清理和标准化
query_words = [w.strip().lower() for w in query_words if w.strip() and len(w.strip()) > 1]
@@ -953,8 +978,9 @@ class MultiStageRetrieval:
except ImportError:
# 如果jieba不可用使用简单分词
import re
query_words = re.findall(r'[\w\u4e00-\u9fa5]+', query_lower)
memory_words = re.findall(r'[\w\u4e00-\u9fa5]+', memory_lower)
query_words = re.findall(r"[\w\u4e00-\u9fa5]+", query_lower)
memory_words = re.findall(r"[\w\u4e00-\u9fa5]+", memory_lower)
if query_words and memory_words:
query_set = set(w for w in query_words if len(w) > 1)
@@ -971,13 +997,19 @@ class MultiStageRetrieval:
"天气": ["天气", "阳光", "", "", "", "温度", "weather", "sunny", "rain"],
"编程": ["编程", "代码", "程序", "开发", "语言", "programming", "code", "develop", "python"],
"时间": ["今天", "昨天", "明天", "现在", "时间", "today", "yesterday", "tomorrow", "time"],
"情感": ["", "", "开心", "难过", "有趣", "good", "bad", "happy", "sad", "fun"]
"情感": ["", "", "开心", "难过", "有趣", "good", "bad", "happy", "sad", "fun"],
}
query_concepts = {concept for concept, keywords in concept_groups.items()
if any(keyword in query_lower for keyword in keywords)}
memory_concepts = {concept for concept, keywords in concept_groups.items()
if any(keyword in memory_lower for keyword in keywords)}
query_concepts = {
concept
for concept, keywords in concept_groups.items()
if any(keyword in query_lower for keyword in keywords)
}
memory_concepts = {
concept
for concept, keywords in concept_groups.items()
if any(keyword in memory_lower for keyword in keywords)
}
if query_concepts and memory_concepts:
concept_overlap = query_concepts & memory_concepts
@@ -987,19 +1019,19 @@ class MultiStageRetrieval:
plan_bonus = 0.0
if query_plan:
# 主体匹配
if hasattr(query_plan, 'subjects') and query_plan.subjects:
if hasattr(query_plan, "subjects") and query_plan.subjects:
for subject in query_plan.subjects:
if subject.lower() in memory_lower:
plan_bonus += 0.15
# 对象匹配
if hasattr(query_plan, 'objects') and query_plan.objects:
if hasattr(query_plan, "objects") and query_plan.objects:
for obj in query_plan.objects:
if obj.lower() in memory_lower:
plan_bonus += 0.1
# 记忆类型匹配
if hasattr(query_plan, 'memory_types') and query_plan.memory_types:
if hasattr(query_plan, "memory_types") and query_plan.memory_types:
if memory.memory_type in query_plan.memory_types:
plan_bonus += 0.1
@@ -1059,14 +1091,22 @@ class MultiStageRetrieval:
object_keywords = getattr(query_plan, "object_includes", []) or []
if object_keywords:
display_text = (memory.display or memory.text_content or "").lower()
hits = sum(1 for kw in object_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
hits = sum(
1
for kw in object_keywords
if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text
)
if hits:
score += min(0.3, hits * 0.1)
optional_keywords = getattr(query_plan, "optional_keywords", []) or []
if optional_keywords:
display_text = (memory.display or memory.text_content or "").lower()
hits = sum(1 for kw in optional_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
hits = sum(
1
for kw in optional_keywords
if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text
)
if hits:
score += min(0.2, hits * 0.05)
@@ -1091,7 +1131,9 @@ class MultiStageRetrieval:
logger.warning(f"计算上下文相关度失败: {e}")
return 0.0
async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float:
async def _calculate_final_score(
self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float
) -> float:
"""计算最终评分"""
try:
query_plan = context.get("query_plan")
@@ -1126,10 +1168,10 @@ class MultiStageRetrieval:
context_weight += 0.05
final_score = (
semantic_score * semantic_weight +
vector_score * vector_weight +
context_score * context_weight +
recency_score * recency_weight
semantic_score * semantic_weight
+ vector_score * vector_weight
+ context_score * context_weight
+ recency_score * recency_weight
)
# 加入记忆重要性权重
@@ -1259,7 +1301,9 @@ class MultiStageRetrieval:
stage_stat["calls"] += 1
current_stage_avg = stage_stat["avg_time"]
new_stage_avg = (current_stage_avg * (stage_stat["calls"] - 1) + result.processing_time) / stage_stat["calls"]
new_stage_avg = (current_stage_avg * (stage_stat["calls"] - 1) + result.processing_time) / stage_stat[
"calls"
]
stage_stat["avg_time"] = new_stage_avg
def get_retrieval_stats(self) -> Dict[str, Any]:
@@ -1276,8 +1320,8 @@ class MultiStageRetrieval:
"vector_search": {"calls": 0, "avg_time": 0.0},
"semantic_reranking": {"calls": 0, "avg_time": 0.0},
"contextual_filtering": {"calls": 0, "avg_time": 0.0},
"enhanced_reranking": {"calls": 0, "avg_time": 0.0}
}
"enhanced_reranking": {"calls": 0, "avg_time": 0.0},
},
}
async def _enhanced_reranking_stage(
@@ -1289,7 +1333,7 @@ class MultiStageRetrieval:
all_memories_cache: Dict[str, MemoryChunk],
limit: int,
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
) -> StageResult:
"""阶段5增强重排序 - 使用多维度评分模型"""
start_time = time.time()
@@ -1326,15 +1370,12 @@ class MultiStageRetrieval:
# 使用增强重排序器
reranked_memories = self.reranker.rerank_memories(
query=query,
candidate_memories=candidate_memories,
context=context,
limit=limit
query=query, candidate_memories=candidate_memories, context=context, limit=limit
)
# 提取重排序后的记忆ID
result_ids = [memory_id for memory_id, _, _ in reranked_memories]
# 生成调试详情
details = []
for memory_id, memory, final_score in reranked_memories:
@@ -1346,7 +1387,7 @@ class MultiStageRetrieval:
"access_count": memory.metadata.access_count,
}
details.append(detail_entry)
if debug_log is not None:
stage_entry = debug_log.setdefault(memory_id, {}).setdefault("enhanced_rerank_stage", {})
stage_entry["final_score"] = round(final_score, 4)
@@ -1357,13 +1398,9 @@ class MultiStageRetrieval:
kept_ids = set(result_ids)
for memory_id in candidate_ids:
if memory_id not in kept_ids:
detail_entry = {
"memory_id": memory_id,
"status": "filtered_out",
"reason": "ranked_below_limit"
}
detail_entry = {"memory_id": memory_id, "status": "filtered_out", "reason": "ranked_below_limit"}
details.append(detail_entry)
if debug_log is not None:
stage_entry = debug_log.setdefault(memory_id, {}).setdefault("enhanced_rerank_stage", {})
stage_entry["status"] = "filtered_out"
@@ -1371,10 +1408,7 @@ class MultiStageRetrieval:
filtered_count = len(candidate_ids) - len(result_ids)
logger.debug(
f"增强重排序完成:候选={len(candidate_ids)}, 返回={len(result_ids)}, "
f"过滤={filtered_count}"
)
logger.debug(f"增强重排序完成:候选={len(candidate_ids)}, 返回={len(result_ids)}, 过滤={filtered_count}")
return StageResult(
stage=RetrievalStage.CONTEXTUAL_FILTERING, # 保持与原有枚举兼容
@@ -1394,4 +1428,4 @@ class MultiStageRetrieval:
filtered_count=0,
score_threshold=0.0,
details=[{"error": str(e)}],
)
)

View File

@@ -4,22 +4,19 @@
为记忆系统提供高效的向量存储和语义搜索能力
"""
import os
import time
import orjson
import asyncio
from typing import Dict, List, Optional, Tuple, Set, Any
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from datetime import datetime
import threading
import numpy as np
import pandas as pd
from pathlib import Path
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.config.config import model_config
from src.common.config_helpers import resolve_embedding_dimension
from src.chat.memory_system.memory_chunk import MemoryChunk
@@ -28,6 +25,7 @@ logger = get_logger(__name__)
# 尝试导入FAISS如果不可用则使用简单替代
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
@@ -37,6 +35,7 @@ except ImportError:
@dataclass
class VectorStorageConfig:
"""向量存储配置"""
dimension: int = 1024
similarity_threshold: float = 0.8
index_type: str = "flat" # flat, ivf, hnsw
@@ -79,7 +78,7 @@ class VectorStorageManager:
"average_search_time": 0.0,
"cache_hit_rate": 0.0,
"total_searches": 0,
"cache_hits": 0
"cache_hits": 0,
}
# 线程锁
@@ -122,8 +121,7 @@ class VectorStorageManager:
"""初始化嵌入模型"""
if self.embedding_model is None:
self.embedding_model = LLMRequest(
model_set=model_config.model_task_config.embedding,
request_type="memory.embedding"
model_set=model_config.model_task_config.embedding, request_type="memory.embedding"
)
logger.info("✅ 嵌入模型初始化完成")
@@ -137,20 +135,16 @@ class VectorStorageManager:
await self.initialize_embedding_model()
logger.debug(f"开始生成查询向量,文本: '{query_text[:50]}{'...' if len(query_text) > 50 else ''}'")
embedding, _ = await self.embedding_model.get_embedding(query_text)
if not embedding:
logger.warning("嵌入模型返回空向量")
return None
logger.debug(f"生成的向量维度: {len(embedding)}, 期望维度: {self.config.dimension}")
if len(embedding) != self.config.dimension:
logger.error(
"查询向量维度不匹配: 期望 %d, 实际 %d",
self.config.dimension,
len(embedding)
)
logger.error("查询向量维度不匹配: 期望 %d, 实际 %d", self.config.dimension, len(embedding))
return None
normalized_vector = self._normalize_vector(embedding)
@@ -287,7 +281,7 @@ class VectorStorageManager:
logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc)
results[memory_id] = []
tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts)]
tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)]
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
@@ -313,12 +307,12 @@ class VectorStorageManager:
memory.set_embedding(embedding)
# 添加到向量索引
if hasattr(self.vector_index, 'add'):
if hasattr(self.vector_index, "add"):
# FAISS索引
if isinstance(embedding, np.ndarray):
vector_array = embedding.reshape(1, -1).astype('float32')
vector_array = embedding.reshape(1, -1).astype("float32")
else:
vector_array = np.array([embedding], dtype='float32')
vector_array = np.array([embedding], dtype="float32")
# 特殊处理IVF索引
if self.config.index_type == "ivf" and self.vector_index.ntotal == 0:
@@ -367,14 +361,14 @@ class VectorStorageManager:
*,
query_text: Optional[str] = None,
limit: int = 10,
scope_id: Optional[str] = None
scope_id: Optional[str] = None,
) -> List[Tuple[str, float]]:
"""搜索相似记忆"""
start_time = time.time()
try:
logger.debug(f"开始向量搜索: query_text='{query_text[:30] if query_text else 'None'}', limit={limit}")
if query_vector is None:
if not query_text:
logger.warning("查询向量和查询文本都为空")
@@ -395,34 +389,34 @@ class VectorStorageManager:
# 规范化查询向量
query_vector = self._normalize_vector(query_vector)
logger.debug(f"查询向量维度: {len(query_vector)}, 存储总向量数: {self.storage_stats['total_vectors']}")
# 检查向量索引状态
if not self.vector_index:
logger.error("向量索引未初始化")
return []
total_vectors = 0
if hasattr(self.vector_index, 'ntotal'):
if hasattr(self.vector_index, "ntotal"):
total_vectors = self.vector_index.ntotal
elif hasattr(self.vector_index, 'vectors'):
elif hasattr(self.vector_index, "vectors"):
total_vectors = len(self.vector_index.vectors)
logger.debug(f"向量索引中实际向量数: {total_vectors}")
if total_vectors == 0:
logger.warning("向量索引为空,无法执行搜索")
return []
# 执行向量搜索
with self._lock:
if hasattr(self.vector_index, 'search'):
if hasattr(self.vector_index, "search"):
# FAISS索引
if isinstance(query_vector, np.ndarray):
query_array = query_vector.reshape(1, -1).astype('float32')
query_array = query_vector.reshape(1, -1).astype("float32")
else:
query_array = np.array([query_vector], dtype='float32')
query_array = np.array([query_vector], dtype="float32")
if self.config.index_type == "ivf" and self.vector_index.ntotal > 0:
# 设置IVF搜索参数
@@ -432,11 +426,11 @@ class VectorStorageManager:
search_limit = min(limit, total_vectors)
logger.debug(f"执行FAISS搜索搜索限制: {search_limit}")
distances, indices = self.vector_index.search(query_array, search_limit)
distances = distances.flatten().tolist()
indices = indices.flatten().tolist()
logger.debug(f"FAISS搜索结果: {len(distances)} 个距离值, {len(indices)} 个索引")
else:
# 简单索引
@@ -451,8 +445,8 @@ class VectorStorageManager:
valid_results = 0
invalid_indices = 0
filtered_by_scope = 0
for distance, index in zip(distances, indices):
for distance, index in zip(distances, indices, strict=False):
if index == -1: # FAISS的无效索引标记
invalid_indices += 1
continue
@@ -462,7 +456,7 @@ class VectorStorageManager:
logger.debug(f"索引 {index} 没有对应的记忆ID")
invalid_indices += 1
continue
if scope_filter:
memory = self.memory_cache.get(memory_id)
if memory and str(memory.user_id) != scope_filter:
@@ -482,16 +476,15 @@ class VectorStorageManager:
search_time = time.time() - start_time
self.storage_stats["total_searches"] += 1
self.storage_stats["average_search_time"] = (
(self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time) /
self.storage_stats["total_searches"]
)
self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time
) / self.storage_stats["total_searches"]
final_results = results[:limit]
logger.info(
f"向量搜索完成: 查询='{query_text[:20] if query_text else 'vector'}' "
f"耗时={search_time:.3f}s, 返回={len(final_results)}个结果"
)
return final_results
except Exception as e:
@@ -520,7 +513,7 @@ class VectorStorageManager:
old_index = self.memory_id_to_index[memory_id]
# 删除旧向量(如果支持)
if hasattr(self.vector_index, 'remove_ids'):
if hasattr(self.vector_index, "remove_ids"):
try:
self.vector_index.remove_ids(np.array([old_index]))
except:
@@ -530,11 +523,11 @@ class VectorStorageManager:
new_embedding = self._normalize_vector(new_embedding)
# 添加新向量
if hasattr(self.vector_index, 'add'):
if hasattr(self.vector_index, "add"):
if isinstance(new_embedding, np.ndarray):
vector_array = new_embedding.reshape(1, -1).astype('float32')
vector_array = new_embedding.reshape(1, -1).astype("float32")
else:
vector_array = np.array([new_embedding], dtype='float32')
vector_array = np.array([new_embedding], dtype="float32")
self.vector_index.add(vector_array)
new_index = self.vector_index.ntotal - 1
@@ -569,7 +562,7 @@ class VectorStorageManager:
index = self.memory_id_to_index[memory_id]
# 从向量索引中删除(如果支持)
if hasattr(self.vector_index, 'remove_ids'):
if hasattr(self.vector_index, "remove_ids"):
try:
self.vector_index.remove_ids(np.array([index]))
except:
@@ -598,44 +591,37 @@ class VectorStorageManager:
logger.info("正在保存向量存储...")
# 保存记忆缓存
cache_data = {
memory_id: memory.to_dict()
for memory_id, memory in self.memory_cache.items()
}
cache_data = {memory_id: memory.to_dict() for memory_id, memory in self.memory_cache.items()}
cache_file = self.storage_path / "memory_cache.json"
with open(cache_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
with open(cache_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存向量缓存
vector_cache_file = self.storage_path / "vector_cache.json"
with open(vector_cache_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode('utf-8'))
with open(vector_cache_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存映射关系
mapping_file = self.storage_path / "id_mapping.json"
mapping_data = {
"memory_id_to_index": {
str(memory_id): int(index)
for memory_id, index in self.memory_id_to_index.items()
str(memory_id): int(index) for memory_id, index in self.memory_id_to_index.items()
},
"index_to_memory_id": {
str(index): memory_id
for index, memory_id in self.index_to_memory_id.items()
}
"index_to_memory_id": {str(index): memory_id for index, memory_id in self.index_to_memory_id.items()},
}
with open(mapping_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
with open(mapping_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存FAISS索引如果可用
if FAISS_AVAILABLE and hasattr(self.vector_index, 'save'):
if FAISS_AVAILABLE and hasattr(self.vector_index, "save"):
index_file = self.storage_path / "vector_index.faiss"
faiss.write_index(self.vector_index, str(index_file))
# 保存统计信息
stats_file = self.storage_path / "storage_stats.json"
with open(stats_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode('utf-8'))
with open(stats_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode("utf-8"))
logger.info("✅ 向量存储保存完成")
@@ -650,36 +636,31 @@ class VectorStorageManager:
# 加载记忆缓存
cache_file = self.storage_path / "memory_cache.json"
if cache_file.exists():
with open(cache_file, 'r', encoding='utf-8') as f:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = orjson.loads(f.read())
self.memory_cache = {
memory_id: MemoryChunk.from_dict(memory_data)
for memory_id, memory_data in cache_data.items()
memory_id: MemoryChunk.from_dict(memory_data) for memory_id, memory_data in cache_data.items()
}
# 加载向量缓存
vector_cache_file = self.storage_path / "vector_cache.json"
if vector_cache_file.exists():
with open(vector_cache_file, 'r', encoding='utf-8') as f:
with open(vector_cache_file, "r", encoding="utf-8") as f:
self.vector_cache = orjson.loads(f.read())
# 加载映射关系
mapping_file = self.storage_path / "id_mapping.json"
if mapping_file.exists():
with open(mapping_file, 'r', encoding='utf-8') as f:
with open(mapping_file, "r", encoding="utf-8") as f:
mapping_data = orjson.loads(f.read())
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
self.memory_id_to_index = {
str(memory_id): int(index)
for memory_id, index in raw_memory_to_index.items()
str(memory_id): int(index) for memory_id, index in raw_memory_to_index.items()
}
raw_index_to_memory = mapping_data.get("index_to_memory_id", {})
self.index_to_memory_id = {
int(index): memory_id
for index, memory_id in raw_index_to_memory.items()
}
self.index_to_memory_id = {int(index): memory_id for index, memory_id in raw_index_to_memory.items()}
# 加载FAISS索引如果可用
index_loaded = False
@@ -699,7 +680,7 @@ class VectorStorageManager:
logger.warning(f"加载FAISS索引失败: {e},重新构建")
else:
logger.info("FAISS索引文件不存在将重新构建")
# 如果索引没有成功加载且有向量数据,则重建索引
if not index_loaded and self.vector_cache:
logger.info(f"检测到 {len(self.vector_cache)} 个向量缓存,重建索引")
@@ -708,7 +689,7 @@ class VectorStorageManager:
# 加载统计信息
stats_file = self.storage_path / "storage_stats.json"
if stats_file.exists():
with open(stats_file, 'r', encoding='utf-8') as f:
with open(stats_file, "r", encoding="utf-8") as f:
self.storage_stats = orjson.loads(f.read())
# 更新向量计数
@@ -738,7 +719,7 @@ class VectorStorageManager:
# 准备向量数据
memory_ids = []
vectors = []
for memory_id, embedding in self.vector_cache.items():
if embedding and len(embedding) == self.config.dimension:
memory_ids.append(memory_id)
@@ -753,18 +734,18 @@ class VectorStorageManager:
logger.info(f"准备重建 {len(vectors)} 个向量到索引")
# 批量添加向量到FAISS索引
if hasattr(self.vector_index, 'add'):
if hasattr(self.vector_index, "add"):
# FAISS索引
vector_array = np.array(vectors, dtype='float32')
vector_array = np.array(vectors, dtype="float32")
# 特殊处理IVF索引
if self.config.index_type == "ivf" and hasattr(self.vector_index, 'train'):
if self.config.index_type == "ivf" and hasattr(self.vector_index, "train"):
logger.info("训练IVF索引...")
self.vector_index.train(vector_array)
# 添加向量
self.vector_index.add(vector_array)
# 重建映射关系
for i, memory_id in enumerate(memory_ids):
self.memory_id_to_index[memory_id] = i
@@ -772,15 +753,15 @@ class VectorStorageManager:
else:
# 简单索引
for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors)):
for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors, strict=False)):
index_id = self.vector_index.add_vector(vector)
self.memory_id_to_index[memory_id] = index_id
self.index_to_memory_id[index_id] = memory_id
# 更新统计
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
final_count = getattr(self.vector_index, 'ntotal', len(self.memory_id_to_index))
final_count = getattr(self.vector_index, "ntotal", len(self.memory_id_to_index))
logger.info(f"✅ 向量索引重建完成,索引中向量数: {final_count}")
except Exception as e:
@@ -875,7 +856,7 @@ class SimpleVectorIndex:
def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float:
"""计算余弦相似度"""
try:
dot_product = sum(x * y for x, y in zip(v1, v2))
dot_product = sum(x * y for x, y in zip(v1, v2, strict=False))
norm1 = sum(x * x for x in v1) ** 0.5
norm2 = sum(x * x for x in v2) ** 0.5
@@ -890,4 +871,4 @@ class SimpleVectorIndex:
@property
def ntotal(self) -> int:
"""向量总数"""
return len(self.vectors)
return len(self.vectors)

View File

@@ -6,7 +6,6 @@
import difflib
import orjson
import time
from typing import List, Dict, Optional
from datetime import datetime
@@ -15,7 +14,7 @@ from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.memory_system.memory_manager import memory_manager, MemoryResult
from src.chat.memory_system.memory_manager import MemoryResult
logger = get_logger("memory_activator")
@@ -127,8 +126,8 @@ class MemoryActivator:
for result in memory_results:
# 检查是否已存在相似内容的记忆
exists = any(
m["content"] == result.content or
difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
m["content"] == result.content
or difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
for m in self.running_memory
)
if not exists:
@@ -140,7 +139,7 @@ class MemoryActivator:
"confidence": result.confidence,
"importance": result.importance,
"source": result.source,
"relevance_score": result.relevance_score # 添加相关度评分
"relevance_score": result.relevance_score, # 添加相关度评分
}
self.running_memory.append(memory_entry)
logger.debug(f"添加新记忆: {result.memory_type} - {result.content}")
@@ -168,17 +167,14 @@ class MemoryActivator:
return []
# 构建查询上下文
context = {
"keywords": keywords,
"query_intent": "conversation_response"
}
context = {"keywords": keywords, "query_intent": "conversation_response"}
# 查询记忆
memories = await memory_system.retrieve_relevant_memories(
query_text=query_text,
user_id="global", # 使用全局作用域
context=context,
limit=5
limit=5,
)
# 转换为 MemoryResult 格式
@@ -191,7 +187,7 @@ class MemoryActivator:
importance=memory.metadata.importance.value,
timestamp=memory.metadata.created_at,
source="unified_memory",
relevance_score=memory.metadata.relevance_score
relevance_score=memory.metadata.relevance_score,
)
memory_results.append(result)
@@ -214,16 +210,10 @@ class MemoryActivator:
if not memory_system or memory_system.status.value != "ready":
return None
context = {
"query_intent": "instant_response",
"chat_id": chat_id
}
context = {"query_intent": "instant_response", "chat_id": chat_id}
memories = await memory_system.retrieve_relevant_memories(
query_text=target_message,
user_id="global",
context=context,
limit=1
query_text=target_message, user_id="global", context=context, limit=1
)
if memories:
@@ -248,4 +238,4 @@ memory_activator = MemoryActivator()
# 兼容性别名
enhanced_memory_activator = memory_activator
init_prompt()
init_prompt()

View File

@@ -6,7 +6,6 @@
import difflib
import orjson
import time
from typing import List, Dict, Optional
from datetime import datetime
@@ -15,7 +14,7 @@ from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.memory_system.memory_manager import memory_manager, MemoryResult
from src.chat.memory_system.memory_manager import MemoryResult
logger = get_logger("memory_activator")
@@ -127,8 +126,8 @@ class MemoryActivator:
for result in memory_results:
# 检查是否已存在相似内容的记忆
exists = any(
m["content"] == result.content or
difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
m["content"] == result.content
or difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
for m in self.running_memory
)
if not exists:
@@ -140,7 +139,7 @@ class MemoryActivator:
"confidence": result.confidence,
"importance": result.importance,
"source": result.source,
"relevance_score": result.relevance_score # 添加相关度评分
"relevance_score": result.relevance_score, # 添加相关度评分
}
self.running_memory.append(memory_entry)
logger.debug(f"添加新记忆: {result.memory_type} - {result.content}")
@@ -168,17 +167,14 @@ class MemoryActivator:
return []
# 构建查询上下文
context = {
"keywords": keywords,
"query_intent": "conversation_response"
}
context = {"keywords": keywords, "query_intent": "conversation_response"}
# 查询记忆
memories = await memory_system.retrieve_relevant_memories(
query_text=query_text,
user_id="global", # 使用全局作用域
context=context,
limit=5
limit=5,
)
# 转换为 MemoryResult 格式
@@ -191,7 +187,7 @@ class MemoryActivator:
importance=memory.metadata.importance.value,
timestamp=memory.metadata.created_at,
source="unified_memory",
relevance_score=memory.metadata.relevance_score
relevance_score=memory.metadata.relevance_score,
)
memory_results.append(result)
@@ -214,16 +210,10 @@ class MemoryActivator:
if not memory_system or memory_system.status.value != "ready":
return None
context = {
"query_intent": "instant_response",
"chat_id": chat_id
}
context = {"query_intent": "instant_response", "chat_id": chat_id}
memories = await memory_system.retrieve_relevant_memories(
query_text=target_message,
user_id="global",
context=context,
limit=1
query_text=target_message, user_id="global", context=context, limit=1
)
if memories:
@@ -246,4 +236,4 @@ class MemoryActivator:
memory_activator = MemoryActivator()
init_prompt()
init_prompt()

View File

@@ -33,7 +33,7 @@ import time
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Union, Type
from typing import Any, Dict, List, Optional, Union, Type
import orjson
@@ -53,14 +53,15 @@ logger = get_logger(__name__)
class ExtractionStrategy(Enum):
"""提取策略"""
LLM_BASED = "llm_based" # 基于LLM的智能提取
RULE_BASED = "rule_based" # 基于规则的提取
HYBRID = "hybrid" # 混合策略
LLM_BASED = "llm_based" # 基于LLM的智能提取
RULE_BASED = "rule_based" # 基于规则的提取
HYBRID = "hybrid" # 混合策略
@dataclass
class ExtractionResult:
"""提取结果"""
memories: List[MemoryChunk]
confidence_scores: List[float]
extraction_time: float
@@ -80,15 +81,11 @@ class MemoryBuilder:
"total_extractions": 0,
"successful_extractions": 0,
"failed_extractions": 0,
"average_confidence": 0.0
"average_confidence": 0.0,
}
async def build_memories(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: float
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: float
) -> List[MemoryChunk]:
"""从对话中构建记忆"""
start_time = time.time()
@@ -119,19 +116,13 @@ class MemoryBuilder:
raise
async def _extract_with_llm(
self,
text: str,
context: Dict[str, Any],
user_id: str,
timestamp: float
self, text: str, context: Dict[str, Any], user_id: str, timestamp: float
) -> List[MemoryChunk]:
"""使用LLM提取记忆"""
try:
prompt = self._build_llm_extraction_prompt(text, context)
response, _ = await self.llm_model.generate_response_async(
prompt, temperature=0.3
)
response, _ = await self.llm_model.generate_response_async(prompt, temperature=0.3)
# 解析LLM响应
memories = self._parse_llm_response(response, user_id, timestamp, context)
@@ -342,16 +333,12 @@ class MemoryBuilder:
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start:end + 1].strip()
return stripped[start : end + 1].strip()
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
def _parse_llm_response(
self,
response: str,
user_id: str,
timestamp: float,
context: Dict[str, Any]
self, response: str, user_id: str, timestamp: float, context: Dict[str, Any]
) -> List[MemoryChunk]:
"""解析LLM响应"""
if not response:
@@ -366,9 +353,7 @@ class MemoryBuilder:
data = orjson.loads(json_payload)
except Exception as e:
preview = json_payload[:200]
raise MemoryExtractionError(
f"LLM响应JSON解析失败: {e}, 片段: {preview}"
) from e
raise MemoryExtractionError(f"LLM响应JSON解析失败: {e}, 片段: {preview}") from e
memory_list = data.get("memories", [])
@@ -406,17 +391,15 @@ class MemoryBuilder:
try:
# 检查是否包含模糊代称
display_text = mem_data.get("display", "")
if any(ambiguous_term in display_text for ambiguous_term in ["用户", "user", "the user", "对方", "对手"]):
if any(
ambiguous_term in display_text for ambiguous_term in ["用户", "user", "the user", "对方", "对手"]
):
logger.debug(f"拒绝构建包含模糊代称的记忆display字段: {display_text}")
continue
subject_value = mem_data.get("subject")
normalized_subject = self._normalize_subjects(
subject_value,
bot_identifiers,
system_identifiers,
default_subjects,
bot_display
subject_value, bot_identifiers, system_identifiers, default_subjects, bot_display
)
if not normalized_subject:
@@ -425,17 +408,11 @@ class MemoryBuilder:
# 创建记忆块
importance_level = self._parse_enum_value(
ImportanceLevel,
mem_data.get("importance"),
ImportanceLevel.NORMAL,
"importance"
ImportanceLevel, mem_data.get("importance"), ImportanceLevel.NORMAL, "importance"
)
confidence_level = self._parse_enum_value(
ConfidenceLevel,
mem_data.get("confidence"),
ConfidenceLevel.MEDIUM,
"confidence"
ConfidenceLevel, mem_data.get("confidence"), ConfidenceLevel.MEDIUM, "confidence"
)
predicate_value = mem_data.get("predicate", "")
@@ -457,7 +434,7 @@ class MemoryBuilder:
source_context=mem_data.get("reasoning", ""),
importance=importance_level,
confidence=confidence_level,
display=display_text
display=display_text,
)
if used_fallback_display:
@@ -483,13 +460,7 @@ class MemoryBuilder:
return memories
def _parse_enum_value(
self,
enum_cls: Type[Enum],
raw_value: Any,
default: Enum,
field_name: str
) -> Enum:
def _parse_enum_value(self, enum_cls: Type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum:
"""解析枚举值,兼容数字/字符串表示"""
if isinstance(raw_value, enum_cls):
return raw_value
@@ -533,12 +504,14 @@ class MemoryBuilder:
try:
return enum_cls(raw_value)
except Exception:
logger.debug("%s=%s 类型 %s 无法解析为 %s,使用默认值 %s",
field_name,
raw_value,
type(raw_value).__name__,
enum_cls.__name__,
default.name)
logger.debug(
"%s=%s 类型 %s 无法解析为 %s,使用默认值 %s",
field_name,
raw_value,
type(raw_value).__name__,
enum_cls.__name__,
default.name,
)
return default
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
@@ -606,7 +579,7 @@ class MemoryBuilder:
"members",
"member_names",
"mention_users",
"audiences"
"audiences",
]
for key in candidate_keys:
@@ -727,7 +700,7 @@ class MemoryBuilder:
bot_identifiers: set[str],
system_identifiers: set[str],
default_subjects: List[str],
bot_display: Optional[str] = None
bot_display: Optional[str] = None,
) -> List[str]:
defaults = default_subjects or ["对话参与者"]
@@ -800,7 +773,9 @@ class MemoryBuilder:
return obj.strip() or None
return None
def _compose_display_text(self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]) -> str:
def _compose_display_text(
self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]
) -> str:
subject_phrase = "".join(subjects) if subjects else "对话参与者"
predicate = (predicate or "").strip()
@@ -866,11 +841,7 @@ class MemoryBuilder:
return f"{subject_phrase}{predicate}".strip()
return subject_phrase
def _validate_and_enhance_memories(
self,
memories: List[MemoryChunk],
context: Dict[str, Any]
) -> List[MemoryChunk]:
def _validate_and_enhance_memories(self, memories: List[MemoryChunk], context: Dict[str, Any]) -> List[MemoryChunk]:
"""验证和增强记忆"""
validated_memories = []
@@ -905,11 +876,7 @@ class MemoryBuilder:
return True
def _enhance_memory(
self,
memory: MemoryChunk,
context: Dict[str, Any]
) -> MemoryChunk:
def _enhance_memory(self, memory: MemoryChunk, context: Dict[str, Any]) -> MemoryChunk:
"""增强记忆块"""
# 时间规范化处理
self._normalize_time_in_memory(memory)
@@ -919,7 +886,7 @@ class MemoryBuilder:
memory.temporal_context = {
"timestamp": memory.metadata.created_at,
"timezone": context.get("timezone", "UTC"),
"day_of_week": datetime.fromtimestamp(memory.metadata.created_at).strftime("%A")
"day_of_week": datetime.fromtimestamp(memory.metadata.created_at).strftime("%A"),
}
# 添加情感上下文(如果有)
@@ -941,22 +908,22 @@ class MemoryBuilder:
# 定义相对时间映射
relative_time_patterns = {
r'今天|今日': current_time.strftime('%Y-%m-%d'),
r'昨天|昨日': (current_time - timedelta(days=1)).strftime('%Y-%m-%d'),
r'明天|明日': (current_time + timedelta(days=1)).strftime('%Y-%m-%d'),
r'后天': (current_time + timedelta(days=2)).strftime('%Y-%m-%d'),
r'大后天': (current_time + timedelta(days=3)).strftime('%Y-%m-%d'),
r'前天': (current_time - timedelta(days=2)).strftime('%Y-%m-%d'),
r'大前天': (current_time - timedelta(days=3)).strftime('%Y-%m-%d'),
r'本周|这周|这星期': current_time.strftime('%Y-%m-%d'),
r'上周|上星期': (current_time - timedelta(weeks=1)).strftime('%Y-%m-%d'),
r'下周|下星期': (current_time + timedelta(weeks=1)).strftime('%Y-%m-%d'),
r'本月|这个月': current_time.strftime('%Y-%m-01'),
r'上月|上个月': (current_time.replace(day=1) - timedelta(days=1)).strftime('%Y-%m-01'),
r'下月|下个月': (current_time.replace(day=1) + timedelta(days=32)).replace(day=1).strftime('%Y-%m-01'),
r'今年|今年': current_time.strftime('%Y'),
r'去年|上一年': str(current_time.year - 1),
r'明年|下一年': str(current_time.year + 1),
r"今天|今日": current_time.strftime("%Y-%m-%d"),
r"昨天|昨日": (current_time - timedelta(days=1)).strftime("%Y-%m-%d"),
r"明天|明日": (current_time + timedelta(days=1)).strftime("%Y-%m-%d"),
r"后天": (current_time + timedelta(days=2)).strftime("%Y-%m-%d"),
r"大后天": (current_time + timedelta(days=3)).strftime("%Y-%m-%d"),
r"前天": (current_time - timedelta(days=2)).strftime("%Y-%m-%d"),
r"大前天": (current_time - timedelta(days=3)).strftime("%Y-%m-%d"),
r"本周|这周|这星期": current_time.strftime("%Y-%m-%d"),
r"上周|上星期": (current_time - timedelta(weeks=1)).strftime("%Y-%m-%d"),
r"下周|下星期": (current_time + timedelta(weeks=1)).strftime("%Y-%m-%d"),
r"本月|这个月": current_time.strftime("%Y-%m-01"),
r"上月|上个月": (current_time.replace(day=1) - timedelta(days=1)).strftime("%Y-%m-01"),
r"下月|下个月": (current_time.replace(day=1) + timedelta(days=32)).replace(day=1).strftime("%Y-%m-01"),
r"今年|今年": current_time.strftime("%Y"),
r"去年|上一年": str(current_time.year - 1),
r"明年|下一年": str(current_time.year + 1),
}
def _normalize_value(value):
@@ -1009,10 +976,14 @@ class MemoryBuilder:
# 更新平均置信度
if self.extraction_stats["successful_extractions"] > 0:
total_confidence = self.extraction_stats["average_confidence"] * (self.extraction_stats["successful_extractions"] - success_count)
total_confidence = self.extraction_stats["average_confidence"] * (
self.extraction_stats["successful_extractions"] - success_count
)
# 假设新记忆的平均置信度为0.8
total_confidence += 0.8 * success_count
self.extraction_stats["average_confidence"] = total_confidence / self.extraction_stats["successful_extractions"]
self.extraction_stats["average_confidence"] = (
total_confidence / self.extraction_stats["successful_extractions"]
)
def get_extraction_stats(self) -> Dict[str, Any]:
"""获取提取统计信息"""
@@ -1024,5 +995,5 @@ class MemoryBuilder:
"total_extractions": 0,
"successful_extractions": 0,
"failed_extractions": 0,
"average_confidence": 0.0
}
"average_confidence": 0.0,
}

View File

@@ -8,8 +8,7 @@ import time
import uuid
import orjson
from typing import Dict, List, Optional, Any, Union, Iterable
from dataclasses import dataclass, field, asdict
from datetime import datetime
from dataclasses import dataclass, field
from enum import Enum
import hashlib
@@ -21,33 +20,36 @@ logger = get_logger(__name__)
class MemoryType(Enum):
"""记忆类型分类"""
PERSONAL_FACT = "personal_fact" # 个人事实(姓名、职业、住址等)
EVENT = "event" # 事件(重要经历、约会等)
PREFERENCE = "preference" # 偏好(喜好、习惯等)
OPINION = "opinion" # 观点(对事物的看法
RELATIONSHIP = "relationship" # 关系(与他人的关系
EMOTION = "emotion" # 情感状态
KNOWLEDGE = "knowledge" # 知识信息
SKILL = "skill" # 技能能力
GOAL = "goal" # 目标计划
EXPERIENCE = "experience" # 经验教训
CONTEXTUAL = "contextual" # 上下文信息
PERSONAL_FACT = "personal_fact" # 个人事实(姓名、职业、住址等)
EVENT = "event" # 事件(重要经历、约会等)
PREFERENCE = "preference" # 偏好(喜好、习惯等
OPINION = "opinion" # 观点(对事物的看法
RELATIONSHIP = "relationship" # 关系(与他人的关系)
EMOTION = "emotion" # 情感状态
KNOWLEDGE = "knowledge" # 知识信息
SKILL = "skill" # 技能能力
GOAL = "goal" # 目标计划
EXPERIENCE = "experience" # 经验教训
CONTEXTUAL = "contextual" # 上下文信息
class ConfidenceLevel(Enum):
"""置信度等级"""
LOW = 1 # 低置信度,可能不准确
MEDIUM = 2 # 中等置信度,有一定依据
HIGH = 3 # 置信度,有明确来源
VERIFIED = 4 # 已验证,非常可靠
LOW = 1 # 置信度,可能不准确
MEDIUM = 2 # 中等置信度,有一定依据
HIGH = 3 # 高置信度,有明确来源
VERIFIED = 4 # 已验证,非常可靠
class ImportanceLevel(Enum):
"""重要性等级"""
LOW = 1 # 低重要性,普通信息
NORMAL = 2 # 一般重要性,日常信息
HIGH = 3 # 重要性,重要信息
CRITICAL = 4 # 关键重要性,核心信息
LOW = 1 # 重要性,普通信息
NORMAL = 2 # 一般重要性,日常信息
HIGH = 3 # 重要性,重要信息
CRITICAL = 4 # 关键重要性,核心信息
@dataclass
@@ -61,12 +63,7 @@ class ContentStructure:
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"subject": self.subject,
"predicate": self.predicate,
"object": self.object,
"display": self.display
}
return {"subject": self.subject, "predicate": self.predicate, "object": self.object, "display": self.display}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure":
@@ -75,7 +72,7 @@ class ContentStructure:
subject=data.get("subject", ""),
predicate=data.get("predicate", ""),
object=data.get("object", ""),
display=data.get("display", "")
display=data.get("display", ""),
)
def to_subject_list(self) -> List[str]:
@@ -98,24 +95,25 @@ class ContentStructure:
@dataclass
class MemoryMetadata:
"""记忆元数据 - 简化版本"""
# 基础信息
memory_id: str # 唯一标识符
user_id: str # 用户ID
chat_id: Optional[str] = None # 聊天ID群聊或私聊
memory_id: str # 唯一标识符
user_id: str # 用户ID
chat_id: Optional[str] = None # 聊天ID群聊或私聊
# 时间信息
created_at: float = 0.0 # 创建时间戳
last_accessed: float = 0.0 # 最后访问时间
last_modified: float = 0.0 # 最后修改时间
created_at: float = 0.0 # 创建时间戳
last_accessed: float = 0.0 # 最后访问时间
last_modified: float = 0.0 # 最后修改时间
# 激活频率管理
last_activation_time: float = 0.0 # 最后激活时间
activation_frequency: int = 0 # 激活频率(单位时间内的激活次数)
total_activations: int = 0 # 总激活次数
last_activation_time: float = 0.0 # 最后激活时间
activation_frequency: int = 0 # 激活频率(单位时间内的激活次数)
total_activations: int = 0 # 总激活次数
# 统计信息
access_count: int = 0 # 访问次数
relevance_score: float = 0.0 # 相关度评分
access_count: int = 0 # 访问次数
relevance_score: float = 0.0 # 相关度评分
# 信心和重要性(核心字段)
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM
@@ -123,10 +121,10 @@ class MemoryMetadata:
# 遗忘机制相关
forgetting_threshold: float = 0.0 # 遗忘阈值(动态计算)
last_forgetting_check: float = 0.0 # 上次遗忘检查时间
last_forgetting_check: float = 0.0 # 上次遗忘检查时间
# 来源信息
source_context: Optional[str] = None # 来源上下文片段
source_context: Optional[str] = None # 来源上下文片段
# 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source
source: Optional[str] = None
@@ -153,13 +151,13 @@ class MemoryMetadata:
self.last_forgetting_check = current_time
# 兼容性:如果旧字段 source 被使用,保证 source 与 source_context 同步
if not getattr(self, 'source', None) and getattr(self, 'source_context', None):
if not getattr(self, "source", None) and getattr(self, "source_context", None):
try:
self.source = str(self.source_context)
except Exception:
self.source = None
# 如果有 source 字段但 source_context 为空,也同步回去
if not getattr(self, 'source_context', None) and getattr(self, 'source', None):
if not getattr(self, "source_context", None) and getattr(self, "source", None):
try:
self.source_context = str(self.source)
except Exception:
@@ -177,7 +175,6 @@ class MemoryMetadata:
def _update_activation_frequency(self, current_time: float):
"""更新激活频率24小时内的激活次数"""
from datetime import datetime, timedelta
# 如果超过24小时重置激活频率
if current_time - self.last_activation_time > 86400: # 24小时 = 86400秒
@@ -251,7 +248,7 @@ class MemoryMetadata:
"importance": self.importance.value,
"forgetting_threshold": self.forgetting_threshold,
"last_forgetting_check": self.last_forgetting_check,
"source_context": self.source_context
"source_context": self.source_context,
}
@classmethod
@@ -273,7 +270,7 @@ class MemoryMetadata:
importance=ImportanceLevel(data.get("importance", ImportanceLevel.NORMAL.value)),
forgetting_threshold=data.get("forgetting_threshold", 0.0),
last_forgetting_check=data.get("last_forgetting_check", 0),
source_context=data.get("source_context")
source_context=data.get("source_context"),
)
@@ -285,21 +282,21 @@ class MemoryChunk:
metadata: MemoryMetadata
# 内容结构
content: ContentStructure # 主谓宾结构
memory_type: MemoryType # 记忆类型
content: ContentStructure # 主谓宾结构
memory_type: MemoryType # 记忆类型
# 扩展信息
keywords: List[str] = field(default_factory=list) # 关键词列表
tags: List[str] = field(default_factory=list) # 标签列表
categories: List[str] = field(default_factory=list) # 分类列表
keywords: List[str] = field(default_factory=list) # 关键词列表
tags: List[str] = field(default_factory=list) # 标签列表
categories: List[str] = field(default_factory=list) # 分类列表
# 语义信息
embedding: Optional[List[float]] = None # 语义向量
semantic_hash: Optional[str] = None # 语义哈希值
embedding: Optional[List[float]] = None # 语义向量
semantic_hash: Optional[str] = None # 语义哈希值
# 关联信息
related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表
temporal_context: Optional[Dict[str, Any]] = None # 时间上下文
temporal_context: Optional[Dict[str, Any]] = None # 时间上下文
def __post_init__(self):
"""后初始化处理"""
@@ -317,7 +314,7 @@ class MemoryChunk:
embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding]))
hash_input = f"{content_str}|{embedding_str}"
hash_object = hashlib.sha256(hash_input.encode('utf-8'))
hash_object = hashlib.sha256(hash_input.encode("utf-8"))
self.semantic_hash = hash_object.hexdigest()[:16]
except Exception as e:
@@ -430,7 +427,7 @@ class MemoryChunk:
"embedding": self.embedding,
"semantic_hash": self.semantic_hash,
"related_memories": self.related_memories,
"temporal_context": self.temporal_context
"temporal_context": self.temporal_context,
}
@classmethod
@@ -449,14 +446,14 @@ class MemoryChunk:
embedding=data.get("embedding"),
semantic_hash=data.get("semantic_hash"),
related_memories=data.get("related_memories", []),
temporal_context=data.get("temporal_context")
temporal_context=data.get("temporal_context"),
)
return chunk
def to_json(self) -> str:
"""转换为JSON字符串"""
return orjson.dumps(self.to_dict(), ensure_ascii=False).decode('utf-8')
return orjson.dumps(self.to_dict(), ensure_ascii=False).decode("utf-8")
@classmethod
def from_json(cls, json_str: str) -> "MemoryChunk":
@@ -530,7 +527,7 @@ class MemoryChunk:
MemoryType.SKILL: "🛠️",
MemoryType.GOAL: "🎯",
MemoryType.EXPERIENCE: "💡",
MemoryType.CONTEXTUAL: "📝"
MemoryType.CONTEXTUAL: "📝",
}
emoji = type_emoji.get(self.memory_type, "📝")
@@ -581,7 +578,7 @@ def create_memory_chunk(
importance: ImportanceLevel = ImportanceLevel.NORMAL,
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
display: Optional[str] = None,
**kwargs
**kwargs,
) -> MemoryChunk:
"""便捷的内存块创建函数"""
metadata = MemoryMetadata(
@@ -593,7 +590,7 @@ def create_memory_chunk(
last_modified=0,
confidence=confidence,
importance=importance,
source_context=source_context
source_context=source_context,
)
subjects: List[str]
@@ -607,18 +604,8 @@ def create_memory_chunk(
display_text = display or _build_display_text(subjects, predicate, obj)
content = ContentStructure(
subject=subject_payload,
predicate=predicate,
object=obj,
display=display_text
)
content = ContentStructure(subject=subject_payload, predicate=predicate, object=obj, display=display_text)
chunk = MemoryChunk(
metadata=metadata,
content=content,
memory_type=memory_type,
**kwargs
)
chunk = MemoryChunk(metadata=metadata, content=content, memory_type=memory_type, **kwargs)
return chunk
return chunk

View File

@@ -6,8 +6,8 @@
import time
import asyncio
from typing import List, Dict, Optional, Set, Tuple
from datetime import datetime, timedelta
from typing import List, Dict, Optional, Tuple
from datetime import datetime
from dataclasses import dataclass
from src.common.logger import get_logger
@@ -19,6 +19,7 @@ logger = get_logger(__name__)
@dataclass
class ForgettingStats:
"""遗忘统计信息"""
total_checked: int = 0
marked_for_forgetting: int = 0
actually_forgotten: int = 0
@@ -30,34 +31,35 @@ class ForgettingStats:
@dataclass
class ForgettingConfig:
"""遗忘引擎配置"""
# 检查频率配置
check_interval_hours: int = 24 # 定期检查间隔(小时)
batch_size: int = 100 # 批处理大小
check_interval_hours: int = 24 # 定期检查间隔(小时)
batch_size: int = 100 # 批处理大小
# 遗忘阈值配置
base_forgetting_days: float = 30.0 # 基础遗忘天数
min_forgetting_days: float = 7.0 # 最小遗忘天数
max_forgetting_days: float = 365.0 # 最大遗忘天数
base_forgetting_days: float = 30.0 # 基础遗忘天数
min_forgetting_days: float = 7.0 # 最小遗忘天数
max_forgetting_days: float = 365.0 # 最大遗忘天数
# 重要程度权重
critical_importance_bonus: float = 45.0 # 关键重要性额外天数
high_importance_bonus: float = 30.0 # 高重要性额外天数
normal_importance_bonus: float = 15.0 # 一般重要性额外天数
low_importance_bonus: float = 0.0 # 低重要性额外天数
high_importance_bonus: float = 30.0 # 高重要性额外天数
normal_importance_bonus: float = 15.0 # 一般重要性额外天数
low_importance_bonus: float = 0.0 # 低重要性额外天数
# 置信度权重
verified_confidence_bonus: float = 30.0 # 已验证置信度额外天数
high_confidence_bonus: float = 20.0 # 高置信度额外天数
medium_confidence_bonus: float = 10.0 # 中等置信度额外天数
low_confidence_bonus: float = 0.0 # 低置信度额外天数
high_confidence_bonus: float = 20.0 # 高置信度额外天数
medium_confidence_bonus: float = 10.0 # 中等置信度额外天数
low_confidence_bonus: float = 0.0 # 低置信度额外天数
# 激活频率权重
activation_frequency_weight: float = 0.5 # 每次激活增加的天数权重
max_frequency_bonus: float = 10.0 # 最大激活频率奖励天数
max_frequency_bonus: float = 10.0 # 最大激活频率奖励天数
# 休眠配置
dormant_threshold_days: int = 90 # 休眠状态判定天数
force_forget_dormant_days: int = 180 # 强制遗忘休眠记忆的天数
dormant_threshold_days: int = 90 # 休眠状态判定天数
force_forget_dormant_days: int = 180 # 强制遗忘休眠记忆的天数
class MemoryForgettingEngine:
@@ -107,13 +109,12 @@ class MemoryForgettingEngine:
# 激活频率权重
frequency_bonus = min(
memory.metadata.activation_frequency * self.config.activation_frequency_weight,
self.config.max_frequency_bonus
self.config.max_frequency_bonus,
)
threshold += frequency_bonus
# 确保在合理范围内
return max(self.config.min_forgetting_days,
min(threshold, self.config.max_forgetting_days))
return max(self.config.min_forgetting_days, min(threshold, self.config.max_forgetting_days))
def should_forget_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
"""
@@ -265,8 +266,8 @@ class MemoryForgettingEngine:
"actually_forgotten": self.stats.actually_forgotten,
"dormant_memories": self.stats.dormant_memories,
"check_duration": self.stats.check_duration,
"last_check_time": self.stats.last_check_time
}
"last_check_time": self.stats.last_check_time,
},
}
def is_forgetting_check_needed(self) -> bool:
@@ -302,7 +303,9 @@ class MemoryForgettingEngine:
# 如果启用自动清理,执行实际的遗忘操作
if enable_auto_cleanup and (result["normal_forgetting"] or result["force_forgetting"]):
logger.info(f"检测到 {len(result['normal_forgetting'])} 条普通遗忘和 {len(result['force_forgetting'])} 条强制遗忘记忆")
logger.info(
f"检测到 {len(result['normal_forgetting'])} 条普通遗忘和 {len(result['force_forgetting'])} 条强制遗忘记忆"
)
# 这里可以调用实际的删除逻辑
# await self.cleanup_forgotten_memories(result["normal_forgetting"] + result["force_forgetting"])
@@ -318,14 +321,16 @@ class MemoryForgettingEngine:
"marked_for_forgetting": self.stats.marked_for_forgetting,
"actually_forgotten": self.stats.actually_forgotten,
"dormant_memories": self.stats.dormant_memories,
"last_check_time": datetime.fromtimestamp(self.stats.last_check_time).isoformat() if self.stats.last_check_time else None,
"last_check_time": datetime.fromtimestamp(self.stats.last_check_time).isoformat()
if self.stats.last_check_time
else None,
"last_check_duration": self.stats.check_duration,
"config": {
"check_interval_hours": self.config.check_interval_hours,
"base_forgetting_days": self.config.base_forgetting_days,
"min_forgetting_days": self.config.min_forgetting_days,
"max_forgetting_days": self.config.max_forgetting_days
}
"max_forgetting_days": self.config.max_forgetting_days,
},
}
def reset_stats(self):
@@ -349,4 +354,4 @@ memory_forgetting_engine = MemoryForgettingEngine()
def get_memory_forgetting_engine() -> MemoryForgettingEngine:
"""获取全局遗忘引擎实例"""
return memory_forgetting_engine
return memory_forgetting_engine

View File

@@ -5,14 +5,12 @@
"""
import time
from typing import Dict, List, Optional, Tuple, Set, Any
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import (
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
)
from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel
logger = get_logger(__name__)
@@ -20,6 +18,7 @@ logger = get_logger(__name__)
@dataclass
class FusionResult:
"""融合结果"""
original_count: int
fused_count: int
removed_duplicates: int
@@ -31,6 +30,7 @@ class FusionResult:
@dataclass
class DuplicateGroup:
"""重复记忆组"""
group_id: str
memories: List[MemoryChunk]
similarity_matrix: List[List[float]]
@@ -46,22 +46,20 @@ class MemoryFusionEngine:
"total_fusions": 0,
"memories_fused": 0,
"duplicates_removed": 0,
"average_similarity": 0.0
"average_similarity": 0.0,
}
# 融合策略配置
self.fusion_strategies = {
"semantic_similarity": True, # 语义相似性融合
"temporal_proximity": True, # 时间接近性融合
"logical_consistency": True, # 逻辑一致性融合
"confidence_boosting": True, # 置信度提升
"importance_preservation": True # 重要性保持
"semantic_similarity": True, # 语义相似性融合
"temporal_proximity": True, # 时间接近性融合
"logical_consistency": True, # 逻辑一致性融合
"confidence_boosting": True, # 置信度提升
"importance_preservation": True, # 重要性保持
}
async def fuse_memories(
self,
new_memories: List[MemoryChunk],
existing_memories: Optional[List[MemoryChunk]] = None
self, new_memories: List[MemoryChunk], existing_memories: Optional[List[MemoryChunk]] = None
) -> List[MemoryChunk]:
"""融合记忆列表"""
start_time = time.time()
@@ -73,9 +71,7 @@ class MemoryFusionEngine:
logger.info(f"开始记忆融合,新记忆: {len(new_memories)},现有记忆: {len(existing_memories or [])}")
# 1. 检测重复记忆组
duplicate_groups = await self._detect_duplicate_groups(
new_memories, existing_memories or []
)
duplicate_groups = await self._detect_duplicate_groups(new_memories, existing_memories or [])
if not duplicate_groups:
fusion_time = time.time() - start_time
@@ -110,9 +106,7 @@ class MemoryFusionEngine:
return new_memories # 失败时返回原始记忆
async def _detect_duplicate_groups(
self,
new_memories: List[MemoryChunk],
existing_memories: List[MemoryChunk]
self, new_memories: List[MemoryChunk], existing_memories: List[MemoryChunk]
) -> List[DuplicateGroup]:
"""检测重复记忆组"""
all_memories = new_memories + existing_memories
@@ -125,16 +119,12 @@ class MemoryFusionEngine:
continue
# 创建新的重复组
group = DuplicateGroup(
group_id=f"group_{len(groups)}",
memories=[memory1],
similarity_matrix=[[1.0]]
)
group = DuplicateGroup(group_id=f"group_{len(groups)}", memories=[memory1], similarity_matrix=[[1.0]])
processed_ids.add(memory1.memory_id)
# 寻找相似记忆
for j, memory2 in enumerate(all_memories[i+1:], i+1):
for j, memory2 in enumerate(all_memories[i + 1 :], i + 1):
if memory2.memory_id in processed_ids:
continue
@@ -182,9 +172,7 @@ class MemoryFusionEngine:
# 5. 时间接近性
if self.fusion_strategies["temporal_proximity"]:
temporal_sim = self._calculate_temporal_similarity(
mem1.metadata.created_at, mem2.metadata.created_at
)
temporal_sim = self._calculate_temporal_similarity(mem1.metadata.created_at, mem2.metadata.created_at)
similarity_scores.append(("temporal", temporal_sim))
# 6. 逻辑一致性
@@ -193,14 +181,7 @@ class MemoryFusionEngine:
similarity_scores.append(("logical", logical_sim))
# 计算加权平均相似度
weights = {
"semantic": 0.35,
"text": 0.25,
"keyword": 0.15,
"type": 0.10,
"temporal": 0.10,
"logical": 0.05
}
weights = {"semantic": 0.35, "text": 0.25, "keyword": 0.15, "type": 0.10, "temporal": 0.10, "logical": 0.05}
weighted_sum = 0.0
total_weight = 0.0
@@ -276,9 +257,7 @@ class MemoryFusionEngine:
# 宾语相似性
if isinstance(mem1.content.object, str) and isinstance(mem2.content.object, str):
object_sim = self._calculate_text_similarity(
str(mem1.content.object), str(mem2.content.object)
)
object_sim = self._calculate_text_similarity(str(mem1.content.object), str(mem2.content.object))
consistency_score += object_sim * 0.3
return consistency_score
@@ -349,11 +328,7 @@ class MemoryFusionEngine:
# 返回置信度最高的记忆
return max(group.memories, key=lambda m: m.metadata.confidence.value)
async def _merge_memory_attributes(
self,
base_memory: MemoryChunk,
memories: List[MemoryChunk]
) -> MemoryChunk:
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: List[MemoryChunk]) -> MemoryChunk:
"""合并记忆属性"""
# 创建基础记忆的深拷贝
fused_memory = MemoryChunk.from_dict(base_memory.to_dict())
@@ -436,7 +411,7 @@ class MemoryFusionEngine:
"earliest_timestamp": earliest_time,
"latest_timestamp": latest_time,
"time_span_hours": (latest_time - earliest_time) / 3600,
"source_memories": len(memories)
"source_memories": len(memories),
}
# 合并其他上下文信息
@@ -451,9 +426,7 @@ class MemoryFusionEngine:
return merged_context
async def incremental_fusion(
self,
new_memory: MemoryChunk,
existing_memories: List[MemoryChunk]
self, new_memory: MemoryChunk, existing_memories: List[MemoryChunk]
) -> Tuple[MemoryChunk, List[MemoryChunk]]:
"""增量融合(单个新记忆与现有记忆融合)"""
# 寻找相似记忆
@@ -478,7 +451,7 @@ class MemoryFusionEngine:
group = DuplicateGroup(
group_id=f"incremental_{int(time.time())}",
memories=[new_memory, best_match],
similarity_matrix=[[1.0, similarity], [similarity, 1.0]]
similarity_matrix=[[1.0, similarity], [similarity, 1.0]],
)
# 执行融合
@@ -530,5 +503,5 @@ class MemoryFusionEngine:
"total_fusions": 0,
"memories_fused": 0,
"duplicates_removed": 0,
"average_similarity": 0.0
}
"average_similarity": 0.0,
}

View File

@@ -9,12 +9,9 @@ from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.memory_system.memory_system import MemorySystem
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_system import (
initialize_memory_system
)
from src.chat.memory_system.memory_system import initialize_memory_system
logger = get_logger(__name__)
@@ -22,6 +19,7 @@ logger = get_logger(__name__)
@dataclass
class MemoryResult:
"""记忆查询结果"""
content: str
memory_type: str
confidence: float
@@ -67,6 +65,7 @@ class MemoryManager:
# 获取LLM模型
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory")
# 初始化记忆系统
@@ -121,7 +120,7 @@ class MemoryManager:
max_memory_num: int = 3,
max_memory_length: int = 2,
time_weight: float = 1.0,
keyword_weight: float = 1.0
keyword_weight: float = 1.0,
) -> List[Tuple[str, str]]:
"""从文本获取相关记忆 - 兼容原有接口"""
if not self.is_initialized or not self.memory_system:
@@ -131,14 +130,11 @@ class MemoryManager:
# 使用增强记忆系统检索
context = {
"chat_id": chat_id,
"expected_memory_types": [MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE]
"expected_memory_types": [MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE],
}
relevant_memories = await self.memory_system.retrieve_relevant_memories(
query=text,
user_id=user_id,
context=context,
limit=max_memory_num
query=text, user_id=user_id, context=context, limit=max_memory_num
)
# 转换为原有格式 (topic, content)
@@ -156,11 +152,7 @@ class MemoryManager:
return []
async def get_memory_from_topic(
self,
valid_keywords: List[str],
max_memory_num: int = 3,
max_memory_length: int = 2,
max_depth: int = 3
self, valid_keywords: List[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
) -> List[Tuple[str, str]]:
"""从关键词获取记忆 - 兼容原有接口"""
if not self.is_initialized or not self.memory_system:
@@ -177,15 +169,15 @@ class MemoryManager:
MemoryType.PERSONAL_FACT,
MemoryType.EVENT,
MemoryType.PREFERENCE,
MemoryType.OPINION
]
MemoryType.OPINION,
],
}
relevant_memories = await self.memory_system.retrieve_relevant_memories(
query_text=query_text,
user_id="default_user", # 可以根据实际需要传递
context=context,
limit=max_memory_num
limit=max_memory_num,
)
# 转换为原有格式 (topic, content)
@@ -216,11 +208,7 @@ class MemoryManager:
return []
async def process_conversation(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: Optional[float] = None
) -> List[MemoryChunk]:
"""处理对话并构建记忆 - 新增功能"""
if not self.is_initialized or not self.memory_system:
@@ -247,11 +235,7 @@ class MemoryManager:
return []
async def get_enhanced_memory_context(
self,
query_text: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
limit: int = 5
self, query_text: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: int = 5
) -> List[MemoryResult]:
"""获取增强记忆上下文 - 新增功能"""
if not self.is_initialized or not self.memory_system:
@@ -259,10 +243,7 @@ class MemoryManager:
try:
relevant_memories = await self.memory_system.retrieve_relevant_memories(
query=query_text,
user_id=None,
context=context or {},
limit=limit
query=query_text, user_id=None, context=context or {}, limit=limit
)
results = []
@@ -276,7 +257,7 @@ class MemoryManager:
timestamp=memory.metadata.created_at,
source="enhanced_memory",
relevance_score=memory.metadata.relevance_score,
structure=structure
structure=structure,
)
results.append(result)
@@ -342,7 +323,9 @@ class MemoryManager:
return None
return f"{subject}的职业是{profession}"
if predicate == "lives_in":
location = self._extract_from_object(obj_value, ["location", "city", "place"]) or self._format_object(obj_value)
location = self._extract_from_object(obj_value, ["location", "city", "place"]) or self._format_object(
obj_value
)
location = self._clean_text(location)
if not location:
return None
@@ -385,7 +368,9 @@ class MemoryManager:
return None
return f"{subject}最喜欢{favorite}"
if predicate == "mentioned_event":
event_text = self._extract_from_object(obj_value, ["event_text", "description"]) or self._format_object(obj_value)
event_text = self._extract_from_object(obj_value, ["event_text", "description"]) or self._format_object(
obj_value
)
event_text = self._clean_text(self._truncate(event_text))
if not event_text:
return None
@@ -494,4 +479,4 @@ class MemoryManager:
# 全局记忆管理器实例
memory_manager = MemoryManager()
memory_manager = MemoryManager()

View File

@@ -12,7 +12,6 @@ from dataclasses import dataclass, asdict
from datetime import datetime
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryType, ImportanceLevel, ConfidenceLevel
logger = get_logger(__name__)
@@ -20,22 +19,23 @@ logger = get_logger(__name__)
@dataclass
class MemoryMetadataIndexEntry:
"""元数据索引条目(轻量级,只用于快速过滤)"""
memory_id: str
user_id: str
# 分类信息
memory_type: str # MemoryType.value
subjects: List[str] # 主语列表
objects: List[str] # 宾语列表
keywords: List[str] # 关键词列表
tags: List[str] # 标签列表
# 数值字段(用于范围过滤)
importance: int # ImportanceLevel.value (1-4)
confidence: int # ConfidenceLevel.value (1-4)
created_at: float # 创建时间戳
access_count: int # 访问次数
# 可选字段
chat_id: Optional[str] = None
content_preview: Optional[str] = None # 内容预览前100字符
@@ -43,152 +43,152 @@ class MemoryMetadataIndexEntry:
class MemoryMetadataIndex:
"""记忆元数据索引管理器"""
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
self.index_file = Path(index_file)
self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
# 倒排索引(用于快速查找)
self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids}
self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids}
self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids}
self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids}
self.lock = threading.RLock()
# 加载已有索引
self._load_index()
def _load_index(self):
"""从文件加载索引"""
if not self.index_file.exists():
logger.info(f"元数据索引文件不存在,将创建新索引: {self.index_file}")
return
try:
with open(self.index_file, 'rb') as f:
with open(self.index_file, "rb") as f:
data = orjson.loads(f.read())
# 重建内存索引
for entry_data in data.get('entries', []):
for entry_data in data.get("entries", []):
entry = MemoryMetadataIndexEntry(**entry_data)
self.index[entry.memory_id] = entry
self._update_inverted_indices(entry)
logger.info(f"✅ 加载元数据索引: {len(self.index)} 条记录")
except Exception as e:
logger.error(f"加载元数据索引失败: {e}", exc_info=True)
def _save_index(self):
"""保存索引到文件"""
try:
# 确保目录存在
self.index_file.parent.mkdir(parents=True, exist_ok=True)
# 序列化所有条目
entries = [asdict(entry) for entry in self.index.values()]
data = {
'version': '1.0',
'count': len(entries),
'last_updated': datetime.now().isoformat(),
'entries': entries
"version": "1.0",
"count": len(entries),
"last_updated": datetime.now().isoformat(),
"entries": entries,
}
# 写入文件(使用临时文件 + 原子重命名)
temp_file = self.index_file.with_suffix('.tmp')
with open(temp_file, 'wb') as f:
temp_file = self.index_file.with_suffix(".tmp")
with open(temp_file, "wb") as f:
f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
temp_file.replace(self.index_file)
logger.debug(f"元数据索引已保存: {len(entries)} 条记录")
except Exception as e:
logger.error(f"保存元数据索引失败: {e}", exc_info=True)
def _update_inverted_indices(self, entry: MemoryMetadataIndexEntry):
"""更新倒排索引"""
memory_id = entry.memory_id
# 类型索引
self.type_index.setdefault(entry.memory_type, set()).add(memory_id)
# 主语索引
for subject in entry.subjects:
subject_norm = subject.strip().lower()
if subject_norm:
self.subject_index.setdefault(subject_norm, set()).add(memory_id)
# 关键词索引
for keyword in entry.keywords:
keyword_norm = keyword.strip().lower()
if keyword_norm:
self.keyword_index.setdefault(keyword_norm, set()).add(memory_id)
# 标签索引
for tag in entry.tags:
tag_norm = tag.strip().lower()
if tag_norm:
self.tag_index.setdefault(tag_norm, set()).add(memory_id)
def add_or_update(self, entry: MemoryMetadataIndexEntry):
"""添加或更新索引条目"""
with self.lock:
# 如果已存在,先从倒排索引中移除旧记录
if entry.memory_id in self.index:
self._remove_from_inverted_indices(entry.memory_id)
# 添加新记录
self.index[entry.memory_id] = entry
self._update_inverted_indices(entry)
def _remove_from_inverted_indices(self, memory_id: str):
"""从倒排索引中移除记录"""
if memory_id not in self.index:
return
entry = self.index[memory_id]
# 从类型索引移除
if entry.memory_type in self.type_index:
self.type_index[entry.memory_type].discard(memory_id)
# 从主语索引移除
for subject in entry.subjects:
subject_norm = subject.strip().lower()
if subject_norm in self.subject_index:
self.subject_index[subject_norm].discard(memory_id)
# 从关键词索引移除
for keyword in entry.keywords:
keyword_norm = keyword.strip().lower()
if keyword_norm in self.keyword_index:
self.keyword_index[keyword_norm].discard(memory_id)
# 从标签索引移除
for tag in entry.tags:
tag_norm = tag.strip().lower()
if tag_norm in self.tag_index:
self.tag_index[tag_norm].discard(memory_id)
def remove(self, memory_id: str):
"""移除索引条目"""
with self.lock:
if memory_id in self.index:
self._remove_from_inverted_indices(memory_id)
del self.index[memory_id]
def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]):
"""批量添加或更新"""
with self.lock:
for entry in entries:
self.add_or_update(entry)
def save(self):
"""保存索引到磁盘"""
with self.lock:
self._save_index()
def search(
self,
memory_types: Optional[List[str]] = None,
@@ -201,11 +201,11 @@ class MemoryMetadataIndex:
created_before: Optional[float] = None,
user_id: Optional[str] = None,
limit: Optional[int] = None,
flexible_mode: bool = True # 新增:灵活匹配模式
flexible_mode: bool = True, # 新增:灵活匹配模式
) -> List[str]:
"""
搜索符合条件的记忆ID列表支持模糊匹配
Returns:
List[str]: 符合条件的 memory_id 列表
"""
@@ -219,7 +219,7 @@ class MemoryMetadataIndex:
created_after=created_after,
created_before=created_before,
user_id=user_id,
limit=limit
limit=limit,
)
else:
return self._search_strict(
@@ -232,7 +232,7 @@ class MemoryMetadataIndex:
created_after=created_after,
created_before=created_before,
user_id=user_id,
limit=limit
limit=limit,
)
def _search_flexible(
@@ -243,7 +243,7 @@ class MemoryMetadataIndex:
created_before: Optional[float] = None,
user_id: Optional[str] = None,
limit: Optional[int] = None,
**kwargs # 接受但不使用的参数
**kwargs, # 接受但不使用的参数
) -> List[str]:
"""
灵活搜索模式2/4项匹配即可支持部分匹配
@@ -258,10 +258,7 @@ class MemoryMetadataIndex:
"""
# 用户过滤(必选)
if user_id:
base_candidates = {
mid for mid, entry in self.index.items()
if entry.user_id == user_id
}
base_candidates = {mid for mid, entry in self.index.items() if entry.user_id == user_id}
else:
base_candidates = set(self.index.keys())
@@ -386,7 +383,7 @@ class MemoryMetadataIndex:
created_after: Optional[float] = None,
created_before: Optional[float] = None,
user_id: Optional[str] = None,
limit: Optional[int] = None
limit: Optional[int] = None,
) -> List[str]:
"""严格搜索模式(原有逻辑)"""
# 初始候选集(所有记忆)
@@ -394,10 +391,7 @@ class MemoryMetadataIndex:
# 用户过滤(必选)
if user_id:
candidate_ids = {
mid for mid, entry in self.index.items()
if entry.user_id == user_id
}
candidate_ids = {mid for mid, entry in self.index.items() if entry.user_id == user_id}
else:
candidate_ids = set(self.index.keys())
@@ -447,7 +441,8 @@ class MemoryMetadataIndex:
# 重要性过滤
if importance_min is not None or importance_max is not None:
importance_ids = {
mid for mid in candidate_ids
mid
for mid in candidate_ids
if (importance_min is None or self.index[mid].importance >= importance_min)
and (importance_max is None or self.index[mid].importance <= importance_max)
}
@@ -456,41 +451,37 @@ class MemoryMetadataIndex:
# 时间范围过滤
if created_after is not None or created_before is not None:
time_ids = {
mid for mid in candidate_ids
mid
for mid in candidate_ids
if (created_after is None or self.index[mid].created_at >= created_after)
and (created_before is None or self.index[mid].created_at <= created_before)
}
candidate_ids &= time_ids
# 转换为列表并排序(按创建时间倒序)
result_ids = sorted(
candidate_ids,
key=lambda mid: self.index[mid].created_at,
reverse=True
)
result_ids = sorted(candidate_ids, key=lambda mid: self.index[mid].created_at, reverse=True)
# 限制数量
if limit:
result_ids = result_ids[:limit]
logger.debug(
f"[严格搜索] types={memory_types}, subjects={subjects}, "
f"keywords={keywords}, 返回={len(result_ids)}"
f"[严格搜索] types={memory_types}, subjects={subjects}, keywords={keywords}, 返回={len(result_ids)}"
)
return result_ids
def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]:
"""获取单个索引条目"""
return self.index.get(memory_id)
def get_stats(self) -> Dict[str, Any]:
"""获取索引统计信息"""
with self.lock:
return {
'total_memories': len(self.index),
'types': {mtype: len(ids) for mtype, ids in self.type_index.items()},
'subjects_count': len(self.subject_index),
'keywords_count': len(self.keyword_index),
'tags_count': len(self.tag_index),
"total_memories": len(self.index),
"types": {mtype: len(ids) for mtype, ids in self.type_index.items()},
"subjects_count": len(self.subject_index),
"keywords_count": len(self.keyword_index),
"tags_count": len(self.tag_index),
}

View File

@@ -80,10 +80,7 @@ class MemoryQueryPlanner:
return self._default_plan(query_text)
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
return MemoryQueryPlan(
semantic_query=query_text,
limit=self.default_limit
)
return MemoryQueryPlan(semantic_query=query_text, limit=self.default_limit)
def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
@@ -122,7 +119,7 @@ class MemoryQueryPlanner:
recency_preference=self._safe_str(data.get("recency")) or "any",
limit=self._safe_int(data.get("limit"), self.default_limit),
emphasis=self._safe_str(data.get("emphasis")) or "balanced",
raw_plan=data
raw_plan=data,
)
return plan
@@ -154,18 +151,18 @@ class MemoryQueryPlanner:
context_section = f"""
## 📋 未读消息上下文 (共{unread_context.get('total_count', 0)}条未读消息)
## 📋 未读消息上下文 (共{unread_context.get("total_count", 0)}条未读消息)
### 最近消息预览:
{chr(10).join(message_previews)}
### 上下文关键词:
{', '.join(unread_keywords[:15]) if unread_keywords else ''}
{", ".join(unread_keywords[:15]) if unread_keywords else ""}
### 对话参与者:
{', '.join(unread_participants) if unread_participants else ''}
{", ".join(unread_participants) if unread_participants else ""}
### 上下文摘要:
{context_summary[:300] if context_summary else ''}
{context_summary[:300] if context_summary else ""}
"""
else:
context_section = """
@@ -223,7 +220,7 @@ class MemoryQueryPlanner:
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start:end + 1]
return stripped[start : end + 1]
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
@@ -243,4 +240,4 @@ class MemoryQueryPlanner:
return default
return number
except (TypeError, ValueError):
return default
return default

View File

@@ -35,6 +35,7 @@ GLOBAL_MEMORY_SCOPE = "global"
class MemorySystemStatus(Enum):
"""记忆系统状态"""
INITIALIZING = "initializing"
READY = "ready"
BUILDING = "building"
@@ -45,6 +46,7 @@ class MemorySystemStatus(Enum):
@dataclass
class MemorySystemConfig:
"""记忆系统配置"""
# 记忆构建配置
min_memory_length: int = 10
max_memory_length: int = 500
@@ -97,11 +99,9 @@ class MemorySystemConfig:
max_memory_length=global_config.memory.max_memory_length,
memory_value_threshold=global_config.memory.memory_value_threshold,
min_build_interval_seconds=getattr(global_config.memory, "memory_build_interval", 300.0),
# 向量存储配置
vector_dimension=int(embedding_dimension),
similarity_threshold=global_config.memory.vector_similarity_threshold,
# 召回配置
coarse_recall_limit=global_config.memory.metadata_filter_limit,
fine_recall_limit=global_config.memory.vector_search_limit,
@@ -112,21 +112,16 @@ class MemorySystemConfig:
semantic_weight=global_config.memory.semantic_weight,
context_weight=global_config.memory.context_weight,
recency_weight=global_config.memory.recency_weight,
# 融合配置
fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold,
deduplication_window=timedelta(hours=global_config.memory.deduplication_window_hours)
deduplication_window=timedelta(hours=global_config.memory.deduplication_window_hours),
)
class MemorySystem:
"""精准记忆系统核心类"""
def __init__(
self,
llm_model: Optional[LLMRequest] = None,
config: Optional[MemorySystemConfig] = None
):
def __init__(self, llm_model: Optional[LLMRequest] = None, config: Optional[MemorySystemConfig] = None):
self.config = config or MemorySystemConfig.from_global_config()
self.llm_model = llm_model
self.status = MemorySystemStatus.INITIALIZING
@@ -175,16 +170,16 @@ class MemorySystem:
extraction_task_config = value_task_config or fallback_task
if value_task_config is None or extraction_task_config is None:
raise RuntimeError("无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。")
raise RuntimeError(
"无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。"
)
self.value_assessment_model = LLMRequest(
model_set=value_task_config,
request_type="memory.value_assessment"
model_set=value_task_config, request_type="memory.value_assessment"
)
self.memory_extraction_model = LLMRequest(
model_set=extraction_task_config,
request_type="memory.extraction"
model_set=extraction_task_config, request_type="memory.extraction"
)
# 初始化核心组件(简化版)
@@ -198,13 +193,13 @@ class MemorySystem:
memory_collection="unified_memory_v2",
metadata_collection="memory_metadata_v2",
similarity_threshold=self.config.similarity_threshold,
search_limit=getattr(global_config.memory, 'unified_storage_search_limit', 20),
batch_size=getattr(global_config.memory, 'unified_storage_batch_size', 100),
enable_caching=getattr(global_config.memory, 'unified_storage_enable_caching', True),
cache_size_limit=getattr(global_config.memory, 'unified_storage_cache_limit', 1000),
auto_cleanup_interval=getattr(global_config.memory, 'unified_storage_auto_cleanup_interval', 3600),
enable_forgetting=getattr(global_config.memory, 'enable_memory_forgetting', True),
retention_hours=getattr(global_config.memory, 'memory_retention_hours', 720) # 30天
search_limit=getattr(global_config.memory, "unified_storage_search_limit", 20),
batch_size=getattr(global_config.memory, "unified_storage_batch_size", 100),
enable_caching=getattr(global_config.memory, "unified_storage_enable_caching", True),
cache_size_limit=getattr(global_config.memory, "unified_storage_cache_limit", 1000),
auto_cleanup_interval=getattr(global_config.memory, "unified_storage_auto_cleanup_interval", 3600),
enable_forgetting=getattr(global_config.memory, "enable_memory_forgetting", True),
retention_hours=getattr(global_config.memory, "memory_retention_hours", 720), # 30天
)
try:
@@ -220,32 +215,27 @@ class MemorySystem:
# 从全局配置创建遗忘引擎配置
forgetting_config = ForgettingConfig(
# 检查频率配置
check_interval_hours=getattr(global_config.memory, 'forgetting_check_interval_hours', 24),
check_interval_hours=getattr(global_config.memory, "forgetting_check_interval_hours", 24),
batch_size=100, # 固定值,暂不配置
# 遗忘阈值配置
base_forgetting_days=getattr(global_config.memory, 'base_forgetting_days', 30.0),
min_forgetting_days=getattr(global_config.memory, 'min_forgetting_days', 7.0),
max_forgetting_days=getattr(global_config.memory, 'max_forgetting_days', 365.0),
base_forgetting_days=getattr(global_config.memory, "base_forgetting_days", 30.0),
min_forgetting_days=getattr(global_config.memory, "min_forgetting_days", 7.0),
max_forgetting_days=getattr(global_config.memory, "max_forgetting_days", 365.0),
# 重要程度权重
critical_importance_bonus=getattr(global_config.memory, 'critical_importance_bonus', 45.0),
high_importance_bonus=getattr(global_config.memory, 'high_importance_bonus', 30.0),
normal_importance_bonus=getattr(global_config.memory, 'normal_importance_bonus', 15.0),
low_importance_bonus=getattr(global_config.memory, 'low_importance_bonus', 0.0),
critical_importance_bonus=getattr(global_config.memory, "critical_importance_bonus", 45.0),
high_importance_bonus=getattr(global_config.memory, "high_importance_bonus", 30.0),
normal_importance_bonus=getattr(global_config.memory, "normal_importance_bonus", 15.0),
low_importance_bonus=getattr(global_config.memory, "low_importance_bonus", 0.0),
# 置信度权重
verified_confidence_bonus=getattr(global_config.memory, 'verified_confidence_bonus', 30.0),
high_confidence_bonus=getattr(global_config.memory, 'high_confidence_bonus', 20.0),
medium_confidence_bonus=getattr(global_config.memory, 'medium_confidence_bonus', 10.0),
low_confidence_bonus=getattr(global_config.memory, 'low_confidence_bonus', 0.0),
verified_confidence_bonus=getattr(global_config.memory, "verified_confidence_bonus", 30.0),
high_confidence_bonus=getattr(global_config.memory, "high_confidence_bonus", 20.0),
medium_confidence_bonus=getattr(global_config.memory, "medium_confidence_bonus", 10.0),
low_confidence_bonus=getattr(global_config.memory, "low_confidence_bonus", 0.0),
# 激活频率权重
activation_frequency_weight=getattr(global_config.memory, 'activation_frequency_weight', 0.5),
max_frequency_bonus=getattr(global_config.memory, 'max_frequency_bonus', 10.0),
activation_frequency_weight=getattr(global_config.memory, "activation_frequency_weight", 0.5),
max_frequency_bonus=getattr(global_config.memory, "max_frequency_bonus", 10.0),
# 休眠配置
dormant_threshold_days=getattr(global_config.memory, 'dormant_threshold_days', 90)
dormant_threshold_days=getattr(global_config.memory, "dormant_threshold_days", 90),
)
self.forgetting_engine = MemoryForgettingEngine(forgetting_config)
@@ -253,17 +243,11 @@ class MemorySystem:
planner_task_config = getattr(model_config.model_task_config, "utils_small", None)
planner_model: Optional[LLMRequest] = None
try:
planner_model = LLMRequest(
model_set=planner_task_config,
request_type="memory.query_planner"
)
planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner")
except Exception as planner_exc:
logger.warning("查询规划模型初始化失败,将使用默认规划策略: %s", planner_exc, exc_info=True)
self.query_planner = MemoryQueryPlanner(
planner_model,
default_limit=self.config.final_recall_limit
)
self.query_planner = MemoryQueryPlanner(planner_model, default_limit=self.config.final_recall_limit)
# 统一存储已经自动加载数据,无需额外加载
logger.info("✅ 简化版记忆系统初始化完成")
@@ -277,11 +261,7 @@ class MemorySystem:
raise
async def retrieve_memories_for_building(
self,
query_text: str,
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
limit: int = 5
self, query_text: str, user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5
) -> List[MemoryChunk]:
"""在构建记忆时检索相关记忆,使用统一存储系统
@@ -304,9 +284,7 @@ class MemorySystem:
try:
# 使用统一存储检索相似记忆
search_results = await self.unified_storage.search_similar_memories(
query_text=query_text,
limit=limit,
scope_id=user_id
query_text=query_text, limit=limit, scope_id=user_id
)
# 转换为记忆对象
@@ -324,10 +302,7 @@ class MemorySystem:
return []
async def build_memory_from_conversation(
self,
conversation_text: str,
context: Dict[str, Any],
timestamp: Optional[float] = None
self, conversation_text: str, context: Dict[str, Any], timestamp: Optional[float] = None
) -> List[MemoryChunk]:
"""从对话中构建记忆
@@ -383,7 +358,7 @@ class MemorySystem:
conversation_text,
normalized_context,
GLOBAL_MEMORY_SCOPE, # 强制使用 global不区分用户
timestamp or time.time()
timestamp or time.time(),
)
if not memory_chunks:
@@ -393,10 +368,7 @@ class MemorySystem:
# 3. 记忆融合与去重(包含与历史记忆的融合)
existing_candidates = await self._collect_fusion_candidates(memory_chunks)
fused_chunks = await self.fusion_engine.fuse_memories(
memory_chunks,
existing_candidates
)
fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks, existing_candidates)
# 4. 存储记忆到统一存储
stored_count = await self._store_memories_unified(fused_chunks)
@@ -459,11 +431,7 @@ class MemorySystem:
return []
candidate_ids: Set[str] = set()
new_memory_ids = {
memory.memory_id
for memory in new_memories
if memory and getattr(memory, "memory_id", None)
}
new_memory_ids = {memory.memory_id for memory in new_memories if memory and getattr(memory, "memory_id", None)}
# 基于指纹的直接匹配
for memory in new_memories:
@@ -501,9 +469,7 @@ class MemorySystem:
continue
search_tasks.append(
self.unified_storage.search_similar_memories(
query_text=display_text,
limit=8,
scope_id=GLOBAL_MEMORY_SCOPE
query_text=display_text, limit=8, scope_id=GLOBAL_MEMORY_SCOPE
)
)
@@ -545,10 +511,7 @@ class MemorySystem:
return existing_candidates
async def process_conversation_memory(
self,
context: Dict[str, Any]
) -> Dict[str, Any]:
async def process_conversation_memory(self, context: Dict[str, Any]) -> Dict[str, Any]:
"""对外暴露的对话记忆处理接口,仅依赖上下文信息"""
start_time = time.time()
@@ -563,7 +526,9 @@ class MemorySystem:
or ""
)
conversation_text = conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate)
conversation_text = (
conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate)
)
timestamp = context.get("timestamp")
if timestamp is None:
@@ -573,9 +538,7 @@ class MemorySystem:
normalized_context.setdefault("conversation_text", conversation_text)
memories = await self.build_memory_from_conversation(
conversation_text=conversation_text,
context=normalized_context,
timestamp=timestamp
conversation_text=conversation_text, context=normalized_context, timestamp=timestamp
)
processing_time = time.time() - start_time
@@ -586,18 +549,13 @@ class MemorySystem:
"created_memories": memories,
"memory_count": memory_count,
"processing_time": processing_time,
"status": self.status.value
"status": self.status.value,
}
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"对话记忆处理失败: {e}", exc_info=True)
return {
"success": False,
"error": str(e),
"processing_time": processing_time,
"status": self.status.value
}
return {"success": False, "error": str(e), "processing_time": processing_time, "status": self.status.value}
async def retrieve_relevant_memories(
self,
@@ -605,7 +563,7 @@ class MemorySystem:
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
limit: int = 5,
**kwargs
**kwargs,
) -> List[MemoryChunk]:
"""检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)"""
raw_query = query_text or kwargs.get("query")
@@ -617,7 +575,7 @@ class MemorySystem:
return []
context = context or {}
# 所有记忆完全共享,统一使用 global 作用域,不区分用户
resolved_user_id = GLOBAL_MEMORY_SCOPE
@@ -627,7 +585,7 @@ class MemorySystem:
try:
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, None)
effective_limit = self.config.final_recall_limit
# === 阶段一:元数据粗筛(软性过滤) ===
coarse_filters = {
"user_id": GLOBAL_MEMORY_SCOPE, # 必选:确保作用域正确
@@ -642,119 +600,126 @@ class MemorySystem:
# 构建包含未读消息的增强上下文
enhanced_context = await self._build_enhanced_query_context(raw_query, normalized_context)
query_plan = await self.query_planner.plan_query(raw_query, enhanced_context)
# 使用LLM优化后的查询语句更精确的语义表达
if getattr(query_plan, "semantic_query", None):
optimized_query = query_plan.semantic_query
# 构建JSON元数据过滤条件用于阶段一粗筛
# 将查询规划的结果转换为元数据过滤条件
if getattr(query_plan, "memory_types", None):
metadata_filters['memory_types'] = [mt.value for mt in query_plan.memory_types]
metadata_filters["memory_types"] = [mt.value for mt in query_plan.memory_types]
if getattr(query_plan, "subject_includes", None):
metadata_filters['subjects'] = query_plan.subject_includes
metadata_filters["subjects"] = query_plan.subject_includes
if getattr(query_plan, "required_keywords", None):
metadata_filters['keywords'] = query_plan.required_keywords
metadata_filters["keywords"] = query_plan.required_keywords
# 时间范围过滤
recency = getattr(query_plan, "recency_preference", "any")
current_time = time.time()
if recency == "recent":
# 最近7天
metadata_filters['created_after'] = current_time - (7 * 24 * 3600)
metadata_filters["created_after"] = current_time - (7 * 24 * 3600)
elif recency == "historical":
# 30天以前
metadata_filters['created_before'] = current_time - (30 * 24 * 3600)
metadata_filters["created_before"] = current_time - (30 * 24 * 3600)
# 添加用户ID到元数据过滤
metadata_filters['user_id'] = GLOBAL_MEMORY_SCOPE
metadata_filters["user_id"] = GLOBAL_MEMORY_SCOPE
logger.debug(f"[阶段一] 查询优化: '{raw_query}''{optimized_query}'")
logger.debug(f"[阶段一] 元数据过滤条件: {metadata_filters}")
except Exception as plan_exc:
logger.warning("查询规划失败,使用原始查询: %s", plan_exc, exc_info=True)
# 即使查询规划失败也保留基本的user_id过滤
metadata_filters = {'user_id': GLOBAL_MEMORY_SCOPE}
metadata_filters = {"user_id": GLOBAL_MEMORY_SCOPE}
# === 阶段二:向量精筛 ===
coarse_limit = self.config.coarse_recall_limit # 粗筛阶段返回更多候选
logger.debug(f"[阶段二] 开始向量搜索: query='{optimized_query[:60]}...', limit={coarse_limit}")
search_results = await self.unified_storage.search_similar_memories(
query_text=optimized_query,
limit=coarse_limit,
filters=coarse_filters, # ChromaDB where条件保留兼容
metadata_filters=metadata_filters # JSON元数据索引过滤
metadata_filters=metadata_filters, # JSON元数据索引过滤
)
logger.info(f"[阶段二] 向量搜索完成: 返回 {len(search_results)} 条候选")
# === 阶段三:综合重排 ===
scored_memories = []
current_time = time.time()
for memory, vector_similarity in search_results:
# 1. 向量相似度得分(已归一化到 0-1
vector_score = vector_similarity
# 2. 时效性得分指数衰减30天半衰期
age_seconds = current_time - memory.metadata.created_at
age_days = age_seconds / (24 * 3600)
# 使用 math.exp 而非 np.exp避免依赖numpy
import math
recency_score = math.exp(-age_days / 30)
# 3. 重要性得分(枚举值转换为归一化得分 0-1
# ImportanceLevel: LOW=1, NORMAL=2, HIGH=3, CRITICAL=4
importance_enum = memory.metadata.importance
if hasattr(importance_enum, 'value'):
if hasattr(importance_enum, "value"):
# 枚举类型转换为0-1范围(value - 1) / 3
importance_score = (importance_enum.value - 1) / 3.0
else:
# 如果已经是数值,直接使用
importance_score = float(importance_enum) if importance_enum else 0.5
# 4. 访问频率得分归一化访问10次以上得满分
access_count = memory.metadata.access_count
frequency_score = min(access_count / 10.0, 1.0)
# 综合得分(加权平均)
final_score = (
self.config.vector_weight * vector_score +
self.config.recency_weight * recency_score +
self.config.context_weight * importance_score +
0.1 * frequency_score # 访问频率权重固定10%
self.config.vector_weight * vector_score
+ self.config.recency_weight * recency_score
+ self.config.context_weight * importance_score
+ 0.1 * frequency_score # 访问频率权重固定10%
)
scored_memories.append((memory, final_score, {
"vector": vector_score,
"recency": recency_score,
"importance": importance_score,
"frequency": frequency_score,
"final": final_score
}))
scored_memories.append(
(
memory,
final_score,
{
"vector": vector_score,
"recency": recency_score,
"importance": importance_score,
"frequency": frequency_score,
"final": final_score,
},
)
)
# 更新访问记录
memory.update_access()
# 按综合得分排序
scored_memories.sort(key=lambda x: x[1], reverse=True)
# 返回 Top-K
final_memories = [mem for mem, score, details in scored_memories[:effective_limit]]
retrieval_time = time.time() - start_time
# 详细日志
if scored_memories:
logger.info(f"[阶段三] 综合重排完成: Top 3 得分详情")
logger.info("[阶段三] 综合重排完成: Top 3 得分详情")
for i, (mem, score, details) in enumerate(scored_memories[:3], 1):
try:
summary = mem.content[:60] if hasattr(mem, 'content') and mem.content else ""
summary = mem.content[:60] if hasattr(mem, "content") and mem.content else ""
except:
summary = ""
logger.info(
@@ -803,15 +768,12 @@ class MemorySystem:
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start:end + 1].strip()
return stripped[start : end + 1].strip()
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
def _normalize_context(
self,
raw_context: Optional[Dict[str, Any]],
user_id: Optional[str],
timestamp: Optional[float]
self, raw_context: Optional[Dict[str, Any]], user_id: Optional[str], timestamp: Optional[float]
) -> Dict[str, Any]:
"""标准化上下文,确保必备字段存在且格式正确"""
context: Dict[str, Any] = {}
@@ -850,9 +812,7 @@ class MemorySystem:
# 历史窗口配置
window_candidate = (
context.get("history_limit")
or context.get("history_window")
or context.get("memory_history_limit")
context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit")
)
if window_candidate is not None:
try:
@@ -888,7 +848,9 @@ class MemorySystem:
enhanced_context["unread_messages_context"] = unread_messages_summary
enhanced_context["has_unread_context"] = True
logger.debug(f"为查询规划构建了增强上下文,包含 {len(unread_messages_summary.get('messages', []))} 条未读消息")
logger.debug(
f"为查询规划构建了增强上下文,包含 {len(unread_messages_summary.get('messages', []))} 条未读消息"
)
else:
enhanced_context["has_unread_context"] = False
logger.debug("未找到未读消息,使用基础上下文进行查询规划")
@@ -934,26 +896,30 @@ class MemorySystem:
for msg in unread_messages[:10]: # 限制处理最近10条未读消息
try:
# 提取消息内容
content = (getattr(msg, "processed_plain_text", None) or
getattr(msg, "display_message", None) or "")
content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None) or ""
if not content:
continue
# 提取发送者信息
sender_name = "未知用户"
if hasattr(msg, "user_info") and msg.user_info:
sender_name = (getattr(msg.user_info, "user_nickname", None) or
getattr(msg.user_info, "user_cardname", None) or
getattr(msg.user_info, "user_id", None) or "未知用户")
sender_name = (
getattr(msg.user_info, "user_nickname", None)
or getattr(msg.user_info, "user_cardname", None)
or getattr(msg.user_info, "user_id", None)
or "未知用户"
)
participant_names.add(sender_name)
# 添加到消息摘要
messages_summary.append({
"sender": sender_name,
"content": content[:200], # 限制长度避免过长
"timestamp": getattr(msg, "time", None)
})
messages_summary.append(
{
"sender": sender_name,
"content": content[:200], # 限制长度避免过长
"timestamp": getattr(msg, "time", None),
}
)
# 提取关键词(简单实现)
content_lower = content.lower()
@@ -975,10 +941,12 @@ class MemorySystem:
"processed_count": len(messages_summary),
"keywords": list(all_keywords)[:20], # 最多20个关键词
"participants": list(participant_names),
"context_summary": self._build_unread_context_summary(messages_summary)
"context_summary": self._build_unread_context_summary(messages_summary),
}
logger.debug(f"收集到未读消息上下文: {len(messages_summary)}条消息,{len(all_keywords)}个关键词,{len(participant_names)}个参与者")
logger.debug(
f"收集到未读消息上下文: {len(messages_summary)}条消息,{len(all_keywords)}个关键词,{len(participant_names)}个参与者"
)
return unread_context
except Exception as e:
@@ -1051,10 +1019,7 @@ class MemorySystem:
if user_id and fallback_text:
try:
relevant_memories = await self.retrieve_memories_for_building(
query_text=fallback_text,
user_id=user_id,
context=context,
limit=3
query_text=fallback_text, user_id=user_id, context=context, limit=3
)
if relevant_memories:
@@ -1068,9 +1033,7 @@ class MemorySystem:
memory_transcript = f"{memory_transcript}\n[当前消息] {cleaned_fallback}"
logger.debug(
"使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s",
len(relevant_memories),
user_id
"使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s", len(relevant_memories), user_id
)
return memory_transcript
@@ -1087,11 +1050,7 @@ class MemorySystem:
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
"""确定历史消息获取数量限制在30-50之间"""
default_limit = 40
candidate = (
context.get("history_limit")
or context.get("history_window")
or context.get("memory_history_limit")
)
candidate = context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit")
if isinstance(candidate, str):
try:
@@ -1186,9 +1145,9 @@ class MemorySystem:
{text}
上下文信息:
- 用户ID: {context.get('user_id', 'unknown')}
- 消息类型: {context.get('message_type', 'unknown')}
- 时间: {datetime.fromtimestamp(context.get('timestamp', time.time()))}
- 用户ID: {context.get("user_id", "unknown")}
- 消息类型: {context.get("message_type", "unknown")}
- 时间: {datetime.fromtimestamp(context.get("timestamp", time.time()))}
## 📋 评估要求:
@@ -1214,9 +1173,7 @@ class MemorySystem:
}}
"""
response, _ = await self.value_assessment_model.generate_response_async(
prompt, temperature=0.3
)
response, _ = await self.value_assessment_model.generate_response_async(prompt, temperature=0.3)
# 解析响应
try:
@@ -1236,7 +1193,7 @@ class MemorySystem:
return max(0.0, min(1.0, value_score))
except (orjson.JSONDecodeError, ValueError) as e:
preview = response[:200].replace('\n', ' ')
preview = response[:200].replace("\n", " ")
logger.warning(f"解析价值评估响应失败: {e}, 响应片段: {preview}")
return 0.5 # 默认中等价值
@@ -1331,13 +1288,15 @@ class MemorySystem:
else:
obj_part = str(obj).strip()
base = "|".join([
str(memory.user_id or "unknown"),
memory.memory_type.value,
subject_part,
predicate_part,
obj_part,
])
base = "|".join(
[
str(memory.user_id or "unknown"),
memory.memory_type.value,
subject_part,
predicate_part,
obj_part,
]
)
return hashlib.sha256(base.encode("utf-8")).hexdigest()
@@ -1352,7 +1311,7 @@ class MemorySystem:
"total_memories": self.total_memories,
"last_build_time": self.last_build_time,
"last_retrieval_time": self.last_retrieval_time,
"config": asdict(self.config)
"config": asdict(self.config),
}
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
@@ -1369,7 +1328,9 @@ class MemorySystem:
keyword_overlap = 0.0
if context_keywords:
memory_keywords = set(k.lower() for k in memory.keywords)
keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max(len(context_keywords), 1)
keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max(
len(context_keywords), 1
)
importance_boost = (memory.metadata.importance.value - 1) / 3 * 0.1
confidence_boost = (memory.metadata.confidence.value - 1) / 3 * 0.05
@@ -1429,7 +1390,7 @@ class MemorySystem:
"""重建向量存储(如果需要)"""
try:
# 检查是否有记忆缓存数据
if not hasattr(self.unified_storage, 'memory_cache') or not self.unified_storage.memory_cache:
if not hasattr(self.unified_storage, "memory_cache") or not self.unified_storage.memory_cache:
logger.info("无记忆缓存数据,跳过向量存储重建")
return
@@ -1443,19 +1404,19 @@ class MemorySystem:
memories_to_rebuild.append(memory)
elif memory.text_content and memory.text_content.strip():
memories_to_rebuild.append(memory)
if not memories_to_rebuild:
logger.warning("没有找到可重建向量的记忆")
return
logger.info(f"准备为 {len(memories_to_rebuild)} 条记忆重建向量")
# 批量重建向量
batch_size = 10
rebuild_count = 0
for i in range(0, len(memories_to_rebuild), batch_size):
batch = memories_to_rebuild[i:i + batch_size]
batch = memories_to_rebuild[i : i + batch_size]
try:
await self.unified_storage.store_memories(batch)
rebuild_count += len(batch)
@@ -1472,7 +1433,7 @@ class MemorySystem:
final_count = self.unified_storage.storage_stats.get("total_vectors", 0)
logger.info(f"✅ 向量存储重建完成,最终向量数量: {final_count}")
except Exception as e:
logger.error(f"❌ 向量存储重建失败: {e}", exc_info=True)
@@ -1495,4 +1456,4 @@ async def initialize_memory_system(llm_model: Optional[LLMRequest] = None):
if memory_system is None:
memory_system = MemorySystem(llm_model=llm_model)
await memory_system.initialize()
return memory_system
return memory_system

File diff suppressed because it is too large Load Diff

View File

@@ -5,15 +5,12 @@
from .message_manager import MessageManager, message_manager
from .context_manager import SingleStreamContextManager
from .distribution_manager import (
StreamLoopManager,
stream_loop_manager
)
from .distribution_manager import StreamLoopManager, stream_loop_manager
__all__ = [
"MessageManager",
"message_manager",
"SingleStreamContextManager",
"StreamLoopManager",
"stream_loop_manager"
]
"stream_loop_manager",
]

View File

@@ -230,12 +230,14 @@ class SingleStreamContextManager:
异步计算消息的兴趣度。
此方法通过检查当前是否存在正在运行的 asyncio 事件循环来兼容同步和异步调用。
"""
# 内部异步函数,封装实际的计算逻辑
async def _get_score():
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system,
)
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
message=message, bot_nickname=global_config.bot.nickname
)

View File

@@ -34,17 +34,13 @@ class StreamLoopManager:
}
# 配置参数
self.max_concurrent_streams = (
max_concurrent_streams or global_config.chat.max_concurrent_distributions
)
self.max_concurrent_streams = max_concurrent_streams or global_config.chat.max_concurrent_distributions
# 强制分发策略
self.force_dispatch_unread_threshold: Optional[int] = getattr(
global_config.chat, "force_dispatch_unread_threshold", 20
)
self.force_dispatch_min_interval: float = getattr(
global_config.chat, "force_dispatch_min_interval", 0.1
)
self.force_dispatch_min_interval: float = getattr(global_config.chat, "force_dispatch_min_interval", 0.1)
# Chatter管理器
self.chatter_manager: Optional[ChatterManager] = None
@@ -108,7 +104,9 @@ class StreamLoopManager:
if force and len(self.stream_loops) >= self.max_concurrent_streams:
logger.warning(
"%s 未读消息积压严重(>%s),突破并发限制强制启动分发", stream_id, self.force_dispatch_unread_threshold
"%s 未读消息积压严重(>%s),突破并发限制强制启动分发",
stream_id,
self.force_dispatch_unread_threshold,
)
# 创建流循环任务
@@ -168,9 +166,7 @@ class StreamLoopManager:
if has_messages:
if force_dispatch:
logger.info(
"%s 未读消息 %d 条,触发强制分发", stream_id, unread_count
)
logger.info("%s 未读消息 %d 条,触发强制分发", stream_id, unread_count)
# 3. 激活chatter处理
success = await self._process_stream_messages(stream_id, context)

View File

@@ -11,7 +11,7 @@ from typing import Dict, Optional, Any, TYPE_CHECKING, List
from src.chat.message_receive.chat_stream import ChatStream
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import StreamContext, MessageManagerStats, StreamStats
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats
from src.chat.chatter_manager import ChatterManager
from src.chat.planner_actions.action_manager import ChatterActionManager
from .sleep_manager.sleep_manager import SleepManager
@@ -21,7 +21,7 @@ from src.plugin_system.apis.chat_api import get_chat_manager
from .distribution_manager import stream_loop_manager
if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext
pass
logger = get_logger("message_manager")
@@ -63,7 +63,7 @@ class MessageManager:
stream_loop_manager.set_chatter_manager(self.chatter_manager)
logger.info("🚀 消息管理器已启动 | 流循环管理器已启动")
async def stop(self):
"""停止消息管理器"""
if not self.is_running:
@@ -88,7 +88,9 @@ class MessageManager:
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return
await self._check_and_handle_interruption(chat_stream)
chat_stream.context_manager.context.processing_task = asyncio.create_task(chat_stream.context_manager.add_message(message))
chat_stream.context_manager.context.processing_task = asyncio.create_task(
chat_stream.context_manager.add_message(message)
)
except Exception as e:
logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}")
@@ -141,11 +143,7 @@ class MessageManager:
if not message_id:
continue
payload = {
key: value
for key, value in item.items()
if key != "message_id" and value is not None
}
payload = {key: value for key, value in item.items() if key != "message_id" and value is not None}
if not payload:
continue
@@ -169,9 +167,7 @@ class MessageManager:
if not chat_stream:
logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在")
return
success = await chat_stream.context_manager.update_message(
message_id, {"actions": [action]}
)
success = await chat_stream.context_manager.update_message(message_id, {"actions": [action]})
if success:
logger.debug(f"为消息 {message_id} 添加动作 {action} 成功")
else:
@@ -193,7 +189,7 @@ class MessageManager:
context.is_active = False
# 取消处理任务
if hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done():
if hasattr(context, "processing_task") and context.processing_task and not context.processing_task.done():
context.processing_task.cancel()
logger.info(f"停用聊天流: {stream_id}")
@@ -236,7 +232,11 @@ class MessageManager:
unread_count=unread_count,
history_count=len(context.history_messages),
last_check_time=context.last_check_time,
has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()),
has_active_task=bool(
hasattr(context, "processing_task")
and context.processing_task
and not context.processing_task.done()
),
)
except Exception as e:
@@ -284,7 +284,10 @@ class MessageManager:
return
# 检查是否有正在进行的处理任务
if chat_stream.context_manager.context.processing_task and not chat_stream.context_manager.context.processing_task.done():
if (
chat_stream.context_manager.context.processing_task
and not chat_stream.context_manager.context.processing_task.done()
):
# 计算打断概率
interruption_probability = chat_stream.context_manager.context.calculate_interruption_probability(
global_config.chat.interruption_max_limit, global_config.chat.interruption_probability_factor
@@ -310,7 +313,9 @@ class MessageManager:
# 增加打断计数并应用afc阈值降低
chat_stream.context_manager.context.increment_interruption_count()
chat_stream.context_manager.context.apply_interruption_afc_reduction(global_config.chat.interruption_afc_reduction)
chat_stream.context_manager.context.apply_interruption_afc_reduction(
global_config.chat.interruption_afc_reduction
)
# 检查是否已达到最大次数
if chat_stream.context_manager.context.interruption_count >= global_config.chat.interruption_max_limit:
@@ -364,7 +369,7 @@ class MessageManager:
return
context = chat_stream.context_manager.context
if hasattr(context, 'unread_messages') and context.unread_messages:
if hasattr(context, "unread_messages") and context.unread_messages:
logger.debug(f"正在为流 {stream_id} 清除 {len(context.unread_messages)} 条未读消息")
context.unread_messages.clear()
else:

View File

@@ -1,33 +1,33 @@
from src.common.logger import get_logger
#from ..hfc_context import HfcContext
# from ..hfc_context import HfcContext
logger = get_logger("notification_sender")
class NotificationSender:
@staticmethod
async def send_goodnight_notification(context): # type: ignore
async def send_goodnight_notification(context): # type: ignore
"""发送晚安通知"""
#try:
#from ..proactive.events import ProactiveTriggerEvent
#from ..proactive.proactive_thinker import ProactiveThinker
#event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight")
#proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
#await proactive_thinker.think(event)
#except Exception as e:
#logger.error(f"发送晚安通知失败: {e}")
# try:
# from ..proactive.events import ProactiveTriggerEvent
# from ..proactive.proactive_thinker import ProactiveThinker
# event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight")
# proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
# await proactive_thinker.think(event)
# except Exception as e:
# logger.error(f"发送晚安通知失败: {e}")
@staticmethod
async def send_insomnia_notification(context, reason: str): # type: ignore
async def send_insomnia_notification(context, reason: str): # type: ignore
"""发送失眠通知"""
#try:
#from ..proactive.events import ProactiveTriggerEvent
#from ..proactive.proactive_thinker import ProactiveThinker
# try:
# from ..proactive.events import ProactiveTriggerEvent
# from ..proactive.proactive_thinker import ProactiveThinker
#event = ProactiveTriggerEvent(source="sleep_manager", reason=reason)
#proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
#await proactive_thinker.think(event)
#except Exception as e:
#logger.error(f"发送失眠通知失败: {e}")
# event = ProactiveTriggerEvent(source="sleep_manager", reason=reason)
# proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
# await proactive_thinker.think(event)
# except Exception as e:
# logger.error(f"发送失眠通知失败: {e}")

View File

@@ -1,6 +1,6 @@
import asyncio
import random
from datetime import datetime, timedelta, date
from datetime import datetime, timedelta
from typing import Optional, TYPE_CHECKING
from src.common.logger import get_logger
@@ -21,6 +21,7 @@ class SleepManager:
它实现了一个状态机,根据预设的时间表、睡眠压力和随机因素,
在不同的睡眠状态(如清醒、准备入睡、睡眠、失眠)之间进行切换。
"""
def __init__(self):
"""
初始化睡眠管理器。
@@ -97,7 +98,7 @@ class SleepManager:
logger.info(f"进入理论休眠时间 '{activity}',开始进行睡眠决策...")
else:
logger.info("进入理论休眠时间,开始进行睡眠决策...")
if global_config.sleep_system.enable_flexible_sleep:
# --- 新的弹性睡眠逻辑 ---
if wakeup_manager:
@@ -112,7 +113,7 @@ class SleepManager:
pressure_diff = (pressure_threshold - sleep_pressure) / pressure_threshold
# 延迟分钟数,压力越低,延迟越长
delay_minutes = int(pressure_diff * max_delay_minutes)
# 确保总延迟不超过当日最大值
remaining_delay = max_delay_minutes - self.context.total_delayed_minutes_today
delay_minutes = min(delay_minutes, remaining_delay)
@@ -151,9 +152,10 @@ class SleepManager:
if wakeup_manager and global_config.sleep_system.enable_pre_sleep_notification:
asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context))
self.context.current_state = SleepState.SLEEPING
def _handle_preparing_sleep(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]):
def _handle_preparing_sleep(
self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]
):
"""处理“准备入睡”状态下的逻辑。"""
# 如果在准备期间离开了理论睡眠时间,则取消入睡
if not is_in_theoretical_sleep:
@@ -166,16 +168,22 @@ class SleepManager:
logger.info("睡眠缓冲期结束,正式进入休眠状态。")
self.context.current_state = SleepState.SLEEPING
self._last_fully_slept_log_time = now.timestamp()
# 设置一个随机的延迟,用于触发“睡后失眠”检查
delay_minutes_range = global_config.sleep_system.insomnia_trigger_delay_minutes
delay_minutes = random.randint(delay_minutes_range[0], delay_minutes_range[1])
self.context.sleep_buffer_end_time = now + timedelta(minutes=delay_minutes)
logger.info(f"已设置睡后失眠检查,将在 {delay_minutes} 分钟后触发。")
self.context.save()
def _handle_sleeping(self, now: datetime, is_in_theoretical_sleep: bool, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
def _handle_sleeping(
self,
now: datetime,
is_in_theoretical_sleep: bool,
activity: Optional[str],
wakeup_manager: Optional["WakeUpManager"],
):
"""处理“正在睡觉”状态下的逻辑。"""
# 如果理论睡眠时间结束,则自然醒来
if not is_in_theoretical_sleep:
@@ -198,14 +206,16 @@ class SleepManager:
if insomnia_reason:
self.context.current_state = SleepState.INSOMNIA
# 设置失眠的持续时间
duration_minutes_range = global_config.sleep_system.insomnia_duration_minutes
duration_minutes = random.randint(*duration_minutes_range)
self.context.sleep_buffer_end_time = now + timedelta(minutes=duration_minutes)
# 发送失眠通知
asyncio.create_task(NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason))
asyncio.create_task(
NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason)
)
logger.info(f"进入失眠状态 (原因: {insomnia_reason}),将持续 {duration_minutes} 分钟。")
else:
# 睡眠压力正常,不触发失眠,清除检查时间点

View File

@@ -25,6 +25,7 @@ class SleepContext:
"""
睡眠上下文,负责封装和管理所有与睡眠相关的状态,并处理其持久化。
"""
def __init__(self):
"""初始化睡眠上下文,并从本地存储加载初始状态。"""
self.current_state: SleepState = SleepState.AWAKE
@@ -83,4 +84,4 @@ class SleepContext:
logger.info(f"成功从本地存储加载睡眠上下文: {state}")
except Exception as e:
logger.warning(f"加载睡眠上下文失败,将使用默认值: {e}")
logger.warning(f"加载睡眠上下文失败,将使用默认值: {e}")

View File

@@ -15,23 +15,25 @@ class TimeChecker:
self._daily_sleep_offset: int = 0
self._daily_wake_offset: int = 0
self._offset_date = None
def _get_daily_offsets(self):
"""获取当天的睡眠和起床时间偏移量,每天生成一次"""
today = datetime.now().date()
# 如果是新的一天,重新生成偏移量
if self._offset_date != today:
sleep_offset_range = global_config.sleep_system.sleep_time_offset_minutes
wake_offset_range = global_config.sleep_system.wake_up_time_offset_minutes
# 生成 ±offset_range 范围内的随机偏移量
self._daily_sleep_offset = random.randint(-sleep_offset_range, sleep_offset_range)
self._daily_wake_offset = random.randint(-wake_offset_range, wake_offset_range)
self._offset_date = today
logger.debug(f"生成新的每日偏移量 - 睡觉时间偏移: {self._daily_sleep_offset}分钟, 起床时间偏移: {self._daily_wake_offset}分钟")
logger.debug(
f"生成新的每日偏移量 - 睡觉时间偏移: {self._daily_sleep_offset}分钟, 起床时间偏移: {self._daily_wake_offset}分钟"
)
return self._daily_sleep_offset, self._daily_wake_offset
@staticmethod
@@ -82,28 +84,36 @@ class TimeChecker:
try:
start_time_str = global_config.sleep_system.fixed_sleep_time
end_time_str = global_config.sleep_system.fixed_wake_up_time
# 获取当天的偏移量
sleep_offset, wake_offset = self._get_daily_offsets()
# 解析基础时间
base_start_time = datetime.strptime(start_time_str, "%H:%M")
base_end_time = datetime.strptime(end_time_str, "%H:%M")
# 应用偏移量
actual_start_time = (base_start_time + timedelta(minutes=sleep_offset)).time()
actual_end_time = (base_end_time + timedelta(minutes=wake_offset)).time()
logger.debug(f"固定睡眠时间检查 - 基础时间: {start_time_str}-{end_time_str}, "
f"偏移后时间: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')}, "
f"当前时间: {now_time.strftime('%H:%M')}")
logger.debug(
f"固定睡眠时间检查 - 基础时间: {start_time_str}-{end_time_str}, "
f"偏移后时间: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')}, "
f"当前时间: {now_time.strftime('%H:%M')}"
)
if actual_start_time <= actual_end_time:
if actual_start_time <= now_time < actual_end_time:
return True, f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})"
return (
True,
f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})",
)
else:
if now_time >= actual_start_time or now_time < actual_end_time:
return True, f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})"
return (
True,
f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})",
)
except ValueError as e:
logger.error(f"固定的睡眠时间格式不正确,请使用 HH:MM 格式: {e}")
return False, None
return False, None

View File

@@ -1,4 +1,3 @@
import time
from src.common.logger import get_logger
from src.manager.local_store_manager import local_storage
@@ -9,6 +8,7 @@ class WakeUpContext:
"""
唤醒上下文,负责封装和管理所有与唤醒相关的状态,并处理其持久化。
"""
def __init__(self):
"""初始化唤醒上下文,并从本地存储加载初始状态。"""
self.wakeup_value: float = 0.0
@@ -42,4 +42,4 @@ class WakeUpContext:
"sleep_pressure": self.sleep_pressure,
}
local_storage[self._get_storage_key()] = state
logger.debug(f"已将唤醒上下文保存到本地存储: {state}")
logger.debug(f"已将唤醒上下文保存到本地存储: {state}")

View File

@@ -3,7 +3,6 @@ import time
from typing import Optional, TYPE_CHECKING
from src.common.logger import get_logger
from src.config.config import global_config
from src.manager.local_store_manager import local_storage
from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext
if TYPE_CHECKING:
@@ -51,7 +50,7 @@ class WakeUpManager:
if not self.enabled:
logger.info("唤醒度系统已禁用,跳过启动")
return
self.is_running = True
if not self._decay_task or self._decay_task.done():
self._decay_task = asyncio.create_task(self._decay_loop())
@@ -88,6 +87,7 @@ class WakeUpManager:
self.context.is_angry = False
# 通知情绪管理系统清除愤怒状态
from src.mood.mood_manager import mood_manager
if self.angry_chat_id:
mood_manager.clear_angry_from_wakeup(self.angry_chat_id)
self.angry_chat_id = None
@@ -104,7 +104,9 @@ class WakeUpManager:
logger.debug(f"唤醒度衰减: {old_value:.1f} -> {self.context.wakeup_value:.1f}")
self.context.save()
def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None) -> bool:
def add_wakeup_value(
self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None
) -> bool:
"""
增加唤醒度值
@@ -173,6 +175,7 @@ class WakeUpManager:
# 通知情绪管理系统进入愤怒状态
from src.mood.mood_manager import mood_manager
mood_manager.set_angry_from_wakeup(chat_id)
# 通知SleepManager重置睡眠状态
@@ -194,6 +197,7 @@ class WakeUpManager:
self.context.is_angry = False
# 通知情绪管理系统清除愤怒状态
from src.mood.mood_manager import mood_manager
if self.angry_chat_id:
mood_manager.clear_angry_from_wakeup(self.angry_chat_id)
self.angry_chat_id = None

View File

@@ -191,7 +191,7 @@ class ChatBot:
try:
# 检查聊天类型限制
if not plus_command_instance.is_chat_type_allowed():
is_group = message.message_info.group_info
is_group = message.message_info.group_info
logger.info(
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
)
@@ -424,7 +424,9 @@ class ChatBot:
await message.process()
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
logger.info(f"\u001b[38;5;118m{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m")
logger.info(
f"\u001b[38;5;118m{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m"
)
# 过滤检查
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
@@ -456,7 +458,7 @@ class ChatBot:
result = await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message=message)
if not result.all_continue_process():
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
# TODO:暂不可用
# 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default:
@@ -473,14 +475,14 @@ class ChatBot:
async def preprocess():
# 存储消息到数据库
from .storage import MessageStorage
try:
await MessageStorage.store_message(message, message.chat_stream)
logger.debug(f"消息已存储到数据库: {message.message_info.message_id}")
except Exception as e:
logger.error(f"存储消息到数据库失败: {e}")
traceback.print_exc()
# 使用消息管理器处理消息(保持原有功能)
from src.common.data_models.database_data_model import DatabaseMessages

View File

@@ -89,7 +89,7 @@ class ChatStream:
# 复制 stream_context但跳过 processing_task
new_stream.stream_context = copy.deepcopy(self.stream_context, memo)
if hasattr(new_stream.stream_context, 'processing_task'):
if hasattr(new_stream.stream_context, "processing_task"):
new_stream.stream_context.processing_task = None
# 复制 context_manager
@@ -377,6 +377,7 @@ class ChatStream:
# 默认基础分
return 0.3
class ChatManager:
"""聊天管理器,管理所有聊天流"""
@@ -563,9 +564,8 @@ class ChatManager:
if not hasattr(stream, "context_manager"):
# 创建新的单流上下文管理器
from src.chat.message_manager.context_manager import SingleStreamContextManager
stream.context_manager = SingleStreamContextManager(
stream_id=stream_id, context=stream.stream_context
)
stream.context_manager = SingleStreamContextManager(stream_id=stream_id, context=stream.stream_context)
# 保存到内存和数据库
self.streams[stream_id] = stream
@@ -721,6 +721,7 @@ class ChatManager:
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"):
from src.chat.message_manager.context_manager import SingleStreamContextManager
stream.context_manager = SingleStreamContextManager(
stream_id=stream.stream_id, context=stream.stream_context
)

View File

@@ -108,7 +108,7 @@ class MessageRecv(Message):
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message")
self.chat_stream = None
self.reply = None
self.processed_plain_text = message_dict.get("processed_plain_text", "")

View File

@@ -53,7 +53,11 @@ class MessageStorage:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
else:
# 如果没有设置display_message使用processed_plain_text作为显示消息
filtered_display_message = re.sub(pattern, "", message.processed_plain_text, flags=re.DOTALL) if message.processed_plain_text else ""
filtered_display_message = (
re.sub(pattern, "", message.processed_plain_text, flags=re.DOTALL)
if message.processed_plain_text
else ""
)
interest_value = 0
is_mentioned = False
reply_to = message.reply_to
@@ -168,9 +172,11 @@ class MessageStorage:
from src.common.database.sqlalchemy_models import get_db_session
async with get_db_session() as session:
matched_message = (await session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
)).scalar()
matched_message = (
await session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
)
).scalar()
if matched_message:
await session.execute(
@@ -204,9 +210,11 @@ class MessageStorage:
from src.common.database.sqlalchemy_models import get_db_session
async with get_db_session() as session:
image_record = (await session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
)).scalar()
image_record = (
await session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
)
).scalar()
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
except Exception:
return match.group(0)
@@ -287,15 +295,19 @@ class MessageStorage:
from src.common.database.sqlalchemy_models import Messages
# 查找需要修复的记录interest_value为0、null或很小的值
query = select(Messages).where(
(Messages.chat_id == chat_id) &
(Messages.time >= since_time) &
(
(Messages.interest_value == 0) |
(Messages.interest_value.is_(None)) |
(Messages.interest_value < 0.1)
query = (
select(Messages)
.where(
(Messages.chat_id == chat_id)
& (Messages.time >= since_time)
& (
(Messages.interest_value == 0)
| (Messages.interest_value.is_(None))
| (Messages.interest_value < 0.1)
)
)
).limit(50) # 限制每次修复的数量,避免性能问题
.limit(50)
) # 限制每次修复的数量,避免性能问题
result = await session.execute(query)
messages_to_fix = result.scalars().all()
@@ -307,7 +319,7 @@ class MessageStorage:
default_interest = 0.3 # 默认中等兴趣度
# 如果消息内容较长,可能是重要消息,兴趣度稍高
if hasattr(msg, 'processed_plain_text') and msg.processed_plain_text:
if hasattr(msg, "processed_plain_text") and msg.processed_plain_text:
text_length = len(msg.processed_plain_text)
if text_length > 50: # 长消息
default_interest = 0.4
@@ -315,13 +327,15 @@ class MessageStorage:
default_interest = 0.35
# 如果是被@的消息,兴趣度更高
if getattr(msg, 'is_mentioned', False):
if getattr(msg, "is_mentioned", False):
default_interest = min(default_interest + 0.2, 0.8)
# 执行更新
update_stmt = update(Messages).where(
Messages.message_id == msg.message_id
).values(interest_value=default_interest)
update_stmt = (
update(Messages)
.where(Messages.message_id == msg.message_id)
.values(interest_value=default_interest)
)
result = await session.execute(update_stmt)
if result.rowcount > 0:

View File

@@ -40,7 +40,7 @@ class ChatterActionManager:
@staticmethod
def create_action(
action_name: str,
action_name: str,
action_data: dict,
reasoning: str,
cycle_timers: dict,
@@ -162,7 +162,7 @@ class ChatterActionManager:
Returns:
执行结果
"""
from src.chat.message_manager.message_manager import message_manager
try:
logger.debug(f"🎯 [ActionManager] execute_action接收到 target_message: {target_message}")
# 通过chat_id获取chat_stream
@@ -309,9 +309,7 @@ class ChatterActionManager:
# 通过message_manager更新消息的动作记录并刷新focus_energy
await message_manager.add_action(
stream_id=chat_stream.stream_id,
message_id=target_message_id,
action=action_name
stream_id=chat_stream.stream_id, message_id=target_message_id, action=action_name
)
logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy")
@@ -321,9 +319,10 @@ class ChatterActionManager:
async def _reset_interruption_count_after_action(self, stream_id: str):
"""在动作执行成功后重置打断计数"""
from src.chat.message_manager.message_manager import message_manager
try:
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if chat_stream:
@@ -332,7 +331,9 @@ class ChatterActionManager:
old_count = context.context.interruption_count
old_afc_adjustment = context.context.get_afc_threshold_adjustment()
context.context.reset_interruption_count()
logger.debug(f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0")
logger.debug(
f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0"
)
except Exception as e:
logger.warning(f"重置打断计数时出错: {e}")
@@ -531,7 +532,7 @@ class ChatterActionManager:
# 根据新消息数量决定是否需要引用回复
reply_text = ""
is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True
logger.debug(f"[send_response] message_data: {message_data}")
first_replied = False
@@ -558,7 +559,9 @@ class ChatterActionManager:
# 发送第一段回复
if not first_replied:
set_reply_flag = bool(message_data)
logger.debug(f"📤 [ActionManager] 准备发送第一段回复。message_data: {message_data}, set_reply: {set_reply_flag}")
logger.debug(
f"📤 [ActionManager] 准备发送第一段回复。message_data: {message_data}, set_reply: {set_reply_flag}"
)
await send_api.text_to_stream(
text=data,
stream_id=chat_stream.stream_id,
@@ -577,4 +580,4 @@ class ChatterActionManager:
typing=True,
)
return reply_text
return reply_text

View File

@@ -29,6 +29,7 @@ from src.chat.utils.chat_message_builder import (
replace_user_references_sync,
)
from src.chat.express.expression_selector import expression_selector
# 旧记忆系统已被移除
# 旧记忆系统已被移除
from src.mood.mood_manager import mood_manager
@@ -580,7 +581,9 @@ class DefaultReplyer:
memory_context["user_aliases"] = memory_aliases
if group_info_obj is not None:
group_name = getattr(group_info_obj, "group_name", None) or getattr(group_info_obj, "group_nickname", None)
group_name = getattr(group_info_obj, "group_name", None) or getattr(
group_info_obj, "group_nickname", None
)
if group_name:
memory_context["group_name"] = str(group_name)
group_id = getattr(group_info_obj, "group_id", None)
@@ -594,11 +597,7 @@ class DefaultReplyer:
# 检索相关记忆
enhanced_memories = await memory_system.retrieve_relevant_memories(
query=target,
user_id=memory_user_id,
scope_id=stream.stream_id,
context=memory_context,
limit=10
query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10
)
# 注意:记忆存储已迁移到回复生成完成后进行,不在查询阶段执行
@@ -609,23 +608,27 @@ class DefaultReplyer:
logger.debug(f"[记忆转换] 收到 {len(enhanced_memories)} 条原始记忆")
for idx, memory_chunk in enumerate(enhanced_memories, 1):
# 获取结构化内容的字符串表示
structure_display = str(memory_chunk.content) if hasattr(memory_chunk, 'content') else "unknown"
structure_display = str(memory_chunk.content) if hasattr(memory_chunk, "content") else "unknown"
# 获取记忆内容优先使用display
content = memory_chunk.display or memory_chunk.text_content or ""
# 调试:记录每条记忆的内容获取情况
logger.debug(f"[记忆转换] 第{idx}条: display={repr(memory_chunk.display)[:80]}, text_content={repr(memory_chunk.text_content)[:80]}, final_content={repr(content)[:80]}")
running_memories.append({
"content": content,
"memory_type": memory_chunk.memory_type.value,
"confidence": memory_chunk.metadata.confidence.value,
"importance": memory_chunk.metadata.importance.value,
"relevance": getattr(memory_chunk.metadata, 'relevance_score', 0.5),
"source": memory_chunk.metadata.source,
"structure": structure_display,
})
logger.debug(
f"[记忆转换] 第{idx}条: display={repr(memory_chunk.display)[:80]}, text_content={repr(memory_chunk.text_content)[:80]}, final_content={repr(content)[:80]}"
)
running_memories.append(
{
"content": content,
"memory_type": memory_chunk.memory_type.value,
"confidence": memory_chunk.metadata.confidence.value,
"importance": memory_chunk.metadata.importance.value,
"relevance": getattr(memory_chunk.metadata, "relevance_score", 0.5),
"source": memory_chunk.metadata.source,
"structure": structure_display,
}
)
# 构建瞬时记忆字符串
if running_memories:
@@ -633,7 +636,9 @@ class DefaultReplyer:
if top_memory:
instant_memory = top_memory[0].get("content", "")
logger.info(f"增强记忆系统检索到 {len(enhanced_memories)} 条原始记忆,转换为 {len(running_memories)} 条可用记忆")
logger.info(
f"增强记忆系统检索到 {len(enhanced_memories)} 条原始记忆,转换为 {len(running_memories)} 条可用记忆"
)
except Exception as e:
logger.warning(f"增强记忆系统检索失败: {e}")
@@ -650,17 +655,17 @@ class DefaultReplyer:
memory_parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
# 按相关度排序,并记录相关度信息用于调试
sorted_memories = sorted(running_memories, key=lambda x: x.get('relevance', 0.0), reverse=True)
sorted_memories = sorted(running_memories, key=lambda x: x.get("relevance", 0.0), reverse=True)
# 调试相关度信息
relevance_info = [(m.get('memory_type', 'unknown'), m.get('relevance', 0.0)) for m in sorted_memories]
relevance_info = [(m.get("memory_type", "unknown"), m.get("relevance", 0.0)) for m in sorted_memories]
logger.debug(f"记忆相关度信息: {relevance_info}")
logger.debug(f"[记忆构建] 准备将 {len(sorted_memories)} 条记忆添加到提示词")
for idx, running_memory in enumerate(sorted_memories, 1):
content = running_memory.get('content', '')
memory_type = running_memory.get('memory_type', 'unknown')
content = running_memory.get("content", "")
memory_type = running_memory.get("memory_type", "unknown")
# 跳过空内容
if not content or not content.strip():
logger.warning(f"[记忆构建] 跳过第 {idx} 条记忆:内容为空 (type={memory_type})")
@@ -822,10 +827,10 @@ class DefaultReplyer:
"""
try:
# 从message_manager获取真实的已读/未读消息
from src.chat.message_manager.message_manager import message_manager
# 获取聊天流的上下文
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(chat_id)
if chat_stream:
@@ -1021,7 +1026,9 @@ class DefaultReplyer:
interest_scores = {}
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system as interest_scoring_system
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system as interest_scoring_system,
)
from src.common.data_models.database_data_model import DatabaseMessages
# 转换消息格式
@@ -1204,7 +1211,7 @@ class DefaultReplyer:
platform, # type: ignore
reply_message.get("user_id"), # type: ignore
reply_message.get("user_nickname"),
reply_message.get("user_cardname")
reply_message.get("user_cardname"),
)
# 检查是否是bot自己的名字如果是则替换为"(你)"
@@ -1763,6 +1770,7 @@ class DefaultReplyer:
# 创建关系追踪器实例
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
relationship_tracker = ChatterRelationshipTracker(chatter_interest_scoring_system)
if relationship_tracker:
# 获取用户信息以获取真实的user_id
@@ -1805,7 +1813,7 @@ class DefaultReplyer:
async def _store_chat_memory_async(self, reply_to: str, reply_message: Optional[Dict[str, Any]] = None):
"""
异步存储聊天记忆从build_memory_block迁移而来
Args:
reply_to: 回复对象
reply_message: 回复的原始消息
@@ -1874,9 +1882,7 @@ class DefaultReplyer:
memory_aliases.append(stripped)
alias_values = (
user_info_dict.get("aliases")
or user_info_dict.get("alias_names")
or user_info_dict.get("alias")
user_info_dict.get("aliases") or user_info_dict.get("alias_names") or user_info_dict.get("alias")
)
if isinstance(alias_values, (list, tuple, set)):
for alias in alias_values:
@@ -1900,7 +1906,9 @@ class DefaultReplyer:
memory_context["user_aliases"] = memory_aliases
if group_info_obj is not None:
group_name = getattr(group_info_obj, "group_name", None) or getattr(group_info_obj, "group_nickname", None)
group_name = getattr(group_info_obj, "group_name", None) or getattr(
group_info_obj, "group_nickname", None
)
if group_name:
memory_context["group_name"] = str(group_name)
group_id = getattr(group_info_obj, "group_id", None)
@@ -1932,11 +1940,11 @@ class DefaultReplyer:
"conversation_text": chat_history,
"user_id": memory_user_id,
"scope_id": stream.stream_id,
**memory_context
**memory_context,
}
)
)
logger.debug(f"已启动记忆存储任务,用户: {memory_user_display or memory_user_id}")
except Exception as e:

View File

@@ -13,6 +13,7 @@ from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_m
from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select, and_
from src.common.logger import get_logger
logger = get_logger("chat_message_builder")
install(extra_lines=3)
@@ -274,21 +275,52 @@ async def get_actions_by_timestamp_with_chat(
async with get_db_session() as session:
if limit > 0:
result = await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end,
)
)
.order_by(ActionRecords.time.desc())
.limit(limit)
)
actions = list(result.scalars())
actions_result = []
for action in reversed(actions):
action_dict = {
"id": action.id,
"action_id": action.action_id,
"time": action.time,
"action_name": action.action_name,
"action_data": action.action_data,
"action_done": action.action_done,
"action_build_into_prompt": action.action_build_into_prompt,
"action_prompt_display": action.action_prompt_display,
"chat_id": action.chat_id,
"chat_info_stream_id": action.chat_info_stream_id,
"chat_info_platform": action.chat_info_platform,
}
actions_result.append(action_dict)
actions_result.append(action_dict)
else: # earliest
result = await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end,
)
)
.order_by(ActionRecords.time.desc())
.order_by(ActionRecords.time.asc())
.limit(limit)
)
actions = list(result.scalars())
actions_result = []
for action in reversed(actions):
for action in actions:
action_dict = {
"id": action.id,
"action_id": action.action_id,
@@ -303,37 +335,6 @@ async def get_actions_by_timestamp_with_chat(
"chat_info_platform": action.chat_info_platform,
}
actions_result.append(action_dict)
actions_result.append(action_dict)
else: # earliest
result = await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end,
)
)
.order_by(ActionRecords.time.asc())
.limit(limit)
)
actions = list(result.scalars())
actions_result = []
for action in actions:
action_dict = {
"id": action.id,
"action_id": action.action_id,
"time": action.time,
"action_name": action.action_name,
"action_data": action.action_data,
"action_done": action.action_done,
"action_build_into_prompt": action.action_build_into_prompt,
"action_prompt_display": action.action_prompt_display,
"chat_id": action.chat_id,
"chat_info_stream_id": action.chat_info_stream_id,
"chat_info_platform": action.chat_info_platform,
}
actions_result.append(action_dict)
else:
result = await session.execute(
select(ActionRecords)
@@ -457,7 +458,9 @@ async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp_with_chat(
chat_id: str, timestamp: float, limit: int = 0
) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
@@ -466,7 +469,9 @@ async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float,
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
async def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp_with_users(
timestamp: float, person_ids: list, limit: int = 0
) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
@@ -475,7 +480,9 @@ async def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids:
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
async def num_new_messages_since(
chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None
) -> int:
"""
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。
@@ -830,7 +837,7 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
async with get_db_session() as session:
result = await session.execute(select(Images).where(Images.image_id == pic_id))
image = result.scalar_one_or_none()
if image and hasattr(image, 'description') and image.description:
if image and hasattr(image, "description") and image.description:
description = image.description
except Exception as e:
# 如果查询失败,保持默认描述
@@ -1017,24 +1024,29 @@ async def build_readable_messages(
async with get_db_session() as session:
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = (await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id
actions_in_range = (
await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.time >= min_time,
ActionRecords.time <= max_time,
ActionRecords.chat_id == chat_id,
)
)
.order_by(ActionRecords.time)
)
.order_by(ActionRecords.time)
)).scalars()
).scalars()
# 获取最新消息之后的第一个动作记录
action_after_latest = (await session.execute(
select(ActionRecords)
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
)).scalars()
action_after_latest = (
await session.execute(
select(ActionRecords)
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
)
).scalars()
# 合并两部分动作记录,并转为 dict避免 DetachedInstanceError
actions = [
@@ -1225,9 +1237,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
except Exception:
return "?"
content = await replace_user_references_async(
content, platform, anon_name_resolver, replace_bot_name=False
)
content = await replace_user_references_async(content, platform, anon_name_resolver, replace_bot_name=False)
header = f"{anon_name}"
output_lines.append(header)

View File

@@ -17,7 +17,7 @@ MEMORY_TYPE_CHINESE_MAPPING = {
"goal": "目标计划",
"experience": "经验教训",
"contextual": "上下文信息",
"unknown": "未知"
"unknown": "未知",
}
# 置信度等级到中文标签的映射表
@@ -30,7 +30,7 @@ CONFIDENCE_LEVEL_CHINESE_MAPPING = {
"MEDIUM": "中等置信度",
"HIGH": "高置信度",
"VERIFIED": "已验证",
"unknown": "未知"
"unknown": "未知",
}
# 重要性等级到中文标签的映射表
@@ -43,7 +43,7 @@ IMPORTANCE_LEVEL_CHINESE_MAPPING = {
"NORMAL": "一般重要性",
"HIGH": "高重要性",
"CRITICAL": "关键重要性",
"unknown": "未知"
"unknown": "未知",
}
@@ -69,7 +69,7 @@ def get_confidence_level_chinese_label(level) -> str:
str: 对应的中文标签,如果找不到则返回"未知"
"""
# 处理枚举实例
if hasattr(level, 'value'):
if hasattr(level, "value"):
level = level.value
# 处理数字
@@ -94,7 +94,7 @@ def get_importance_level_chinese_label(level) -> str:
str: 对应的中文标签,如果找不到则返回"未知"
"""
# 处理枚举实例
if hasattr(level, 'value'):
if hasattr(level, "value"):
level = level.value
# 处理数字
@@ -106,4 +106,4 @@ def get_importance_level_chinese_label(level) -> str:
level_upper = level.upper()
return IMPORTANCE_LEVEL_CHINESE_MAPPING.get(level_upper, "未知")
return "未知"
return "未知"

View File

@@ -381,12 +381,12 @@ class Prompt:
# 性能优化 - 为不同任务设置不同的超时时间
task_timeouts = {
"memory_block": 15.0, # 记忆系统 - 降低超时时间,鼓励预构建
"tool_info": 15.0, # 工具信息
"relation_info": 10.0, # 关系信息
"knowledge_info": 10.0, # 知识库查询
"cross_context": 10.0, # 上下文处理
"expression_habits": 10.0, # 表达习惯
"memory_block": 15.0, # 记忆系统 - 降低超时时间,鼓励预构建
"tool_info": 15.0, # 工具信息
"relation_info": 10.0, # 关系信息
"knowledge_info": 10.0, # 知识库查询
"cross_context": 10.0, # 上下文处理
"expression_habits": 10.0, # 表达习惯
}
# 分别处理每个任务,避免慢任务影响快任务
@@ -563,7 +563,7 @@ class Prompt:
),
enhanced_memory_activator.get_instant_memory(
target_message=self.parameters.target, chat_id=self.parameters.chat_id
)
),
]
try:
@@ -606,26 +606,27 @@ class Prompt:
"opinion": "opinion",
"personal_fact": "personal_fact",
"preference": "preference",
"event": "event"
"event": "event",
}
mapped_type = memory_type_mapping.get(topic, "personal_fact")
formatted_memories.append({
"display": display_text,
"memory_type": mapped_type,
"metadata": {
"confidence": memory.get("confidence", "未知"),
"importance": memory.get("importance", "一般"),
"timestamp": memory.get("timestamp", ""),
"source": memory.get("source", "unknown"),
"relevance_score": memory.get("relevance_score", 0.0)
formatted_memories.append(
{
"display": display_text,
"memory_type": mapped_type,
"metadata": {
"confidence": memory.get("confidence", "未知"),
"importance": memory.get("importance", "一般"),
"timestamp": memory.get("timestamp", ""),
"source": memory.get("source", "unknown"),
"relevance_score": memory.get("relevance_score", 0.0),
},
}
})
)
# 使用方括号格式格式化记忆
memory_block = format_memories_bracket_style(
formatted_memories,
query_context=self.parameters.target
formatted_memories, query_context=self.parameters.target
)
except Exception as e:
logger.warning(f"记忆格式化失败,使用简化格式: {e}")
@@ -833,7 +834,8 @@ class Prompt:
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
"chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
"chat_scene": self.parameters.chat_scene
or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
}
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
@@ -860,7 +862,8 @@ class Prompt:
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
"chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
"chat_scene": self.parameters.chat_scene
or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
}
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:

View File

@@ -305,11 +305,14 @@ class StatisticOutputTask(AsyncTask):
# 以最早的时间戳为起始时间获取记录
query_start_time = collect_period[-1][1]
records = await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp",
) or []
records = (
await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp",
)
or []
)
for record in records:
if not isinstance(record, dict):
@@ -401,7 +404,9 @@ class StatisticOutputTask(AsyncTask):
return stats
@staticmethod
async def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
async def _collect_online_time_for_period(
collect_period: List[Tuple[str, datetime]], now: datetime
) -> Dict[str, Any]:
"""
收集指定时间段的在线时间统计数据
@@ -420,11 +425,14 @@ class StatisticOutputTask(AsyncTask):
}
query_start_time = collect_period[-1][1]
records = await db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": query_start_time}},
order_by="-end_timestamp",
) or []
records = (
await db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": query_start_time}},
order_by="-end_timestamp",
)
or []
)
for record in records:
if not isinstance(record, dict):
@@ -476,11 +484,14 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
records = await db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time",
) or []
records = (
await db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time",
)
or []
)
for message in records:
if not isinstance(message, dict):
@@ -1038,11 +1049,14 @@ class StatisticOutputTask(AsyncTask):
interval_seconds = interval_minutes * 60
# 单次查询 LLMUsage
llm_records = await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": start_time}},
order_by="-timestamp",
) or []
llm_records = (
await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": start_time}},
order_by="-timestamp",
)
or []
)
for record in llm_records:
if not isinstance(record, dict) or not record.get("timestamp"):
continue
@@ -1068,11 +1082,14 @@ class StatisticOutputTask(AsyncTask):
cost_by_module[module_name][idx] += cost
# 单次查询 Messages
msg_records = await db_get(
model_class=Messages,
filters={"time": {"$gte": start_time.timestamp()}},
order_by="-time",
) or []
msg_records = (
await db_get(
model_class=Messages,
filters={"time": {"$gte": start_time.timestamp()}},
order_by="-time",
)
or []
)
for msg in msg_records:
if not isinstance(msg, dict) or not msg.get("time"):
continue
@@ -1375,4 +1392,4 @@ class StatisticOutputTask(AsyncTask):
}});
</script>
</div>
"""
"""

View File

@@ -675,7 +675,6 @@ async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Di
if loop.is_running():
# 如果事件循环在运行,从其他线程提交并等待结果
try:
from concurrent.futures import TimeoutError
fut = asyncio.run_coroutine_threadsafe(
person_info_manager.get_value(person_id, "person_name"), loop

View File

@@ -81,14 +81,16 @@ class ImageManager:
"""
try:
async with get_db_session() as session:
record = (await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
record = (
await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
)
)
)
)).scalar()
).scalar()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
@@ -107,14 +109,16 @@ class ImageManager:
current_timestamp = time.time()
async with get_db_session() as session:
# 查找现有记录
existing = (await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
existing = (
await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
)
)
)
)).scalar()
).scalar()
if existing:
# 更新现有记录
@@ -262,9 +266,11 @@ class ImageManager:
from src.common.database.sqlalchemy_models import get_db_session
async with get_db_session() as session:
existing_img = (await session.execute(
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
)).scalar()
existing_img = (
await session.execute(
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
)
).scalar()
if existing_img:
existing_img.path = file_path

View File

@@ -35,7 +35,7 @@ logger = get_logger("utils_video")
# Rust模块可用性检测
RUST_VIDEO_AVAILABLE = False
try:
import rust_video # pyright: ignore[reportMissingImports]
import rust_video # pyright: ignore[reportMissingImports]
RUST_VIDEO_AVAILABLE = True
logger.info("✅ Rust 视频处理模块加载成功")
@@ -222,7 +222,7 @@ class VideoAnalyzer:
return None
async def _store_video_result(
self, video_hash: str, description: str, metadata: Optional[Dict] = None
self, video_hash: str, description: str, metadata: Optional[Dict] = None
) -> Optional[Videos]:
"""存储视频分析结果到数据库"""
# 检查描述是否为错误信息,如果是则不保存