From ecb02cae316b8f2381a1b9593478ed2fba2ad34e Mon Sep 17 00:00:00 2001 From: John Richard Date: Thu, 2 Oct 2025 19:38:39 +0800 Subject: [PATCH] =?UTF-8?q?style:=20=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 3 +- scripts/lpmm_learning_tool.py | 10 +- scripts/rebuild_metadata_index.py | 64 +-- scripts/run_multi_stage_smoke.py | 5 +- .../management/statistics.py | 16 +- src/chat/chatter_manager.py | 12 +- src/chat/emoji_system/emoji_history.py | 1 - src/chat/emoji_system/emoji_manager.py | 9 +- src/chat/energy_system/__init__.py | 6 +- src/chat/energy_system/energy_manager.py | 51 +- src/chat/express/expression_learner.py | 9 +- .../interest_system/bot_interest_manager.py | 52 +- src/chat/knowledge/embedding_store.py | 2 +- src/chat/knowledge/qa_manager.py | 27 +- src/chat/memory_system/__init__.py | 44 +- .../enhanced_memory_adapter.py | 72 +-- .../enhanced_memory_hooks.py | 23 +- .../enhanced_memory_integration.py | 61 +-- .../deprecated_backup/enhanced_reranker.py | 218 +++++--- .../deprecated_backup/integration_layer.py | 29 +- .../memory_integration_hooks.py | 67 +-- .../deprecated_backup/metadata_index.py | 126 ++--- .../multi_stage_retrieval.py | 314 ++++++----- .../deprecated_backup/vector_storage.py | 155 +++--- .../enhanced_memory_activator.py | 30 +- .../memory_system/memory_activator_new.py | 30 +- src/chat/memory_system/memory_builder.py | 149 +++--- src/chat/memory_system/memory_chunk.py | 135 +++-- .../memory_system/memory_forgetting_engine.py | 57 +- src/chat/memory_system/memory_fusion.py | 75 +-- src/chat/memory_system/memory_manager.py | 57 +- .../memory_system/memory_metadata_index.py | 137 +++-- .../memory_system/memory_query_planner.py | 19 +- src/chat/memory_system/memory_system.py | 343 ++++++------ .../memory_system/vector_memory_storage_v2.py | 489 +++++++++--------- src/chat/message_manager/__init__.py | 9 +- src/chat/message_manager/context_manager.py | 2 + .../message_manager/distribution_manager.py | 16 +- src/chat/message_manager/message_manager.py | 39 +- .../sleep_manager/notification_sender.py | 40 +- .../sleep_manager/sleep_manager.py | 32 +- .../sleep_manager/sleep_state.py | 3 +- .../sleep_manager/time_checker.py | 42 +- .../sleep_manager/wakeup_context.py | 4 +- .../sleep_manager/wakeup_manager.py | 10 +- src/chat/message_receive/bot.py | 12 +- src/chat/message_receive/chat_stream.py | 9 +- src/chat/message_receive/message.py | 2 +- src/chat/message_receive/storage.py | 54 +- src/chat/planner_actions/action_manager.py | 23 +- src/chat/replyer/default_generator.py | 80 +-- src/chat/utils/chat_message_builder.py | 121 +++-- src/chat/utils/memory_mappings.py | 12 +- src/chat/utils/prompt.py | 47 +- src/chat/utils/statistic.py | 71 ++- src/chat/utils/utils.py | 1 - src/chat/utils/utils_image.py | 36 +- src/chat/utils/utils_video.py | 4 +- src/common/cache_manager.py | 5 +- src/common/database/db_migration.py | 15 +- .../database/sqlalchemy_database_api.py | 2 +- src/common/database/sqlalchemy_models.py | 1 + src/common/logger.py | 8 +- src/common/vector_db/chromadb_impl.py | 22 +- src/config/config.py | 2 +- src/config/official_configs.py | 77 ++- src/individuality/individuality.py | 6 +- .../model_client/aiohttp_gemini_client.py | 4 +- src/llm_models/payload_content/message.py | 2 +- src/llm_models/utils_model.py | 257 +++++---- src/main.py | 3 +- src/mais4u/mais4u_chat/s4u_msg_processor.py | 4 +- src/mais4u/mais4u_chat/s4u_prompt.py | 7 +- src/mood/mood_manager.py | 2 +- src/person_info/person_info.py | 13 +- src/person_info/relationship_builder.py | 20 +- src/person_info/relationship_fetcher.py | 1 - src/plugin_system/apis/__init__.py | 2 +- src/plugin_system/apis/message_api.py | 10 +- src/plugin_system/apis/permission_api.py | 3 +- src/plugin_system/apis/schedule_api.py | 3 +- src/plugin_system/apis/send_api.py | 4 +- src/plugin_system/base/base_action.py | 18 +- src/plugin_system/base/base_chatter.py | 9 +- src/plugin_system/base/base_tool.py | 36 +- src/plugin_system/base/component_types.py | 2 +- src/plugin_system/base/plugin_base.py | 4 +- src/plugin_system/core/component_registry.py | 110 ++-- src/plugin_system/core/permission_manager.py | 27 +- src/plugin_system/core/plugin_manager.py | 7 +- src/plugin_system/core/tool_use.py | 22 +- .../utils/permission_decorators.py | 8 +- .../affinity_flow_chatter/interest_scoring.py | 4 +- .../affinity_flow_chatter/plan_executor.py | 16 +- .../affinity_flow_chatter/plan_filter.py | 97 ++-- .../built_in/affinity_flow_chatter/planner.py | 10 +- .../relationship_tracker.py | 2 +- src/plugins/built_in/core_actions/emoji.py | 4 +- .../built_in/knowledge/lpmm_get_knowledge.py | 4 +- .../built_in/maizone_refactored/plugin.py | 1 + .../services/cookie_service.py | 4 +- .../services/qzone_service.py | 19 +- .../src/recv_handler/message_handler.py | 8 +- .../napcat_adapter_plugin/src/send_handler.py | 5 +- .../napcat_adapter_plugin/src/utils.py | 9 +- .../built_in/plugin_management/plugin.py | 11 +- .../built_in/proactive_thinker/plugin.py | 12 +- .../proacive_thinker_event.py | 12 +- .../proactive_thinker_executor.py | 138 ++--- .../built_in/social_toolkit_plugin/plugin.py | 2 - src/schedule/database.py | 20 +- 111 files changed, 2344 insertions(+), 2316 deletions(-) diff --git a/bot.py b/bot.py index 0399b0d19..798247c96 100644 --- a/bot.py +++ b/bot.py @@ -21,7 +21,7 @@ initialize_logging() from src.main import MainSystem # noqa from src import BaseMain # noqa from src.manager.async_task_manager import async_task_manager # noqa -from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge # noqa +from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge # noqa from src.config.config import global_config # noqa from src.common.database.database import initialize_sql_database # noqa from src.common.database.sqlalchemy_models import initialize_database as init_db # noqa @@ -269,4 +269,3 @@ if __name__ == "__main__": # 在程序退出前暂停,让你有机会看到输出 # input("按 Enter 键退出...") # <--- 添加这行 sys.exit(exit_code) # <--- 使用记录的退出码 - \ No newline at end of file diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index f0888d552..9caafc7fd 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -40,6 +40,7 @@ file_lock = Lock() # --- 缓存清理 --- + def clear_cache(): """清理 lpmm_learning_tool.py 生成的缓存文件""" logger.info("--- 开始清理缓存 ---") @@ -53,6 +54,7 @@ def clear_cache(): logger.info("缓存目录不存在,无需清理。") logger.info("--- 缓存清理完成 ---") + # --- 模块一:数据预处理 --- @@ -108,7 +110,7 @@ def _parse_and_repair_json(json_string: str) -> Optional[dict]: cleaned_string = cleaned_string[7:].strip() elif cleaned_string.startswith("```"): cleaned_string = cleaned_string[3:].strip() - + if cleaned_string.endswith("```"): cleaned_string = cleaned_string[:-3].strip() @@ -117,7 +119,7 @@ def _parse_and_repair_json(json_string: str) -> Optional[dict]: return orjson.loads(cleaned_string) except orjson.JSONDecodeError: logger.warning("直接解析JSON失败,将尝试修复...") - + # 3. 修复与最终解析 repaired_json_str = "" try: @@ -164,10 +166,10 @@ async def extract_info_async(pg_hash, paragraph, llm_api): content = None try: content, (_, _, _) = await llm_api.generate_response_async(prompt) - + # 改进点:调用封装好的函数处理JSON解析和修复 extracted_data = _parse_and_repair_json(content) - + if extracted_data is None: # 如果解析失败,抛出异常以触发统一的错误处理逻辑 raise ValueError("无法从LLM输出中解析有效的JSON数据") diff --git a/scripts/rebuild_metadata_index.py b/scripts/rebuild_metadata_index.py index f5cf652ab..d1990fecc 100644 --- a/scripts/rebuild_metadata_index.py +++ b/scripts/rebuild_metadata_index.py @@ -3,6 +3,7 @@ """ 从现有ChromaDB数据重建JSON元数据索引 """ + import asyncio import sys import os @@ -15,53 +16,53 @@ from src.common.logger import get_logger logger = get_logger(__name__) + async def rebuild_metadata_index(): """从ChromaDB重建元数据索引""" - print("="*80) + print("=" * 80) print("重建JSON元数据索引") - print("="*80) - + print("=" * 80) + # 初始化记忆系统 print("\n🔧 初始化记忆系统...") ms = MemorySystem() await ms.initialize() print("✅ 记忆系统已初始化") - - if not hasattr(ms.unified_storage, 'metadata_index'): + + if not hasattr(ms.unified_storage, "metadata_index"): print("❌ 元数据索引管理器未初始化") return - + # 获取所有记忆 print("\n📥 从ChromaDB获取所有记忆...") from src.common.vector_db import vector_db_service - + try: # 获取集合中的所有记忆ID collection_name = ms.unified_storage.config.memory_collection result = vector_db_service.get( - collection_name=collection_name, - include=["documents", "metadatas", "embeddings"] + collection_name=collection_name, include=["documents", "metadatas", "embeddings"] ) - + if not result or not result.get("ids"): print("❌ ChromaDB中没有找到记忆数据") return - + ids = result["ids"] metadatas = result.get("metadatas", []) - + print(f"✅ 找到 {len(ids)} 条记忆") - + # 重建元数据索引 print("\n🔨 开始重建元数据索引...") entries = [] success_count = 0 - - for i, (memory_id, metadata) in enumerate(zip(ids, metadatas), 1): + + for i, (memory_id, metadata) in enumerate(zip(ids, metadatas, strict=False), 1): try: # 从ChromaDB元数据重建索引条目 import orjson - + entry = MemoryMetadataIndexEntry( memory_id=memory_id, user_id=metadata.get("user_id", "unknown"), @@ -75,9 +76,9 @@ async def rebuild_metadata_index(): created_at=metadata.get("created_at", 0.0), access_count=metadata.get("access_count", 0), chat_id=metadata.get("chat_id"), - content_preview=None + content_preview=None, ) - + # 尝试解析importance和confidence的枚举名称 if "importance" in metadata: imp_str = metadata["importance"] @@ -89,7 +90,7 @@ async def rebuild_metadata_index(): entry.importance = 3 elif imp_str == "CRITICAL": entry.importance = 4 - + if "confidence" in metadata: conf_str = metadata["confidence"] if conf_str == "LOW": @@ -100,40 +101,41 @@ async def rebuild_metadata_index(): entry.confidence = 3 elif conf_str == "VERIFIED": entry.confidence = 4 - + entries.append(entry) success_count += 1 - + if i % 100 == 0: print(f" 处理进度: {i}/{len(ids)} ({success_count} 成功)") - + except Exception as e: logger.warning(f"处理记忆 {memory_id} 失败: {e}") continue - + print(f"\n✅ 成功解析 {success_count}/{len(ids)} 条记忆元数据") - + # 批量更新索引 print("\n💾 保存元数据索引...") ms.unified_storage.metadata_index.batch_add_or_update(entries) ms.unified_storage.metadata_index.save() - + # 显示统计信息 stats = ms.unified_storage.metadata_index.get_stats() - print(f"\n📊 重建后的索引统计:") + print("\n📊 重建后的索引统计:") print(f" - 总记忆数: {stats['total_memories']}") print(f" - 主语数量: {stats['subjects_count']}") print(f" - 关键词数量: {stats['keywords_count']}") print(f" - 标签数量: {stats['tags_count']}") - print(f" - 类型分布:") - for mtype, count in stats['types'].items(): + print(" - 类型分布:") + for mtype, count in stats["types"].items(): print(f" - {mtype}: {count}") - + print("\n✅ 元数据索引重建完成!") - + except Exception as e: logger.error(f"重建索引失败: {e}", exc_info=True) print(f"❌ 重建索引失败: {e}") -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(rebuild_metadata_index()) diff --git a/scripts/run_multi_stage_smoke.py b/scripts/run_multi_stage_smoke.py index bfb3c417f..000336244 100644 --- a/scripts/run_multi_stage_smoke.py +++ b/scripts/run_multi_stage_smoke.py @@ -3,6 +3,7 @@ """ 轻量烟雾测试:初始化 MemorySystem 并运行一次检索,验证 MemoryMetadata.source 访问不再报错 """ + import asyncio import sys import os @@ -11,6 +12,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.chat.memory_system.memory_system import MemorySystem + async def main(): ms = MemorySystem() await ms.initialize() @@ -19,5 +21,6 @@ async def main(): for i, m in enumerate(results, 1): print(f"{i}. id={m.metadata.memory_id} source={getattr(m.metadata, 'source', None)}") -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index 4df22f152..9d44faa78 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -31,9 +31,11 @@ class AntiInjectionStatistics: try: async with get_db_session() as session: # 获取最新的统计记录,如果没有则创建 - stats = (await session.execute( - select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()) - )).scalars().first() + stats = ( + (await session.execute(select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()))) + .scalars() + .first() + ) if not stats: stats = AntiInjectionStats() session.add(stats) @@ -49,9 +51,11 @@ class AntiInjectionStatistics: """更新统计数据""" try: async with get_db_session() as session: - stats = (await session.execute( - select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()) - )).scalars().first() + stats = ( + (await session.execute(select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()))) + .scalars() + .first() + ) if not stats: stats = AntiInjectionStats() session.add(stats) diff --git a/src/chat/chatter_manager.py b/src/chat/chatter_manager.py index 7aee10562..d22d39440 100644 --- a/src/chat/chatter_manager.py +++ b/src/chat/chatter_manager.py @@ -2,13 +2,13 @@ from typing import Dict, List, Optional, Any import time from src.plugin_system.base.base_chatter import BaseChatter from src.common.data_models.message_manager_data_model import StreamContext -from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner as ActionPlanner from src.chat.planner_actions.action_manager import ChatterActionManager -from src.plugin_system.base.component_types import ChatType, ComponentType +from src.plugin_system.base.component_types import ChatType from src.common.logger import get_logger logger = get_logger("chatter_manager") + class ChatterManager: def __init__(self, action_manager: ChatterActionManager): self.action_manager = action_manager @@ -27,6 +27,7 @@ class ChatterManager: """从组件注册表自动注册已注册的chatter组件""" try: from src.plugin_system.core.component_registry import component_registry + # 获取所有CHATTER类型的组件 chatter_components = component_registry.get_enabled_chatter_registry() for chatter_name, chatter_class in chatter_components.items(): @@ -70,7 +71,7 @@ class ChatterManager: inactive_streams = [] for stream_id, instance in self.instances.items(): - if hasattr(instance, 'get_activity_time'): + if hasattr(instance, "get_activity_time"): activity_time = instance.get_activity_time() if (current_time - activity_time) > max_inactive_seconds: inactive_streams.append(stream_id) @@ -91,6 +92,7 @@ class ChatterManager: if not chatter_class: # 如果没有找到精确匹配,尝试查找支持ALL类型的chatter from src.plugin_system.base.component_types import ChatType + all_chatter_class = self.get_chatter_class(ChatType.ALL) if all_chatter_class: chatter_class = all_chatter_class @@ -110,6 +112,7 @@ class ChatterManager: # 从 mood_manager 获取最新的 chat_stream 并同步回 StreamContext try: from src.mood.mood_manager import mood_manager + mood = mood_manager.get_mood_by_chat_id(stream_id) if mood and mood.chat_stream: context.chat_stream = mood.chat_stream @@ -125,6 +128,7 @@ class ChatterManager: # 在处理完成后,清除该流的未读消息 try: from src.chat.message_manager.message_manager import message_manager + await message_manager.clear_stream_unread_messages(stream_id) except Exception as clear_e: logger.error(f"清除流 {stream_id} 未读消息时发生错误: {clear_e}") @@ -149,4 +153,4 @@ class ChatterManager: "streams_processed": 0, "successful_executions": 0, "failed_executions": 0, - } \ No newline at end of file + } diff --git a/src/chat/emoji_system/emoji_history.py b/src/chat/emoji_system/emoji_history.py index 804f61e0a..dadd152a1 100644 --- a/src/chat/emoji_system/emoji_history.py +++ b/src/chat/emoji_system/emoji_history.py @@ -3,7 +3,6 @@ 表情包发送历史记录模块 """ -import os from typing import List, Dict from collections import deque diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 9e4829a56..cd472ec0c 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -204,9 +204,7 @@ class MaiEmoji: # 2. 删除数据库记录 try: async with get_db_session() as session: - result = await session.execute( - select(Emoji).where(Emoji.emoji_hash == self.hash) - ) + result = await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)) will_delete_emoji = result.scalar_one_or_none() if will_delete_emoji is None: logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") @@ -946,10 +944,7 @@ class EmojiManager: existing_description = None try: async with get_db_session() as session: - stmt = select(Images).where( - Images.emoji_hash == image_hash, - Images.type == "emoji" - ) + stmt = select(Images).where(Images.emoji_hash == image_hash, Images.type == "emoji") result = await session.execute(stmt) existing_image = result.scalar_one_or_none() if existing_image and existing_image.description: diff --git a/src/chat/energy_system/__init__.py b/src/chat/energy_system/__init__.py index 0addfd070..6cdf96da5 100644 --- a/src/chat/energy_system/__init__.py +++ b/src/chat/energy_system/__init__.py @@ -12,7 +12,7 @@ from .energy_manager import ( ActivityEnergyCalculator, RecencyEnergyCalculator, RelationshipEnergyCalculator, - energy_manager + energy_manager, ) __all__ = [ @@ -24,5 +24,5 @@ __all__ = [ "ActivityEnergyCalculator", "RecencyEnergyCalculator", "RelationshipEnergyCalculator", - "energy_manager" -] \ No newline at end of file + "energy_manager", +] diff --git a/src/chat/energy_system/energy_manager.py b/src/chat/energy_system/energy_manager.py index b07222b4b..4a92349bf 100644 --- a/src/chat/energy_system/energy_manager.py +++ b/src/chat/energy_system/energy_manager.py @@ -17,16 +17,18 @@ logger = get_logger("energy_system") class EnergyLevel(Enum): """能量等级""" + VERY_LOW = 0.1 # 非常低 - LOW = 0.3 # 低 - NORMAL = 0.5 # 正常 - HIGH = 0.7 # 高 - VERY_HIGH = 0.9 # 非常高 + LOW = 0.3 # 低 + NORMAL = 0.5 # 正常 + HIGH = 0.7 # 高 + VERY_HIGH = 0.9 # 非常高 @dataclass class EnergyComponent: """能量组件""" + name: str value: float weight: float = 1.0 @@ -47,6 +49,7 @@ class EnergyComponent: class EnergyContext(TypedDict): """能量计算上下文""" + stream_id: str messages: List[Any] user_id: Optional[str] @@ -54,6 +57,7 @@ class EnergyContext(TypedDict): class EnergyResult(TypedDict): """能量计算结果""" + energy: float level: EnergyLevel distribution_interval: float @@ -114,12 +118,7 @@ class ActivityEnergyCalculator(EnergyCalculator): """活跃度能量计算器""" def __init__(self): - self.action_weights = { - "reply": 0.4, - "react": 0.3, - "mention": 0.2, - "other": 0.1 - } + self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1} def calculate(self, context: Dict[str, Any]) -> float: """基于活跃度计算能量""" @@ -188,7 +187,7 @@ class RecencyEnergyCalculator(EnergyCalculator): else: recency_score = 0.1 - logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age/3600:.1f}小时)") + logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age / 3600:.1f}小时)") return recency_score def get_weight(self) -> float: @@ -236,11 +235,7 @@ class EnergyManager: self.cache_ttl: int = 60 # 1分钟缓存 # AFC阈值配置 - self.thresholds: Dict[str, float] = { - "high_match": 0.8, - "reply": 0.4, - "non_reply": 0.2 - } + self.thresholds: Dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2} # 统计信息 self.stats: Dict[str, Union[int, float, str]] = { @@ -260,9 +255,13 @@ class EnergyManager: """从配置加载AFC阈值""" try: if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None: - self.thresholds["high_match"] = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8) + self.thresholds["high_match"] = getattr( + global_config.affinity_flow, "high_match_interest_threshold", 0.8 + ) self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4) - self.thresholds["non_reply"] = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2) + self.thresholds["non_reply"] = getattr( + global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2 + ) # 确保阈值关系合理 self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1) @@ -306,6 +305,7 @@ class EnergyManager: # 支持同步和异步计算器 if callable(calculator.calculate): import inspect + if inspect.iscoroutinefunction(calculator.calculate): score = await calculator.calculate(context) else: @@ -347,11 +347,12 @@ class EnergyManager: calculation_time = time.time() - start_time total_calculations = self.stats["total_calculations"] self.stats["average_calculation_time"] = ( - (self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time) - / total_calculations - ) + self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time + ) / total_calculations - logger.debug(f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)") + logger.debug( + f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)" + ) return final_energy def _apply_threshold_adjustment(self, energy: float) -> float: @@ -405,6 +406,7 @@ class EnergyManager: # 添加随机扰动避免同步 import random + jitter = random.uniform(0.8, 1.2) final_interval = base_interval * jitter @@ -424,7 +426,8 @@ class EnergyManager: """清理过期缓存""" current_time = time.time() expired_keys = [ - stream_id for stream_id, (_, timestamp) in self.energy_cache.items() + stream_id + for stream_id, (_, timestamp) in self.energy_cache.items() if current_time - timestamp > self.cache_ttl ] @@ -479,4 +482,4 @@ class EnergyManager: # 全局能量管理器实例 -energy_manager = EnergyManager() \ No newline at end of file +energy_manager = EnergyManager() diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index bb663a1ad..596322ebd 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -382,7 +382,7 @@ class ExpressionLearner: create_date=current_time, # 手动设置创建日期 ) session.add(new_expression) - + # 限制最大数量 exprs_result = await session.execute( select(Expression) @@ -492,11 +492,10 @@ class ExpressionLearnerManager: self._ensure_expression_directories() - async def get_expression_learner(self, chat_id: str) -> ExpressionLearner: await self._auto_migrate_json_to_db() await self._migrate_old_data_create_date() - + if chat_id not in self.expression_learners: self.expression_learners[chat_id] = ExpressionLearner(chat_id) return self.expression_learners[chat_id] @@ -644,7 +643,9 @@ class ExpressionLearnerManager: try: async with get_db_session() as session: # 查找所有create_date为空的表达方式 - old_expressions_result = await session.execute(select(Expression).where(Expression.create_date.is_(None))) + old_expressions_result = await session.execute( + select(Expression).where(Expression.create_date.is_(None)) + ) old_expressions = old_expressions_result.scalars().all() updated_count = 0 diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index 37a315197..8fee48d1c 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -131,7 +131,9 @@ class BotInterestManager: self.current_interests = generated_interests active_count = len(generated_interests.get_active_tags()) logger.info(f"成功生成 {active_count} 个新兴趣标签。") - tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in generated_interests.get_active_tags()] + tags_info = [ + f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in generated_interests.get_active_tags() + ] tags_str = "\n".join(tags_info) logger.info(f"当前兴趣标签:\n{tags_str}") @@ -639,11 +641,19 @@ class BotInterestManager: async with get_db_session() as session: # 查询最新的兴趣标签配置 - db_interests = (await session.execute( - select(DBBotPersonalityInterests) - .where(DBBotPersonalityInterests.personality_id == personality_id) - .order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc()) - )).scalars().first() + db_interests = ( + ( + await session.execute( + select(DBBotPersonalityInterests) + .where(DBBotPersonalityInterests.personality_id == personality_id) + .order_by( + DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc() + ) + ) + ) + .scalars() + .first() + ) if db_interests: logger.debug(f"在数据库中找到兴趣标签配置, 版本: {db_interests.version}") @@ -728,10 +738,17 @@ class BotInterestManager: async with get_db_session() as session: # 检查是否已存在相同personality_id的记录 - existing_record = (await session.execute( - select(DBBotPersonalityInterests) - .where(DBBotPersonalityInterests.personality_id == interests.personality_id) - )).scalars().first() + existing_record = ( + ( + await session.execute( + select(DBBotPersonalityInterests).where( + DBBotPersonalityInterests.personality_id == interests.personality_id + ) + ) + ) + .scalars() + .first() + ) if existing_record: # 更新现有记录 @@ -763,10 +780,17 @@ class BotInterestManager: # 验证保存是否成功 async with get_db_session() as session: - saved_record = (await session.execute( - select(DBBotPersonalityInterests) - .where(DBBotPersonalityInterests.personality_id == interests.personality_id) - )).scalars().first() + saved_record = ( + ( + await session.execute( + select(DBBotPersonalityInterests).where( + DBBotPersonalityInterests.personality_id == interests.personality_id + ) + ) + ) + .scalars() + .first() + ) if saved_record: logger.info(f"✅ 验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录") logger.info(f" 版本: {saved_record.version}") diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 162a00b7f..f6fae8d6c 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -161,7 +161,7 @@ class EmbeddingStore: @staticmethod def _get_embeddings_batch_threaded( - strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None + strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None ) -> List[Tuple[str, List[float]]]: """使用多线程批量获取嵌入向量 diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index f539659fb..c340fc30e 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -101,7 +101,7 @@ class QAManager: async def get_knowledge(self, question: str) -> Optional[Dict[str, Any]]: """ 获取知识,返回结构化字典 - + Args: question: 用户提出的问题 @@ -114,30 +114,27 @@ class QAManager: return None query_res = processed_result[0] - + knowledge_items = [] for res_hash, relevance, *_ in query_res: if store_item := self.embed_manager.paragraphs_embedding_store.store.get(res_hash): - knowledge_items.append({ - "content": store_item.str, - "source": "内部知识库", - "relevance": f"{relevance:.4f}" - }) + knowledge_items.append( + {"content": store_item.str, "source": "内部知识库", "relevance": f"{relevance:.4f}"} + ) if not knowledge_items: return None - + # 使用LLM生成总结 - knowledge_text_for_summary = "\n\n".join([item['content'] for item in knowledge_items[:5]]) # 最多总结前5条 - summary_prompt = f"根据以下信息,为问题 '{question}' 生成一个简洁的、不超过50字的摘要:\n\n{knowledge_text_for_summary}" - + knowledge_text_for_summary = "\n\n".join([item["content"] for item in knowledge_items[:5]]) # 最多总结前5条 + summary_prompt = ( + f"根据以下信息,为问题 '{question}' 生成一个简洁的、不超过50字的摘要:\n\n{knowledge_text_for_summary}" + ) + try: summary, (_, _, _) = await self.qa_model.generate_response_async(summary_prompt) except Exception as e: logger.error(f"生成知识摘要失败: {e}") summary = "无法生成摘要。" - return { - "knowledge_items": knowledge_items, - "summary": summary.strip() if summary else "没有可用的摘要。" - } + return {"knowledge_items": knowledge_items, "summary": summary.strip() if summary else "没有可用的摘要。"} diff --git a/src/chat/memory_system/__init__.py b/src/chat/memory_system/__init__.py index 75daf0fb2..a1c176a10 100644 --- a/src/chat/memory_system/__init__.py +++ b/src/chat/memory_system/__init__.py @@ -12,44 +12,23 @@ from .memory_chunk import ( MemoryType, ImportanceLevel, ConfidenceLevel, - create_memory_chunk + create_memory_chunk, ) # 遗忘引擎 -from .memory_forgetting_engine import ( - MemoryForgettingEngine, - ForgettingConfig, - get_memory_forgetting_engine -) +from .memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig, get_memory_forgetting_engine # Vector DB存储系统 -from .vector_memory_storage_v2 import ( - VectorMemoryStorage, - VectorStorageConfig, - get_vector_memory_storage -) +from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage # 记忆核心系统 -from .memory_system import ( - MemorySystem, - MemorySystemConfig, - get_memory_system, - initialize_memory_system -) +from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system # 记忆管理器 -from .memory_manager import ( - MemoryManager, - MemoryResult, - memory_manager -) +from .memory_manager import MemoryManager, MemoryResult, memory_manager # 激活器 -from .enhanced_memory_activator import ( - MemoryActivator, - memory_activator, - enhanced_memory_activator -) +from .enhanced_memory_activator import MemoryActivator, memory_activator, enhanced_memory_activator # 兼容性别名 from .memory_chunk import MemoryChunk as Memory @@ -64,28 +43,23 @@ __all__ = [ "ImportanceLevel", "ConfidenceLevel", "create_memory_chunk", - # 遗忘引擎 "MemoryForgettingEngine", "ForgettingConfig", "get_memory_forgetting_engine", - # Vector DB存储 "VectorMemoryStorage", - "VectorStorageConfig", + "VectorStorageConfig", "get_vector_memory_storage", - # 记忆系统 "MemorySystem", "MemorySystemConfig", "get_memory_system", "initialize_memory_system", - # 记忆管理器 "MemoryManager", - "MemoryResult", + "MemoryResult", "memory_manager", - # 激活器 "MemoryActivator", "memory_activator", @@ -95,4 +69,4 @@ __all__ = [ # 版本信息 __version__ = "3.0.0" __author__ = "MoFox Team" -__description__ = "简化记忆系统 - 统一记忆架构与智能遗忘机制" \ No newline at end of file +__description__ = "简化记忆系统 - 统一记忆架构与智能遗忘机制" diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py index 7d34df8d9..aae09c08b 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py @@ -4,15 +4,14 @@ 将增强记忆系统集成到现有MoFox Bot架构中 """ -import asyncio import time -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict, List, Optional, Any from dataclasses import dataclass from src.common.logger import get_logger from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType -from src.chat.memory_system.memory_formatter import MemoryFormatter, FormatterConfig, format_memories_for_llm +from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm from src.llm_models.utils_model import LLMRequest logger = get_logger(__name__) @@ -36,6 +35,7 @@ MEMORY_TYPE_LABELS = { @dataclass class AdapterConfig: """适配器配置""" + enable_enhanced_memory: bool = True integration_mode: str = "enhanced_only" # replace, enhanced_only auto_migration: bool = True @@ -61,7 +61,7 @@ class EnhancedMemoryAdapter: "hybrid_used": 0, "memories_created": 0, "memories_retrieved": 0, - "average_processing_time": 0.0 + "average_processing_time": 0.0, } async def initialize(self): @@ -79,14 +79,11 @@ class EnhancedMemoryAdapter: memory_value_threshold=self.config.memory_value_threshold, fusion_threshold=self.config.fusion_threshold, max_retrieval_results=self.config.max_retrieval_results, - enable_learning=True # 启用学习功能 + enable_learning=True, # 启用学习功能 ) # 创建集成层 - self.integration_layer = MemoryIntegrationLayer( - llm_model=self.llm_model, - config=integration_config - ) + self.integration_layer = MemoryIntegrationLayer(llm_model=self.llm_model, config=integration_config) # 初始化集成层 await self.integration_layer.initialize() @@ -99,10 +96,7 @@ class EnhancedMemoryAdapter: # 如果初始化失败,禁用增强记忆功能 self.config.enable_enhanced_memory = False - async def process_conversation_memory( - self, - context: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: + async def process_conversation_memory(self, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """处理对话记忆,以上下文为唯一输入""" if not self._initialized or not self.config.enable_enhanced_memory: return {"success": False, "error": "Enhanced memory not available"} @@ -152,11 +146,7 @@ class EnhancedMemoryAdapter: return {"success": False, "error": str(e)} async def retrieve_relevant_memories( - self, - query: str, - user_id: str, - context: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None + self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: Optional[int] = None ) -> List[MemoryChunk]: """检索相关记忆""" if not self._initialized or not self.config.enable_enhanced_memory: @@ -164,9 +154,7 @@ class EnhancedMemoryAdapter: try: limit = limit or self.config.max_retrieval_results - memories = await self.integration_layer.retrieve_relevant_memories( - query, None, context, limit - ) + memories = await self.integration_layer.retrieve_relevant_memories(query, None, context, limit) self.adapter_stats["memories_retrieved"] += len(memories) logger.debug(f"检索到 {len(memories)} 条相关记忆") @@ -178,11 +166,7 @@ class EnhancedMemoryAdapter: return [] async def get_memory_context_for_prompt( - self, - query: str, - user_id: str, - context: Optional[Dict[str, Any]] = None, - max_memories: int = 5 + self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, max_memories: int = 5 ) -> str: """获取用于提示词的记忆上下文""" memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories) @@ -197,15 +181,11 @@ class EnhancedMemoryAdapter: include_confidence=False, use_emoji_icons=True, group_by_type=False, - max_display_length=150 - ) - - return format_memories_for_llm( - memories=memories, - query_context=query, - config=formatter_config + max_display_length=150, ) + return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config) + async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]: """获取增强记忆系统摘要""" if not self._initialized or not self.config.enable_enhanced_memory: @@ -227,7 +207,7 @@ class EnhancedMemoryAdapter: "adapter_stats": adapter_stats, "integration_stats": integration_stats, "total_memories_created": adapter_stats["memories_created"], - "total_memories_retrieved": adapter_stats["memories_retrieved"] + "total_memories_retrieved": adapter_stats["memories_retrieved"], } except Exception as e: @@ -285,12 +265,12 @@ async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAd from src.config.config import global_config adapter_config = AdapterConfig( - enable_enhanced_memory=getattr(global_config.memory, 'enable_enhanced_memory', True), - integration_mode=getattr(global_config.memory, 'enhanced_memory_mode', 'enhanced_only'), - auto_migration=getattr(global_config.memory, 'enable_memory_migration', True), - memory_value_threshold=getattr(global_config.memory, 'memory_value_threshold', 0.6), - fusion_threshold=getattr(global_config.memory, 'fusion_threshold', 0.85), - max_retrieval_results=getattr(global_config.memory, 'max_retrieval_results', 10) + enable_enhanced_memory=getattr(global_config.memory, "enable_enhanced_memory", True), + integration_mode=getattr(global_config.memory, "enhanced_memory_mode", "enhanced_only"), + auto_migration=getattr(global_config.memory, "enable_memory_migration", True), + memory_value_threshold=getattr(global_config.memory, "memory_value_threshold", 0.6), + fusion_threshold=getattr(global_config.memory, "fusion_threshold", 0.85), + max_retrieval_results=getattr(global_config.memory, "max_retrieval_results", 10), ) _enhanced_memory_adapter = EnhancedMemoryAdapter(llm_model, adapter_config) @@ -312,13 +292,13 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest): async def process_conversation_with_enhanced_memory( - context: Dict[str, Any], - llm_model: Optional[LLMRequest] = None + context: Dict[str, Any], llm_model: Optional[LLMRequest] = None ) -> Dict[str, Any]: """使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息""" if not llm_model: # 获取默认的LLM模型 from src.llm_models.utils_model import get_global_llm_model + llm_model = get_global_llm_model() try: @@ -345,12 +325,13 @@ async def retrieve_memories_with_enhanced_system( user_id: str, context: Optional[Dict[str, Any]] = None, limit: int = 10, - llm_model: Optional[LLMRequest] = None + llm_model: Optional[LLMRequest] = None, ) -> List[MemoryChunk]: """使用增强记忆系统检索记忆""" if not llm_model: # 获取默认的LLM模型 from src.llm_models.utils_model import get_global_llm_model + llm_model = get_global_llm_model() try: @@ -366,12 +347,13 @@ async def get_memory_context_for_prompt( user_id: str, context: Optional[Dict[str, Any]] = None, max_memories: int = 5, - llm_model: Optional[LLMRequest] = None + llm_model: Optional[LLMRequest] = None, ) -> str: """获取用于提示词的记忆上下文""" if not llm_model: # 获取默认的LLM模型 from src.llm_models.utils_model import get_global_llm_model + llm_model = get_global_llm_model() try: @@ -379,4 +361,4 @@ async def get_memory_context_for_prompt( return await adapter.get_memory_context_for_prompt(query, user_id, context, max_memories) except Exception as e: logger.error(f"获取记忆上下文失败: {e}", exc_info=True) - return "" \ No newline at end of file + return "" diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py index 37188a08e..a1b374510 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py @@ -4,7 +4,6 @@ 用于在消息处理过程中自动构建和检索记忆 """ -import asyncio from typing import Dict, List, Any, Optional from datetime import datetime @@ -19,8 +18,7 @@ class EnhancedMemoryHooks: """增强记忆系统钩子 - 自动处理消息的记忆构建和检索""" def __init__(self): - self.enabled = (global_config.memory.enable_memory and - global_config.memory.enable_enhanced_memory) + self.enabled = global_config.memory.enable_memory and global_config.memory.enable_enhanced_memory self.processed_messages = set() # 避免重复处理 async def process_message_for_memory( @@ -29,7 +27,7 @@ class EnhancedMemoryHooks: user_id: str, chat_id: str, message_id: str, - context: Optional[Dict[str, Any]] = None + context: Optional[Dict[str, Any]] = None, ) -> bool: """ 处理消息并构建记忆 @@ -76,7 +74,7 @@ class EnhancedMemoryHooks: "timestamp": datetime.now().timestamp(), "message_type": "user_message", **bot_context, - **(context or {}) + **(context or {}), } # 处理对话并构建记忆 @@ -84,7 +82,7 @@ class EnhancedMemoryHooks: conversation_text=message_content, context=memory_context, user_id=user_id, - timestamp=memory_context["timestamp"] + timestamp=memory_context["timestamp"], ) # 标记消息已处理 @@ -108,7 +106,7 @@ class EnhancedMemoryHooks: user_id: str, chat_id: str, limit: int = 5, - extra_context: Optional[Dict[str, Any]] = None + extra_context: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """ 为回复获取相关记忆 @@ -134,9 +132,7 @@ class EnhancedMemoryHooks: context = { "chat_id": chat_id, "query_intent": "response_generation", - "expected_memory_types": [ - "personal_fact", "event", "preference", "opinion" - ] + "expected_memory_types": ["personal_fact", "event", "preference", "opinion"], } if extra_context: @@ -144,10 +140,7 @@ class EnhancedMemoryHooks: # 获取相关记忆 enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context( - query_text=query_text, - user_id=user_id, - context=context, - limit=limit + query_text=query_text, user_id=user_id, context=context, limit=limit ) # 转换为字典格式 @@ -199,4 +192,4 @@ class EnhancedMemoryHooks: # 创建全局实例 -enhanced_memory_hooks = EnhancedMemoryHooks() \ No newline at end of file +enhanced_memory_hooks = EnhancedMemoryHooks() diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py index 2d5e8f44e..913c2aed0 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py @@ -4,7 +4,6 @@ 用于在现有系统中无缝集成增强记忆功能 """ -import asyncio from typing import Dict, Any, Optional from src.common.logger import get_logger @@ -14,11 +13,7 @@ logger = get_logger(__name__) async def process_user_message_memory( - message_content: str, - user_id: str, - chat_id: str, - message_id: str, - context: Optional[Dict[str, Any]] = None + message_content: str, user_id: str, chat_id: str, message_id: str, context: Optional[Dict[str, Any]] = None ) -> bool: """ 处理用户消息并构建记忆 @@ -35,11 +30,7 @@ async def process_user_message_memory( """ try: success = await enhanced_memory_hooks.process_message_for_memory( - message_content=message_content, - user_id=user_id, - chat_id=chat_id, - message_id=message_id, - context=context + message_content=message_content, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context ) if success: @@ -53,11 +44,7 @@ async def process_user_message_memory( async def get_relevant_memories_for_response( - query_text: str, - user_id: str, - chat_id: str, - limit: int = 5, - extra_context: Optional[Dict[str, Any]] = None + query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ 为回复获取相关记忆 @@ -74,29 +61,17 @@ async def get_relevant_memories_for_response( """ try: memories = await enhanced_memory_hooks.get_memory_for_response( - query_text=query_text, - user_id=user_id, - chat_id=chat_id, - limit=limit, - extra_context=extra_context + query_text=query_text, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=extra_context ) - result = { - "has_memories": len(memories) > 0, - "memories": memories, - "memory_count": len(memories) - } + result = {"has_memories": len(memories) > 0, "memories": memories, "memory_count": len(memories)} logger.debug(f"为回复获取到 {len(memories)} 条相关记忆") return result except Exception as e: logger.error(f"获取回复记忆失败: {e}") - return { - "has_memories": False, - "memories": [], - "memory_count": 0 - } + return {"has_memories": False, "memories": [], "memory_count": 0} def format_memories_for_prompt(memories: Dict[str, Any]) -> str: @@ -152,16 +127,13 @@ def get_memory_system_status() -> Dict[str, Any]: "enabled": enhanced_memory_hooks.enabled, "enhanced_system_initialized": enhanced_memory_manager.is_initialized, "processed_messages_count": len(enhanced_memory_hooks.processed_messages), - "system_type": "enhanced_memory_system" + "system_type": "enhanced_memory_system", } # 便捷函数 async def remember_message( - message: str, - user_id: str = "default_user", - chat_id: str = "default_chat", - context: Optional[Dict[str, Any]] = None + message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: Optional[Dict[str, Any]] = None ) -> bool: """ 便捷的记忆构建函数 @@ -175,13 +147,10 @@ async def remember_message( bool: 是否成功 """ import uuid + message_id = str(uuid.uuid4()) return await process_user_message_memory( - message_content=message, - user_id=user_id, - chat_id=chat_id, - message_id=message_id, - context=context + message_content=message, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context ) @@ -190,7 +159,7 @@ async def recall_memories( user_id: str = "default_user", chat_id: str = "default_chat", limit: int = 5, - context: Optional[Dict[str, Any]] = None + context: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """ 便捷的记忆检索函数 @@ -205,9 +174,5 @@ async def recall_memories( Dict: 记忆信息 """ return await get_relevant_memories_for_response( - query_text=query, - user_id=user_id, - chat_id=chat_id, - limit=limit, - extra_context=context - ) \ No newline at end of file + query_text=query, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=context + ) diff --git a/src/chat/memory_system/deprecated_backup/enhanced_reranker.py b/src/chat/memory_system/deprecated_backup/enhanced_reranker.py index a6dfafb01..e5b368460 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_reranker.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_reranker.py @@ -18,32 +18,34 @@ logger = get_logger(__name__) class IntentType(Enum): """对话意图类型""" - FACT_QUERY = "fact_query" # 事实查询 - EVENT_RECALL = "event_recall" # 事件回忆 + + FACT_QUERY = "fact_query" # 事实查询 + EVENT_RECALL = "event_recall" # 事件回忆 PREFERENCE_CHECK = "preference_check" # 偏好检查 - GENERAL_CHAT = "general_chat" # 一般对话 - UNKNOWN = "unknown" # 未知意图 + GENERAL_CHAT = "general_chat" # 一般对话 + UNKNOWN = "unknown" # 未知意图 @dataclass class ReRankingConfig: """重排序配置""" + # 权重配置 (w1 + w2 + w3 + w4 = 1.0) - semantic_weight: float = 0.5 # 语义相似度权重 - recency_weight: float = 0.2 # 时效性权重 - usage_freq_weight: float = 0.2 # 使用频率权重 - type_match_weight: float = 0.1 # 类型匹配权重 - + semantic_weight: float = 0.5 # 语义相似度权重 + recency_weight: float = 0.2 # 时效性权重 + usage_freq_weight: float = 0.2 # 使用频率权重 + type_match_weight: float = 0.1 # 类型匹配权重 + # 时效性衰减参数 - recency_decay_rate: float = 0.1 # 时效性衰减率 (天) - + recency_decay_rate: float = 0.1 # 时效性衰减率 (天) + # 使用频率计算参数 - freq_log_base: float = 2.0 # 对数底数 - freq_max_score: float = 5.0 # 最大频率得分 - + freq_log_base: float = 2.0 # 对数底数 + freq_max_score: float = 5.0 # 最大频率得分 + # 类型匹配权重映射 type_match_weights: Dict[str, Dict[str, float]] = None - + def __post_init__(self): """初始化类型匹配权重""" if self.type_match_weights is None: @@ -53,102 +55,150 @@ class ReRankingConfig: MemoryType.KNOWLEDGE.value: 0.8, MemoryType.PREFERENCE.value: 0.5, MemoryType.EVENT.value: 0.3, - "default": 0.3 + "default": 0.3, }, IntentType.EVENT_RECALL.value: { MemoryType.EVENT.value: 1.0, MemoryType.EXPERIENCE.value: 0.8, MemoryType.EMOTION.value: 0.6, MemoryType.PERSONAL_FACT.value: 0.5, - "default": 0.5 + "default": 0.5, }, IntentType.PREFERENCE_CHECK.value: { MemoryType.PREFERENCE.value: 1.0, MemoryType.OPINION.value: 0.8, MemoryType.GOAL.value: 0.6, MemoryType.PERSONAL_FACT.value: 0.4, - "default": 0.4 + "default": 0.4, }, - IntentType.GENERAL_CHAT.value: { - "default": 0.8 - }, - IntentType.UNKNOWN.value: { - "default": 0.8 - } + IntentType.GENERAL_CHAT.value: {"default": 0.8}, + IntentType.UNKNOWN.value: {"default": 0.8}, } class IntentClassifier: """轻量级意图识别器""" - + def __init__(self): # 关键词模式匹配规则 self.patterns = { IntentType.FACT_QUERY: [ # 中文模式 - "我是", "我的", "我叫", "我在", "我住在", "我的职业", "我的工作", - "什么时候", "在哪里", "是什么", "多少", "几岁", "年龄", + "我是", + "我的", + "我叫", + "我在", + "我住在", + "我的职业", + "我的工作", + "什么时候", + "在哪里", + "是什么", + "多少", + "几岁", + "年龄", # 英文模式 - "what is", "where is", "when is", "how old", "my name", "i am", "i live" + "what is", + "where is", + "when is", + "how old", + "my name", + "i am", + "i live", ], IntentType.EVENT_RECALL: [ # 中文模式 - "记得", "想起", "还记得", "那次", "上次", "之前", "以前", "曾经", - "发生过", "经历", "做过", "去过", "见过", + "记得", + "想起", + "还记得", + "那次", + "上次", + "之前", + "以前", + "曾经", + "发生过", + "经历", + "做过", + "去过", + "见过", # 英文模式 - "remember", "recall", "last time", "before", "previously", "happened", "experience" + "remember", + "recall", + "last time", + "before", + "previously", + "happened", + "experience", ], IntentType.PREFERENCE_CHECK: [ # 中文模式 - "喜欢", "不喜欢", "偏好", "爱好", "兴趣", "讨厌", "最爱", "最喜欢", - "习惯", "通常", "一般", "倾向于", "更喜欢", + "喜欢", + "不喜欢", + "偏好", + "爱好", + "兴趣", + "讨厌", + "最爱", + "最喜欢", + "习惯", + "通常", + "一般", + "倾向于", + "更喜欢", # 英文模式 - "like", "love", "hate", "prefer", "favorite", "usually", "tend to", "interest" - ] + "like", + "love", + "hate", + "prefer", + "favorite", + "usually", + "tend to", + "interest", + ], } - + def classify_intent(self, query: str, context: Dict[str, Any]) -> IntentType: """识别对话意图""" if not query: return IntentType.UNKNOWN - + query_lower = query.lower() - + # 统计各意图的匹配分数 intent_scores = {intent: 0 for intent in IntentType} - + for intent, patterns in self.patterns.items(): for pattern in patterns: if pattern in query_lower: intent_scores[intent] += 1 - + # 返回得分最高的意图 max_score = max(intent_scores.values()) if max_score == 0: return IntentType.GENERAL_CHAT - + for intent, score in intent_scores.items(): if score == max_score: return intent - + return IntentType.GENERAL_CHAT class EnhancedReRanker: """增强重排序器 - 实现文档设计的多维度评分模型""" - + def __init__(self, config: Optional[ReRankingConfig] = None): self.config = config or ReRankingConfig() self.intent_classifier = IntentClassifier() - + # 验证权重和为1.0 total_weight = ( - self.config.semantic_weight + - self.config.recency_weight + - self.config.usage_freq_weight + - self.config.type_match_weight + self.config.semantic_weight + + self.config.recency_weight + + self.config.usage_freq_weight + + self.config.type_match_weight ) - + if abs(total_weight - 1.0) > 0.01: logger.warning(f"重排序权重和不为1.0: {total_weight}, 将进行归一化") # 归一化权重 @@ -156,94 +206,94 @@ class EnhancedReRanker: self.config.recency_weight /= total_weight self.config.usage_freq_weight /= total_weight self.config.type_match_weight /= total_weight - + def rerank_memories( self, query: str, candidate_memories: List[Tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity) context: Dict[str, Any], - limit: int = 10 + limit: int = 10, ) -> List[Tuple[str, MemoryChunk, float]]: """ 对候选记忆进行重排序 - + Args: query: 查询文本 candidate_memories: 候选记忆列表 [(memory_id, memory, vector_similarity)] context: 上下文信息 limit: 返回数量限制 - + Returns: 重排序后的记忆列表 [(memory_id, memory, final_score)] """ if not candidate_memories: return [] - + # 识别查询意图 intent = self.intent_classifier.classify_intent(query, context) logger.debug(f"识别到查询意图: {intent.value}") - + # 计算每个候选记忆的最终得分 scored_memories = [] current_time = time.time() - + for memory_id, memory, vector_sim in candidate_memories: try: # 1. 语义相似度得分 (已归一化到[0,1]) semantic_score = self._normalize_similarity(vector_sim) - + # 2. 时效性得分 recency_score = self._calculate_recency_score(memory, current_time) - + # 3. 使用频率得分 usage_freq_score = self._calculate_usage_frequency_score(memory) - + # 4. 类型匹配得分 type_match_score = self._calculate_type_match_score(memory, intent) - + # 计算最终得分 final_score = ( - self.config.semantic_weight * semantic_score + - self.config.recency_weight * recency_score + - self.config.usage_freq_weight * usage_freq_score + - self.config.type_match_weight * type_match_score + self.config.semantic_weight * semantic_score + + self.config.recency_weight * recency_score + + self.config.usage_freq_weight * usage_freq_score + + self.config.type_match_weight * type_match_score ) - + scored_memories.append((memory_id, memory, final_score)) - + # 记录调试信息 logger.debug( f"记忆评分 {memory_id[:8]}: semantic={semantic_score:.3f}, " f"recency={recency_score:.3f}, freq={usage_freq_score:.3f}, " f"type={type_match_score:.3f}, final={final_score:.3f}" ) - + except Exception as e: logger.error(f"计算记忆 {memory_id} 得分时出错: {e}") # 使用向量相似度作为后备得分 scored_memories.append((memory_id, memory, vector_sim)) - + # 按最终得分降序排序 scored_memories.sort(key=lambda x: x[2], reverse=True) - + # 返回前N个结果 result = scored_memories[:limit] - + highest_score = result[0][2] if result else 0.0 logger.info( f"重排序完成: 候选={len(candidate_memories)}, 返回={len(result)}, " f"意图={intent.value}, 最高分={highest_score:.3f}" ) - + return result - + def _normalize_similarity(self, raw_similarity: float) -> float: """归一化相似度到[0,1]区间""" # 假设原始相似度已经在[-1,1]或[0,1]区间 if raw_similarity < 0: return (raw_similarity + 1) / 2 # 从[-1,1]映射到[0,1] return min(1.0, max(0.0, raw_similarity)) # 确保在[0,1]区间 - + def _calculate_recency_score(self, memory: MemoryChunk, current_time: float) -> float: """ 计算时效性得分 @@ -251,13 +301,13 @@ class EnhancedReRanker: """ last_accessed = memory.metadata.last_accessed or memory.metadata.created_at days_old = (current_time - last_accessed) / (24 * 3600) # 转换为天数 - + if days_old < 0: days_old = 0 # 处理时间异常 - + score = 1 / (1 + self.config.recency_decay_rate * days_old) return min(1.0, max(0.0, score)) - + def _calculate_usage_frequency_score(self, memory: MemoryChunk) -> float: """ 计算使用频率得分 @@ -266,22 +316,22 @@ class EnhancedReRanker: access_count = memory.metadata.access_count if access_count <= 0: return 0.0 - + log_count = math.log2(access_count + 1) score = log_count / self.config.freq_max_score return min(1.0, max(0.0, score)) - + def _calculate_type_match_score(self, memory: MemoryChunk, intent: IntentType) -> float: """计算类型匹配得分""" memory_type = memory.memory_type.value intent_value = intent.value - + # 获取对应意图的类型权重映射 type_weights = self.config.type_match_weights.get(intent_value, {}) - + # 查找具体类型的权重,如果没有则使用默认权重 score = type_weights.get(memory_type, type_weights.get("default", 0.8)) - + return min(1.0, max(0.0, score)) @@ -294,7 +344,7 @@ def rerank_candidate_memories( candidate_memories: List[Tuple[str, MemoryChunk, float]], context: Dict[str, Any], limit: int = 10, - config: Optional[ReRankingConfig] = None + config: Optional[ReRankingConfig] = None, ) -> List[Tuple[str, MemoryChunk, float]]: """ 便捷函数:对候选记忆进行重排序 @@ -303,5 +353,5 @@ def rerank_candidate_memories( reranker = EnhancedReRanker(config) else: reranker = default_reranker - - return reranker.rerank_memories(query, candidate_memories, context, limit) \ No newline at end of file + + return reranker.rerank_memories(query, candidate_memories, context, limit) diff --git a/src/chat/memory_system/deprecated_backup/integration_layer.py b/src/chat/memory_system/deprecated_backup/integration_layer.py index 3f8b7f1ce..5b9282a84 100644 --- a/src/chat/memory_system/deprecated_backup/integration_layer.py +++ b/src/chat/memory_system/deprecated_backup/integration_layer.py @@ -12,7 +12,7 @@ from enum import Enum from src.common.logger import get_logger from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel +from src.chat.memory_system.memory_chunk import MemoryChunk from src.llm_models.utils_model import LLMRequest logger = get_logger(__name__) @@ -20,13 +20,15 @@ logger = get_logger(__name__) class IntegrationMode(Enum): """集成模式""" - REPLACE = "replace" # 完全替换现有记忆系统 + + REPLACE = "replace" # 完全替换现有记忆系统 ENHANCED_ONLY = "enhanced_only" # 仅使用增强记忆系统 @dataclass class IntegrationConfig: """集成配置""" + mode: IntegrationMode = IntegrationMode.ENHANCED_ONLY enable_enhanced_memory: bool = True memory_value_threshold: float = 0.6 @@ -51,7 +53,7 @@ class MemoryIntegrationLayer: "enhanced_queries": 0, "memory_creations": 0, "average_response_time": 0.0, - "success_rate": 0.0 + "success_rate": 0.0, } # 初始化锁 @@ -88,6 +90,7 @@ class MemoryIntegrationLayer: # 创建增强记忆系统配置 from src.chat.memory_system.enhanced_memory_core import MemorySystemConfig + memory_config = MemorySystemConfig.from_global_config() # 使用集成配置覆盖部分值 @@ -96,9 +99,7 @@ class MemoryIntegrationLayer: memory_config.final_recall_limit = self.config.max_retrieval_results # 创建增强记忆系统 - self.enhanced_memory = EnhancedMemorySystem( - config=memory_config - ) + self.enhanced_memory = EnhancedMemorySystem(config=memory_config) # 如果外部提供了LLM模型,注入到系统中 if self.llm_model is not None: @@ -112,10 +113,7 @@ class MemoryIntegrationLayer: logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True) raise - async def process_conversation( - self, - context: Dict[str, Any] - ) -> Dict[str, Any]: + async def process_conversation(self, context: Dict[str, Any]) -> Dict[str, Any]: """处理对话记忆,仅使用上下文信息""" if not self._initialized or not self.enhanced_memory: return {"success": False, "error": "Memory system not available"} @@ -154,7 +152,7 @@ class MemoryIntegrationLayer: query: str, user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None + limit: Optional[int] = None, ) -> List[MemoryChunk]: """检索相关记忆""" if not self._initialized or not self.enhanced_memory: @@ -163,10 +161,7 @@ class MemoryIntegrationLayer: try: limit = limit or self.config.max_retrieval_results memories = await self.enhanced_memory.retrieve_relevant_memories( - query=query, - user_id=None, - context=context or {}, - limit=limit + query=query, user_id=None, context=context or {}, limit=limit ) memory_count = len(memories) @@ -191,7 +186,7 @@ class MemoryIntegrationLayer: "status": "initialized", "mode": self.config.mode.value, "enhanced_memory": enhanced_status, - "integration_stats": self.integration_stats.copy() + "integration_stats": self.integration_stats.copy(), } except Exception as e: @@ -248,4 +243,4 @@ class MemoryIntegrationLayer: logger.info("✅ 记忆系统集成层已关闭") except Exception as e: - logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True) \ No newline at end of file + logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True) diff --git a/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py b/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py index 2dab63b7a..4659389cb 100644 --- a/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py +++ b/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py @@ -4,17 +4,15 @@ 提供与现有MoFox Bot系统的无缝集成点 """ -import asyncio import time -from typing import Dict, List, Optional, Any, Callable +from typing import Dict, Optional, Any from dataclasses import dataclass from src.common.logger import get_logger from src.chat.memory_system.enhanced_memory_adapter import ( - get_enhanced_memory_adapter, process_conversation_with_enhanced_memory, retrieve_memories_with_enhanced_system, - get_memory_context_for_prompt + get_memory_context_for_prompt, ) logger = get_logger(__name__) @@ -23,6 +21,7 @@ logger = get_logger(__name__) @dataclass class HookResult: """钩子执行结果""" + success: bool data: Any = None error: Optional[str] = None @@ -39,7 +38,7 @@ class MemoryIntegrationHooks: "memory_retrieval_hooks": 0, "prompt_enhancement_hooks": 0, "total_hook_executions": 0, - "average_hook_time": 0.0 + "average_hook_time": 0.0, } async def register_hooks(self): @@ -130,10 +129,7 @@ class MemoryIntegrationHooks: from src.plugin_system.base.component_types import EventType # 注册消息后处理事件 - event_manager.subscribe( - EventType.MESSAGE_PROCESSED, - self._on_message_processed_handler - ) + event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler) logger.debug("已注册到事件系统的消息处理钩子") except ImportError: @@ -144,10 +140,8 @@ class MemoryIntegrationHooks: from src.chat.message_manager import message_manager # 如果消息管理器支持钩子注册 - if hasattr(message_manager, 'register_post_process_hook'): - message_manager.register_post_process_hook( - self._on_message_processed_hook - ) + if hasattr(message_manager, "register_post_process_hook"): + message_manager.register_post_process_hook(self._on_message_processed_hook) logger.debug("已注册到消息管理器的处理钩子") except ImportError: @@ -164,10 +158,8 @@ class MemoryIntegrationHooks: from src.chat.message_receive.chat_stream import get_chat_manager chat_manager = get_chat_manager() - if hasattr(chat_manager, 'register_save_hook'): - chat_manager.register_save_hook( - self._on_chat_stream_save_hook - ) + if hasattr(chat_manager, "register_save_hook"): + chat_manager.register_save_hook(self._on_chat_stream_save_hook) logger.debug("已注册到聊天流管理器的保存钩子") except ImportError: @@ -183,10 +175,8 @@ class MemoryIntegrationHooks: try: from src.chat.replyer.default_generator import default_generator - if hasattr(default_generator, 'register_pre_generation_hook'): - default_generator.register_pre_generation_hook( - self._on_pre_response_hook - ) + if hasattr(default_generator, "register_pre_generation_hook"): + default_generator.register_pre_generation_hook(self._on_pre_response_hook) logger.debug("已注册到回复生成器的前置钩子") except ImportError: @@ -202,10 +192,8 @@ class MemoryIntegrationHooks: try: from src.chat.knowledge.knowledge_lib import knowledge_manager - if hasattr(knowledge_manager, 'register_query_enhancer'): - knowledge_manager.register_query_enhancer( - self._on_knowledge_query_hook - ) + if hasattr(knowledge_manager, "register_query_enhancer"): + knowledge_manager.register_query_enhancer(self._on_knowledge_query_hook) logger.debug("已注册到知识库的查询增强钩子") except ImportError: @@ -221,10 +209,8 @@ class MemoryIntegrationHooks: try: from src.chat.utils.prompt import prompt_manager - if hasattr(prompt_manager, 'register_enhancer'): - prompt_manager.register_enhancer( - self._on_prompt_building_hook - ) + if hasattr(prompt_manager, "register_enhancer"): + prompt_manager.register_enhancer(self._on_prompt_building_hook) logger.debug("已注册到提示词管理器的增强钩子") except ImportError: @@ -278,7 +264,7 @@ class MemoryIntegrationHooks: "platform": message_info.get("platform", "unknown"), "interest_value": message_data.get("interest_value", 0.0), "keywords": message_data.get("key_words", []), - "timestamp": message_data.get("time", time.time()) + "timestamp": message_data.get("time", time.time()), } # 使用增强记忆系统处理对话 @@ -296,7 +282,7 @@ class MemoryIntegrationHooks: return HookResult(success=True, data=result, processing_time=processing_time) else: logger.warning(f"消息处理钩子执行失败: {result.get('error')}") - return HookResult(success=False, error=result.get('error'), processing_time=processing_time) + return HookResult(success=False, error=result.get("error"), processing_time=processing_time) except Exception as e: processing_time = time.time() - start_time @@ -334,7 +320,7 @@ class MemoryIntegrationHooks: "stream_id": chat_stream_data.get("stream_id"), "platform": chat_stream_data.get("platform", "unknown"), "message_count": len(messages), - "timestamp": time.time() + "timestamp": time.time(), } # 使用增强记忆系统处理对话 @@ -352,7 +338,7 @@ class MemoryIntegrationHooks: return HookResult(success=True, data=result, processing_time=processing_time) else: logger.warning(f"聊天流保存钩子执行失败: {result.get('error')}") - return HookResult(success=False, error=result.get('error'), processing_time=processing_time) + return HookResult(success=False, error=result.get("error"), processing_time=processing_time) except Exception as e: processing_time = time.time() - start_time @@ -375,9 +361,7 @@ class MemoryIntegrationHooks: return HookResult(success=True, data="No query provided") # 检索相关记忆 - memories = await retrieve_memories_with_enhanced_system( - query, user_id, context, limit=5 - ) + memories = await retrieve_memories_with_enhanced_system(query, user_id, context, limit=5) processing_time = time.time() - start_time self._update_hook_stats(processing_time) @@ -411,9 +395,7 @@ class MemoryIntegrationHooks: return HookResult(success=True, data="No query provided") # 获取记忆上下文并增强查询 - memory_context = await get_memory_context_for_prompt( - query, user_id, context, max_memories=3 - ) + memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=3) processing_time = time.time() - start_time self._update_hook_stats(processing_time) @@ -445,9 +427,7 @@ class MemoryIntegrationHooks: return HookResult(success=True, data="No query provided") # 获取记忆上下文 - memory_context = await get_memory_context_for_prompt( - query, user_id, context, max_memories=5 - ) + memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=5) processing_time = time.time() - start_time self._update_hook_stats(processing_time) @@ -499,6 +479,7 @@ class MemoryMaintenanceTask: # 获取适配器实例 try: from src.chat.memory_system.enhanced_memory_adapter import _enhanced_memory_adapter + if _enhanced_memory_adapter: await _enhanced_memory_adapter.maintenance() logger.info("✅ 增强记忆系统维护任务完成") @@ -543,4 +524,4 @@ async def initialize_memory_integration_hooks(): return hooks except Exception as e: logger.error(f"❌ 记忆集成钩子初始化失败: {e}", exc_info=True) - return None \ No newline at end of file + return None diff --git a/src/chat/memory_system/deprecated_backup/metadata_index.py b/src/chat/memory_system/deprecated_backup/metadata_index.py index 10f8ff266..f7ab8ecda 100644 --- a/src/chat/memory_system/deprecated_backup/metadata_index.py +++ b/src/chat/memory_system/deprecated_backup/metadata_index.py @@ -4,12 +4,10 @@ 为记忆系统提供多维度的精准过滤和查询能力 """ -import os import time import orjson from typing import Dict, List, Optional, Tuple, Set, Any, Union -from datetime import datetime, timedelta -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import Enum import threading from collections import defaultdict @@ -23,23 +21,25 @@ logger = get_logger(__name__) class IndexType(Enum): """索引类型""" - MEMORY_TYPE = "memory_type" # 记忆类型索引 - USER_ID = "user_id" # 用户ID索引 - SUBJECT = "subject" # 主体索引 - KEYWORD = "keyword" # 关键词索引 - TAG = "tag" # 标签索引 - CATEGORY = "category" # 分类索引 - TIMESTAMP = "timestamp" # 时间索引 - CONFIDENCE = "confidence" # 置信度索引 - IMPORTANCE = "importance" # 重要性索引 + + MEMORY_TYPE = "memory_type" # 记忆类型索引 + USER_ID = "user_id" # 用户ID索引 + SUBJECT = "subject" # 主体索引 + KEYWORD = "keyword" # 关键词索引 + TAG = "tag" # 标签索引 + CATEGORY = "category" # 分类索引 + TIMESTAMP = "timestamp" # 时间索引 + CONFIDENCE = "confidence" # 置信度索引 + IMPORTANCE = "importance" # 重要性索引 RELATIONSHIP_SCORE = "relationship_score" # 关系分索引 ACCESS_FREQUENCY = "access_frequency" # 访问频率索引 - SEMANTIC_HASH = "semantic_hash" # 语义哈希索引 + SEMANTIC_HASH = "semantic_hash" # 语义哈希索引 @dataclass class IndexQuery: """索引查询条件""" + user_ids: Optional[List[str]] = None memory_types: Optional[List[MemoryType]] = None subjects: Optional[List[str]] = None @@ -61,6 +61,7 @@ class IndexQuery: @dataclass class IndexResult: """索引结果""" + memory_ids: List[str] total_count: int query_time: float @@ -102,7 +103,7 @@ class MetadataIndexManager: "average_query_time": 0.0, "total_queries": 0, "cache_hit_rate": 0.0, - "cache_hits": 0 + "cache_hits": 0, } # 线程锁 @@ -171,9 +172,8 @@ class MetadataIndexManager: index_time = time.time() - start_time self.index_stats["index_build_time"] = ( - (self.index_stats["index_build_time"] * (len(memories) - 1) + index_time) / - len(memories) - ) + self.index_stats["index_build_time"] * (len(memories) - 1) + index_time + ) / len(memories) logger.debug(f"元数据索引完成,{len(memories)} 条记忆,耗时 {index_time:.3f}秒") @@ -258,7 +258,7 @@ class MetadataIndexManager: "relationship_score": memory.metadata.relationship_score, "relevance_score": memory.metadata.relevance_score, "semantic_hash": memory.semantic_hash, - "subjects": memory.subjects + "subjects": memory.subjects, } # 记忆类型索引 @@ -355,21 +355,20 @@ class MetadataIndexManager: # 限制数量 if query.limit and len(filtered_ids) > query.limit: - filtered_ids = filtered_ids[:query.limit] + filtered_ids = filtered_ids[: query.limit] # 记录查询统计 query_time = time.time() - start_time self.index_stats["total_queries"] += 1 self.index_stats["average_query_time"] = ( - (self.index_stats["average_query_time"] * (self.index_stats["total_queries"] - 1) + query_time) / - self.index_stats["total_queries"] - ) + self.index_stats["average_query_time"] * (self.index_stats["total_queries"] - 1) + query_time + ) / self.index_stats["total_queries"] return IndexResult( memory_ids=filtered_ids, total_count=len(filtered_ids), query_time=query_time, - filtered_by=self._get_applied_filters(query) + filtered_by=self._get_applied_filters(query), ) except Exception as e: @@ -486,15 +485,15 @@ class MetadataIndexManager: if query.time_range: start_time, end_time = query.time_range filtered_ids = [ - memory_id for memory_id in filtered_ids - if self._is_in_time_range(memory_id, start_time, end_time) + memory_id for memory_id in filtered_ids if self._is_in_time_range(memory_id, start_time, end_time) ] # 置信度过滤 if query.confidence_levels: confidence_set = set(query.confidence_levels) filtered_ids = [ - memory_id for memory_id in filtered_ids + memory_id + for memory_id in filtered_ids if self.memory_metadata_cache[memory_id]["confidence"] in confidence_set ] @@ -502,27 +501,31 @@ class MetadataIndexManager: if query.importance_levels: importance_set = set(query.importance_levels) filtered_ids = [ - memory_id for memory_id in filtered_ids + memory_id + for memory_id in filtered_ids if self.memory_metadata_cache[memory_id]["importance"] in importance_set ] # 关系分范围过滤 if query.min_relationship_score is not None: filtered_ids = [ - memory_id for memory_id in filtered_ids + memory_id + for memory_id in filtered_ids if self.memory_metadata_cache[memory_id]["relationship_score"] >= query.min_relationship_score ] if query.max_relationship_score is not None: filtered_ids = [ - memory_id for memory_id in filtered_ids + memory_id + for memory_id in filtered_ids if self.memory_metadata_cache[memory_id]["relationship_score"] <= query.max_relationship_score ] # 最小访问次数过滤 if query.min_access_count is not None: filtered_ids = [ - memory_id for memory_id in filtered_ids + memory_id + for memory_id in filtered_ids if self.memory_metadata_cache[memory_id]["access_count"] >= query.min_access_count ] @@ -530,7 +533,8 @@ class MetadataIndexManager: if query.semantic_hashes: hash_set = set(query.semantic_hashes) filtered_ids = [ - memory_id for memory_id in filtered_ids + memory_id + for memory_id in filtered_ids if self.memory_metadata_cache[memory_id]["semantic_hash"] in hash_set ] @@ -560,8 +564,7 @@ class MetadataIndexManager: elif sort_by == "relevance_score": # 按相关度排序 memory_ids.sort( - key=lambda mid: self.memory_metadata_cache[mid]["relevance_score"], - reverse=(sort_order == "desc") + key=lambda mid: self.memory_metadata_cache[mid]["relevance_score"], reverse=(sort_order == "desc") ) elif sort_by == "relationship_score": @@ -574,8 +577,7 @@ class MetadataIndexManager: elif sort_by == "last_accessed": # 按最后访问时间排序 memory_ids.sort( - key=lambda mid: self.memory_metadata_cache[mid]["last_accessed"], - reverse=(sort_order == "desc") + key=lambda mid: self.memory_metadata_cache[mid]["last_accessed"], reverse=(sort_order == "desc") ) return memory_ids @@ -665,7 +667,9 @@ class MetadataIndexManager: self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid != memory_id] # 从访问频率索引中移除 - self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid != memory_id] + self.access_frequency_index = [ + (count, mid) for count, mid in self.access_frequency_index if mid != memory_id + ] # 注意:关键词、标签、分类索引需要从原始记忆中获取,这里简化处理 # 实际实现中可能需要重新加载记忆或维护反向索引 @@ -704,7 +708,7 @@ class MetadataIndexManager: "average_importance": 0.0, "average_relationship_score": 0.0, "top_keywords": [], - "top_tags": [] + "top_tags": [], } if user_id: @@ -789,23 +793,23 @@ class MetadataIndexManager: indices_data[index_type.value] = serialized_index indices_file = self.index_path / "indices.json" - with open(indices_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(indices_data, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(indices_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(indices_data, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存时间索引 time_index_file = self.index_path / "time_index.json" - with open(time_index_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(self.time_index, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(time_index_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(self.time_index, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存关系分索引 relationship_index_file = self.index_path / "relationship_index.json" - with open(relationship_index_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(self.relationship_index, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(relationship_index_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(self.relationship_index, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存访问频率索引 access_frequency_index_file = self.index_path / "access_frequency_index.json" - with open(access_frequency_index_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(self.access_frequency_index, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(access_frequency_index_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(self.access_frequency_index, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存元数据缓存 metadata_cache_file = self.index_path / "metadata_cache.json" @@ -813,13 +817,13 @@ class MetadataIndexManager: memory_id: self._serialize_metadata_entry(metadata) for memory_id, metadata in self.memory_metadata_cache.items() } - with open(metadata_cache_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(metadata_cache_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存统计信息 stats_file = self.index_path / "index_stats.json" - with open(stats_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(self.index_stats, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(stats_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(self.index_stats, option=orjson.OPT_INDENT_2).decode("utf-8")) self._dirty = False logger.info("✅ 元数据索引保存完成") @@ -835,7 +839,7 @@ class MetadataIndexManager: # 加载各类索引 indices_file = self.index_path / "indices.json" if indices_file.exists(): - with open(indices_file, 'r', encoding='utf-8') as f: + with open(indices_file, "r", encoding="utf-8") as f: indices_data = orjson.loads(f.read()) for index_type_value, index_data in indices_data.items(): @@ -849,25 +853,25 @@ class MetadataIndexManager: # 加载时间索引 time_index_file = self.index_path / "time_index.json" if time_index_file.exists(): - with open(time_index_file, 'r', encoding='utf-8') as f: + with open(time_index_file, "r", encoding="utf-8") as f: self.time_index = orjson.loads(f.read()) # 加载关系分索引 relationship_index_file = self.index_path / "relationship_index.json" if relationship_index_file.exists(): - with open(relationship_index_file, 'r', encoding='utf-8') as f: + with open(relationship_index_file, "r", encoding="utf-8") as f: self.relationship_index = orjson.loads(f.read()) # 加载访问频率索引 access_frequency_index_file = self.index_path / "access_frequency_index.json" if access_frequency_index_file.exists(): - with open(access_frequency_index_file, 'r', encoding='utf-8') as f: + with open(access_frequency_index_file, "r", encoding="utf-8") as f: self.access_frequency_index = orjson.loads(f.read()) # 加载元数据缓存 metadata_cache_file = self.index_path / "metadata_cache.json" if metadata_cache_file.exists(): - with open(metadata_cache_file, 'r', encoding='utf-8') as f: + with open(metadata_cache_file, "r", encoding="utf-8") as f: cache_data = orjson.loads(f.read()) # 转换置信度和重要性为枚举类型 @@ -910,7 +914,7 @@ class MetadataIndexManager: # 加载统计信息 stats_file = self.index_path / "index_stats.json" if stats_file.exists(): - with open(stats_file, 'r', encoding='utf-8') as f: + with open(stats_file, "r", encoding="utf-8") as f: self.index_stats = orjson.loads(f.read()) # 更新记忆计数 @@ -937,9 +941,7 @@ class MetadataIndexManager: # 更新统计信息 if self.index_stats["total_queries"] > 0: - self.index_stats["cache_hit_rate"] = ( - self.index_stats["cache_hits"] / self.index_stats["total_queries"] - ) + self.index_stats["cache_hit_rate"] = self.index_stats["cache_hits"] / self.index_stats["total_queries"] logger.info("✅ 元数据索引优化完成") @@ -967,7 +969,9 @@ class MetadataIndexManager: self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid in valid_memory_ids] # 清理访问频率索引中的无效引用 - self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid in valid_memory_ids] + self.access_frequency_index = [ + (count, mid) for count, mid in self.access_frequency_index if mid in valid_memory_ids + ] # 更新总记忆数 self.index_stats["total_memories"] = len(valid_memory_ids) @@ -1017,7 +1021,7 @@ class MetadataIndexManager: "categories": len(self.indices[IndexType.CATEGORY]), "confidence_levels": len(self.indices[IndexType.CONFIDENCE]), "importance_levels": len(self.indices[IndexType.IMPORTANCE]), - "semantic_hashes": len(self.indices[IndexType.SEMANTIC_HASH]) + "semantic_hashes": len(self.indices[IndexType.SEMANTIC_HASH]), } - return stats \ No newline at end of file + return stats diff --git a/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py b/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py index 21b99f7f4..bc0a1a0f4 100644 --- a/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py +++ b/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py @@ -5,15 +5,13 @@ """ import time -import asyncio -from typing import Dict, List, Optional, Tuple, Set, Any +from typing import Dict, List, Optional, Set, Any from dataclasses import dataclass, field from enum import Enum -import numpy as np import orjson from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType from src.chat.memory_system.enhanced_reranker import EnhancedReRanker, ReRankingConfig logger = get_logger(__name__) @@ -21,30 +19,32 @@ logger = get_logger(__name__) class RetrievalStage(Enum): """检索阶段""" - METADATA_FILTERING = "metadata_filtering" # 元数据过滤阶段 - VECTOR_SEARCH = "vector_search" # 向量搜索阶段 - SEMANTIC_RERANKING = "semantic_reranking" # 语义重排序阶段 - CONTEXTUAL_FILTERING = "contextual_filtering" # 上下文过滤阶段 + + METADATA_FILTERING = "metadata_filtering" # 元数据过滤阶段 + VECTOR_SEARCH = "vector_search" # 向量搜索阶段 + SEMANTIC_RERANKING = "semantic_reranking" # 语义重排序阶段 + CONTEXTUAL_FILTERING = "contextual_filtering" # 上下文过滤阶段 @dataclass class RetrievalConfig: """检索配置""" + # 各阶段配置 - 优化召回率 - metadata_filter_limit: int = 150 # 元数据过滤阶段返回数量(增加) - vector_search_limit: int = 80 # 向量搜索阶段返回数量(增加) - semantic_rerank_limit: int = 30 # 语义重排序阶段返回数量(增加) - final_result_limit: int = 10 # 最终结果数量 + metadata_filter_limit: int = 150 # 元数据过滤阶段返回数量(增加) + vector_search_limit: int = 80 # 向量搜索阶段返回数量(增加) + semantic_rerank_limit: int = 30 # 语义重排序阶段返回数量(增加) + final_result_limit: int = 10 # 最终结果数量 # 相似度阈值 - 优化召回率 - vector_similarity_threshold: float = 0.5 # 向量相似度阈值(降低以提升召回率) + vector_similarity_threshold: float = 0.5 # 向量相似度阈值(降低以提升召回率) semantic_similarity_threshold: float = 0.05 # 语义相似度阈值(保持较低以获得更多相关记忆) # 权重配置 - vector_weight: float = 0.4 # 向量相似度权重 - semantic_weight: float = 0.3 # 语义相似度权重 - context_weight: float = 0.2 # 上下文权重 - recency_weight: float = 0.1 # 时效性权重 + vector_weight: float = 0.4 # 向量相似度权重 + semantic_weight: float = 0.3 # 语义相似度权重 + context_weight: float = 0.2 # 上下文权重 + recency_weight: float = 0.1 # 时效性权重 @classmethod def from_global_config(cls): @@ -53,26 +53,25 @@ class RetrievalConfig: return cls( # 各阶段配置 - 优化召回率 - metadata_filter_limit=max(150, global_config.memory.metadata_filter_limit), # 增加候选池 - vector_search_limit=max(80, global_config.memory.vector_search_limit), # 增加向量搜索结果 - semantic_rerank_limit=max(30, global_config.memory.semantic_rerank_limit), # 增加重排序候选 + metadata_filter_limit=max(150, global_config.memory.metadata_filter_limit), # 增加候选池 + vector_search_limit=max(80, global_config.memory.vector_search_limit), # 增加向量搜索结果 + semantic_rerank_limit=max(30, global_config.memory.semantic_rerank_limit), # 增加重排序候选 final_result_limit=global_config.memory.final_result_limit, - # 相似度阈值 - 优化召回率 vector_similarity_threshold=max(0.5, global_config.memory.vector_similarity_threshold), # 确保不低于0.5 semantic_similarity_threshold=0.05, # 进一步降低以提升召回率 - # 权重配置 vector_weight=global_config.memory.vector_weight, semantic_weight=global_config.memory.semantic_weight, context_weight=global_config.memory.context_weight, - recency_weight=global_config.memory.recency_weight + recency_weight=global_config.memory.recency_weight, ) @dataclass class StageResult: """阶段结果""" + stage: RetrievalStage memory_ids: List[str] processing_time: float @@ -84,6 +83,7 @@ class StageResult: @dataclass class RetrievalResult: """检索结果""" + query: str user_id: str final_memories: List[MemoryChunk] @@ -98,16 +98,16 @@ class MultiStageRetrieval: def __init__(self, config: Optional[RetrievalConfig] = None): self.config = config or RetrievalConfig.from_global_config() - + # 初始化增强重排序器 reranker_config = ReRankingConfig( semantic_weight=self.config.vector_weight, recency_weight=self.config.recency_weight, usage_freq_weight=0.2, # 新增的使用频率权重 - type_match_weight=0.1 # 新增的类型匹配权重 + type_match_weight=0.1, # 新增的类型匹配权重 ) self.reranker = EnhancedReRanker(reranker_config) - + self.retrieval_stats = { "total_queries": 0, "average_retrieval_time": 0.0, @@ -116,8 +116,8 @@ class MultiStageRetrieval: "vector_search": {"calls": 0, "avg_time": 0.0}, "semantic_reranking": {"calls": 0, "avg_time": 0.0}, "contextual_filtering": {"calls": 0, "avg_time": 0.0}, - "enhanced_reranking": {"calls": 0, "avg_time": 0.0} # 新增统计 - } + "enhanced_reranking": {"calls": 0, "avg_time": 0.0}, # 新增统计 + }, } async def retrieve_memories( @@ -128,7 +128,7 @@ class MultiStageRetrieval: metadata_index, vector_storage, all_memories_cache: Dict[str, MemoryChunk], - limit: Optional[int] = None + limit: Optional[int] = None, ) -> RetrievalResult: """多阶段记忆检索""" start_time = time.time() @@ -143,31 +143,39 @@ class MultiStageRetrieval: # 阶段1:元数据过滤 stage1_result = await self._metadata_filtering_stage( - query, user_id, context, metadata_index, all_memories_cache, - debug_log=memory_debug_info + query, user_id, context, metadata_index, all_memories_cache, debug_log=memory_debug_info ) stage_results.append(stage1_result) current_memory_ids.update(stage1_result.memory_ids) # 阶段2:向量搜索 stage2_result = await self._vector_search_stage( - query, user_id, context, vector_storage, current_memory_ids, all_memories_cache, - debug_log=memory_debug_info + query, + user_id, + context, + vector_storage, + current_memory_ids, + all_memories_cache, + debug_log=memory_debug_info, ) stage_results.append(stage2_result) current_memory_ids.update(stage2_result.memory_ids) # 阶段3:语义重排序 stage3_result = await self._semantic_reranking_stage( - query, user_id, context, current_memory_ids, all_memories_cache, - debug_log=memory_debug_info + query, user_id, context, current_memory_ids, all_memories_cache, debug_log=memory_debug_info ) stage_results.append(stage3_result) # 阶段4:上下文过滤 stage4_result = await self._contextual_filtering_stage( - query, user_id, context, stage3_result.memory_ids, all_memories_cache, limit, - debug_log=memory_debug_info + query, + user_id, + context, + stage3_result.memory_ids, + all_memories_cache, + limit, + debug_log=memory_debug_info, ) stage_results.append(stage4_result) @@ -176,18 +184,27 @@ class MultiStageRetrieval: logger.debug(f"上下文过滤结果过少({len(stage4_result.memory_ids)}),启用回退机制") # 回退到更宽松的检索策略 fallback_result = await self._fallback_retrieval_stage( - query, user_id, context, all_memories_cache, limit, + query, + user_id, + context, + all_memories_cache, + limit, excluded_ids=set(stage4_result.memory_ids), - debug_log=memory_debug_info + debug_log=memory_debug_info, ) if fallback_result.memory_ids: - stage4_result.memory_ids.extend(fallback_result.memory_ids[:limit - len(stage4_result.memory_ids)]) + stage4_result.memory_ids.extend(fallback_result.memory_ids[: limit - len(stage4_result.memory_ids)]) logger.debug(f"回退机制补充了 {len(fallback_result.memory_ids)} 条记忆") # 阶段5:增强重排序 (新增) stage5_result = await self._enhanced_reranking_stage( - query, user_id, context, stage4_result.memory_ids, all_memories_cache, limit, - debug_log=memory_debug_info + query, + user_id, + context, + stage4_result.memory_ids, + all_memories_cache, + limit, + debug_log=memory_debug_info, ) stage_results.append(stage5_result) @@ -226,13 +243,21 @@ class MultiStageRetrieval: "semantic_score": trace.get("semantic_stage", {}).get("score"), "context_score": trace.get("context_stage", {}).get("context_score"), "final_score": trace.get("context_stage", {}).get("final_score"), - "status": trace.get("context_stage", {}).get("status") or trace.get("vector_stage", {}).get("status") or trace.get("semantic_stage", {}).get("status"), + "status": trace.get("context_stage", {}).get("status") + or trace.get("vector_stage", {}).get("status") + or trace.get("semantic_stage", {}).get("status"), "is_final": memory_id in final_ids_set, } debug_entries.append(entry) # 限制日志输出数量 - debug_entries.sort(key=lambda item: (item.get("is_final", False), item.get("final_score") or item.get("vector_similarity") or 0.0), reverse=True) + debug_entries.sort( + key=lambda item: ( + item.get("is_final", False), + item.get("final_score") or item.get("vector_similarity") or 0.0, + ), + reverse=True, + ) debug_payload = { "query": query, "semantic_query": context.get("resolved_query_text", query), @@ -266,7 +291,7 @@ class MultiStageRetrieval: stage_results=stage_results, total_processing_time=total_time, total_filtered=total_filtered, - retrieval_stats=self.retrieval_stats.copy() + retrieval_stats=self.retrieval_stats.copy(), ) except Exception as e: @@ -279,7 +304,7 @@ class MultiStageRetrieval: stage_results=stage_results, total_processing_time=time.time() - start_time, total_filtered=0, - retrieval_stats=self.retrieval_stats.copy() + retrieval_stats=self.retrieval_stats.copy(), ) async def _metadata_filtering_stage( @@ -290,7 +315,7 @@ class MultiStageRetrieval: metadata_index, all_memories_cache: Dict[str, MemoryChunk], *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None + debug_log: Optional[Dict[str, Dict[str, Any]]] = None, ) -> StageResult: """阶段1:元数据过滤""" start_time = time.time() @@ -302,7 +327,9 @@ class MultiStageRetrieval: memory_types = self._extract_memory_types_from_context(context) keywords = self._extract_keywords_from_query(query, query_plan) - subjects = query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None + subjects = ( + query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None + ) index_query = IndexQuery( user_ids=None, @@ -311,7 +338,7 @@ class MultiStageRetrieval: keywords=keywords, limit=self.config.metadata_filter_limit, sort_by="last_accessed", - sort_order="desc" + sort_order="desc", ) # 执行查询 @@ -328,19 +355,16 @@ class MultiStageRetrieval: reverse=True, ) if memory_types: - type_filtered = [ - mid for mid in sorted_ids - if all_memories_cache[mid].memory_type in memory_types - ] + type_filtered = [mid for mid in sorted_ids if all_memories_cache[mid].memory_type in memory_types] sorted_ids = type_filtered or sorted_ids if subjects: subject_candidates = [s.lower() for s in subjects if isinstance(s, str) and s.strip()] if subject_candidates: subject_filtered = [ - mid for mid in sorted_ids + mid + for mid in sorted_ids if any( - subj.strip().lower() in subject_candidates - for subj in all_memories_cache[mid].subjects + subj.strip().lower() in subject_candidates for subj in all_memories_cache[mid].subjects ) ] sorted_ids = subject_filtered or sorted_ids @@ -367,12 +391,14 @@ class MultiStageRetrieval: bool(subjects), bool(keywords), ) - details.append({ - "note": "fallback_recent", - "requested_types": [mt.value for mt in memory_types] if memory_types else [], - "subjects": subjects or [], - "keywords": keywords or [], - }) + details.append( + { + "note": "fallback_recent", + "requested_types": [mt.value for mt in memory_types] if memory_types else [], + "subjects": subjects or [], + "keywords": keywords or [], + } + ) logger.debug( "元数据过滤:候选=%d, 返回=%d", @@ -419,7 +445,7 @@ class MultiStageRetrieval: candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None + debug_log: Optional[Dict[str, Dict[str, Any]]] = None, ) -> StageResult: """阶段2:向量搜索""" start_time = time.time() @@ -441,8 +467,7 @@ class MultiStageRetrieval: # 执行向量搜索 search_result = await vector_storage.search_similar_memories( - query_vector=query_embedding, - limit=self.config.vector_search_limit + query_vector=query_embedding, limit=self.config.vector_search_limit ) if not search_result: @@ -464,16 +489,18 @@ class MultiStageRetrieval: if in_metadata_candidates and above_threshold: filtered_memories.append((memory_id, similarity)) - raw_details.append({ - "memory_id": memory_id, - "similarity": similarity, - "in_metadata": in_metadata_candidates, - "above_threshold": above_threshold, - }) + raw_details.append( + { + "memory_id": memory_id, + "similarity": similarity, + "in_metadata": in_metadata_candidates, + "above_threshold": above_threshold, + } + ) # 按相似度排序 filtered_memories.sort(key=lambda x: x[1], reverse=True) - result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]] + result_ids = [memory_id for memory_id, _ in filtered_memories[: self.config.vector_search_limit]] kept_ids = set(result_ids) for entry in raw_details: @@ -534,11 +561,7 @@ class MultiStageRetrieval: ) def _create_text_search_fallback( - self, - candidate_ids: Set[str], - all_memories_cache: Dict[str, MemoryChunk], - query_text: str, - start_time: float + self, candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], query_text: str, start_time: float ) -> StageResult: """当向量搜索失败时,使用文本搜索作为回退策略""" try: @@ -561,15 +584,13 @@ class MultiStageRetrieval: # 按匹配度排序 text_matches.sort(key=lambda x: x[1], reverse=True) - result_ids = [memory_id for memory_id, _ in text_matches[:self.config.vector_search_limit]] + result_ids = [memory_id for memory_id, _ in text_matches[: self.config.vector_search_limit]] details = [] - for memory_id, score in text_matches[:self.config.vector_search_limit]: - details.append({ - "memory_id": memory_id, - "text_match_score": round(score, 4), - "status": "text_match_fallback" - }) + for memory_id, score in text_matches[: self.config.vector_search_limit]: + details.append( + {"memory_id": memory_id, "text_match_score": round(score, 4), "status": "text_match_fallback"} + ) logger.debug(f"向量搜索回退到文本匹配:找到 {len(result_ids)} 条匹配记忆") @@ -579,18 +600,18 @@ class MultiStageRetrieval: processing_time=time.time() - start_time, filtered_count=len(candidate_ids) - len(result_ids), score_threshold=0.0, # 文本匹配无严格阈值 - details=details + details=details, ) except Exception as e: logger.error(f"文本搜索回退失败: {e}") return StageResult( stage=RetrievalStage.VECTOR_SEARCH, - memory_ids=list(candidate_ids)[:self.config.vector_search_limit], + memory_ids=list(candidate_ids)[: self.config.vector_search_limit], processing_time=time.time() - start_time, filtered_count=0, score_threshold=0.0, - details=[{"error": str(e), "note": "text_fallback_failed"}] + details=[{"error": str(e), "note": "text_fallback_failed"}], ) async def _semantic_reranking_stage( @@ -601,7 +622,7 @@ class MultiStageRetrieval: candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None + debug_log: Optional[Dict[str, Dict[str, Any]]] = None, ) -> StageResult: """阶段3:语义重排序""" start_time = time.time() @@ -643,7 +664,7 @@ class MultiStageRetrieval: # 按语义相似度排序 reranked_memories.sort(key=lambda x: x[1], reverse=True) - result_ids = [memory_id for memory_id, _ in reranked_memories[:self.config.semantic_rerank_limit]] + result_ids = [memory_id for memory_id, _ in reranked_memories[: self.config.semantic_rerank_limit]] kept_ids = set(result_ids) filtered_count = len(candidate_ids) - len(result_ids) @@ -688,7 +709,7 @@ class MultiStageRetrieval: all_memories_cache: Dict[str, MemoryChunk], limit: int, *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None + debug_log: Optional[Dict[str, Dict[str, Any]]] = None, ) -> StageResult: """阶段4:上下文过滤""" start_time = time.time() @@ -777,7 +798,7 @@ class MultiStageRetrieval: limit: int, *, excluded_ids: Optional[Set[str]] = None, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None + debug_log: Optional[Dict[str, Dict[str, Any]]] = None, ) -> StageResult: """回退检索阶段 - 当主检索失败时使用更宽松的策略""" start_time = time.time() @@ -806,13 +827,15 @@ class MultiStageRetrieval: if not fallback_candidates: logger.debug("关键词匹配无结果,使用时序最近策略") recent_memories = sorted( - [(mid, mem.metadata.last_accessed or mem.metadata.created_at) - for mid, mem in all_memories_cache.items() - if mid not in excluded_ids], + [ + (mid, mem.metadata.last_accessed or mem.metadata.created_at) + for mid, mem in all_memories_cache.items() + if mid not in excluded_ids + ], key=lambda x: x[1], - reverse=True + reverse=True, ) - fallback_candidates = [(mid, 0.5) for mid, _ in recent_memories[:limit*2]] + fallback_candidates = [(mid, 0.5) for mid, _ in recent_memories[: limit * 2]] # 按分数排序 fallback_candidates.sort(key=lambda x: x[1], reverse=True) @@ -857,7 +880,9 @@ class MultiStageRetrieval: details=[{"error": str(e)}], ) - async def _generate_query_embedding(self, query: str, context: Dict[str, Any], vector_storage) -> Optional[List[float]]: + async def _generate_query_embedding( + self, query: str, context: Dict[str, Any], vector_storage + ) -> Optional[List[float]]: """生成查询向量""" try: query_plan = context.get("query_plan") @@ -875,15 +900,15 @@ class MultiStageRetrieval: logger.debug(f"正在生成查询向量,文本: '{query_text[:100]}'") embedding = await vector_storage.generate_query_embedding(query_text) - + if embedding is None: logger.warning("向量存储返回空的查询向量") return None - + if len(embedding) == 0: logger.warning("向量存储返回空列表作为查询向量") return None - + logger.debug(f"查询向量生成成功,维度: {len(embedding)}") return embedding @@ -926,8 +951,8 @@ class MultiStageRetrieval: import re # 分词处理 - query_words = list(jieba.cut(query_text)) + re.findall(r'[a-zA-Z]+', query_text) - memory_words = list(jieba.cut(memory_text)) + re.findall(r'[a-zA-Z]+', memory_text) + query_words = list(jieba.cut(query_text)) + re.findall(r"[a-zA-Z]+", query_text) + memory_words = list(jieba.cut(memory_text)) + re.findall(r"[a-zA-Z]+", memory_text) # 清理和标准化 query_words = [w.strip().lower() for w in query_words if w.strip() and len(w.strip()) > 1] @@ -953,8 +978,9 @@ class MultiStageRetrieval: except ImportError: # 如果jieba不可用,使用简单分词 import re - query_words = re.findall(r'[\w\u4e00-\u9fa5]+', query_lower) - memory_words = re.findall(r'[\w\u4e00-\u9fa5]+', memory_lower) + + query_words = re.findall(r"[\w\u4e00-\u9fa5]+", query_lower) + memory_words = re.findall(r"[\w\u4e00-\u9fa5]+", memory_lower) if query_words and memory_words: query_set = set(w for w in query_words if len(w) > 1) @@ -971,13 +997,19 @@ class MultiStageRetrieval: "天气": ["天气", "阳光", "雨", "晴", "阴", "温度", "weather", "sunny", "rain"], "编程": ["编程", "代码", "程序", "开发", "语言", "programming", "code", "develop", "python"], "时间": ["今天", "昨天", "明天", "现在", "时间", "today", "yesterday", "tomorrow", "time"], - "情感": ["好", "坏", "开心", "难过", "有趣", "good", "bad", "happy", "sad", "fun"] + "情感": ["好", "坏", "开心", "难过", "有趣", "good", "bad", "happy", "sad", "fun"], } - query_concepts = {concept for concept, keywords in concept_groups.items() - if any(keyword in query_lower for keyword in keywords)} - memory_concepts = {concept for concept, keywords in concept_groups.items() - if any(keyword in memory_lower for keyword in keywords)} + query_concepts = { + concept + for concept, keywords in concept_groups.items() + if any(keyword in query_lower for keyword in keywords) + } + memory_concepts = { + concept + for concept, keywords in concept_groups.items() + if any(keyword in memory_lower for keyword in keywords) + } if query_concepts and memory_concepts: concept_overlap = query_concepts & memory_concepts @@ -987,19 +1019,19 @@ class MultiStageRetrieval: plan_bonus = 0.0 if query_plan: # 主体匹配 - if hasattr(query_plan, 'subjects') and query_plan.subjects: + if hasattr(query_plan, "subjects") and query_plan.subjects: for subject in query_plan.subjects: if subject.lower() in memory_lower: plan_bonus += 0.15 # 对象匹配 - if hasattr(query_plan, 'objects') and query_plan.objects: + if hasattr(query_plan, "objects") and query_plan.objects: for obj in query_plan.objects: if obj.lower() in memory_lower: plan_bonus += 0.1 # 记忆类型匹配 - if hasattr(query_plan, 'memory_types') and query_plan.memory_types: + if hasattr(query_plan, "memory_types") and query_plan.memory_types: if memory.memory_type in query_plan.memory_types: plan_bonus += 0.1 @@ -1059,14 +1091,22 @@ class MultiStageRetrieval: object_keywords = getattr(query_plan, "object_includes", []) or [] if object_keywords: display_text = (memory.display or memory.text_content or "").lower() - hits = sum(1 for kw in object_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text) + hits = sum( + 1 + for kw in object_keywords + if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text + ) if hits: score += min(0.3, hits * 0.1) optional_keywords = getattr(query_plan, "optional_keywords", []) or [] if optional_keywords: display_text = (memory.display or memory.text_content or "").lower() - hits = sum(1 for kw in optional_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text) + hits = sum( + 1 + for kw in optional_keywords + if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text + ) if hits: score += min(0.2, hits * 0.05) @@ -1091,7 +1131,9 @@ class MultiStageRetrieval: logger.warning(f"计算上下文相关度失败: {e}") return 0.0 - async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float: + async def _calculate_final_score( + self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float + ) -> float: """计算最终评分""" try: query_plan = context.get("query_plan") @@ -1126,10 +1168,10 @@ class MultiStageRetrieval: context_weight += 0.05 final_score = ( - semantic_score * semantic_weight + - vector_score * vector_weight + - context_score * context_weight + - recency_score * recency_weight + semantic_score * semantic_weight + + vector_score * vector_weight + + context_score * context_weight + + recency_score * recency_weight ) # 加入记忆重要性权重 @@ -1259,7 +1301,9 @@ class MultiStageRetrieval: stage_stat["calls"] += 1 current_stage_avg = stage_stat["avg_time"] - new_stage_avg = (current_stage_avg * (stage_stat["calls"] - 1) + result.processing_time) / stage_stat["calls"] + new_stage_avg = (current_stage_avg * (stage_stat["calls"] - 1) + result.processing_time) / stage_stat[ + "calls" + ] stage_stat["avg_time"] = new_stage_avg def get_retrieval_stats(self) -> Dict[str, Any]: @@ -1276,8 +1320,8 @@ class MultiStageRetrieval: "vector_search": {"calls": 0, "avg_time": 0.0}, "semantic_reranking": {"calls": 0, "avg_time": 0.0}, "contextual_filtering": {"calls": 0, "avg_time": 0.0}, - "enhanced_reranking": {"calls": 0, "avg_time": 0.0} - } + "enhanced_reranking": {"calls": 0, "avg_time": 0.0}, + }, } async def _enhanced_reranking_stage( @@ -1289,7 +1333,7 @@ class MultiStageRetrieval: all_memories_cache: Dict[str, MemoryChunk], limit: int, *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None + debug_log: Optional[Dict[str, Dict[str, Any]]] = None, ) -> StageResult: """阶段5:增强重排序 - 使用多维度评分模型""" start_time = time.time() @@ -1326,15 +1370,12 @@ class MultiStageRetrieval: # 使用增强重排序器 reranked_memories = self.reranker.rerank_memories( - query=query, - candidate_memories=candidate_memories, - context=context, - limit=limit + query=query, candidate_memories=candidate_memories, context=context, limit=limit ) # 提取重排序后的记忆ID result_ids = [memory_id for memory_id, _, _ in reranked_memories] - + # 生成调试详情 details = [] for memory_id, memory, final_score in reranked_memories: @@ -1346,7 +1387,7 @@ class MultiStageRetrieval: "access_count": memory.metadata.access_count, } details.append(detail_entry) - + if debug_log is not None: stage_entry = debug_log.setdefault(memory_id, {}).setdefault("enhanced_rerank_stage", {}) stage_entry["final_score"] = round(final_score, 4) @@ -1357,13 +1398,9 @@ class MultiStageRetrieval: kept_ids = set(result_ids) for memory_id in candidate_ids: if memory_id not in kept_ids: - detail_entry = { - "memory_id": memory_id, - "status": "filtered_out", - "reason": "ranked_below_limit" - } + detail_entry = {"memory_id": memory_id, "status": "filtered_out", "reason": "ranked_below_limit"} details.append(detail_entry) - + if debug_log is not None: stage_entry = debug_log.setdefault(memory_id, {}).setdefault("enhanced_rerank_stage", {}) stage_entry["status"] = "filtered_out" @@ -1371,10 +1408,7 @@ class MultiStageRetrieval: filtered_count = len(candidate_ids) - len(result_ids) - logger.debug( - f"增强重排序完成:候选={len(candidate_ids)}, 返回={len(result_ids)}, " - f"过滤={filtered_count}" - ) + logger.debug(f"增强重排序完成:候选={len(candidate_ids)}, 返回={len(result_ids)}, 过滤={filtered_count}") return StageResult( stage=RetrievalStage.CONTEXTUAL_FILTERING, # 保持与原有枚举兼容 @@ -1394,4 +1428,4 @@ class MultiStageRetrieval: filtered_count=0, score_threshold=0.0, details=[{"error": str(e)}], - ) \ No newline at end of file + ) diff --git a/src/chat/memory_system/deprecated_backup/vector_storage.py b/src/chat/memory_system/deprecated_backup/vector_storage.py index 73ddbb6a6..5d2e4fb91 100644 --- a/src/chat/memory_system/deprecated_backup/vector_storage.py +++ b/src/chat/memory_system/deprecated_backup/vector_storage.py @@ -4,22 +4,19 @@ 为记忆系统提供高效的向量存储和语义搜索能力 """ -import os import time import orjson import asyncio -from typing import Dict, List, Optional, Tuple, Set, Any +from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass -from datetime import datetime import threading import numpy as np -import pandas as pd from pathlib import Path from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config, global_config +from src.config.config import model_config from src.common.config_helpers import resolve_embedding_dimension from src.chat.memory_system.memory_chunk import MemoryChunk @@ -28,6 +25,7 @@ logger = get_logger(__name__) # 尝试导入FAISS,如果不可用则使用简单替代 try: import faiss + FAISS_AVAILABLE = True except ImportError: FAISS_AVAILABLE = False @@ -37,6 +35,7 @@ except ImportError: @dataclass class VectorStorageConfig: """向量存储配置""" + dimension: int = 1024 similarity_threshold: float = 0.8 index_type: str = "flat" # flat, ivf, hnsw @@ -79,7 +78,7 @@ class VectorStorageManager: "average_search_time": 0.0, "cache_hit_rate": 0.0, "total_searches": 0, - "cache_hits": 0 + "cache_hits": 0, } # 线程锁 @@ -122,8 +121,7 @@ class VectorStorageManager: """初始化嵌入模型""" if self.embedding_model is None: self.embedding_model = LLMRequest( - model_set=model_config.model_task_config.embedding, - request_type="memory.embedding" + model_set=model_config.model_task_config.embedding, request_type="memory.embedding" ) logger.info("✅ 嵌入模型初始化完成") @@ -137,20 +135,16 @@ class VectorStorageManager: await self.initialize_embedding_model() logger.debug(f"开始生成查询向量,文本: '{query_text[:50]}{'...' if len(query_text) > 50 else ''}'") - + embedding, _ = await self.embedding_model.get_embedding(query_text) if not embedding: logger.warning("嵌入模型返回空向量") return None logger.debug(f"生成的向量维度: {len(embedding)}, 期望维度: {self.config.dimension}") - + if len(embedding) != self.config.dimension: - logger.error( - "查询向量维度不匹配: 期望 %d, 实际 %d", - self.config.dimension, - len(embedding) - ) + logger.error("查询向量维度不匹配: 期望 %d, 实际 %d", self.config.dimension, len(embedding)) return None normalized_vector = self._normalize_vector(embedding) @@ -287,7 +281,7 @@ class VectorStorageManager: logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc) results[memory_id] = [] - tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts)] + tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)] await asyncio.gather(*tasks, return_exceptions=True) except Exception as e: @@ -313,12 +307,12 @@ class VectorStorageManager: memory.set_embedding(embedding) # 添加到向量索引 - if hasattr(self.vector_index, 'add'): + if hasattr(self.vector_index, "add"): # FAISS索引 if isinstance(embedding, np.ndarray): - vector_array = embedding.reshape(1, -1).astype('float32') + vector_array = embedding.reshape(1, -1).astype("float32") else: - vector_array = np.array([embedding], dtype='float32') + vector_array = np.array([embedding], dtype="float32") # 特殊处理IVF索引 if self.config.index_type == "ivf" and self.vector_index.ntotal == 0: @@ -367,14 +361,14 @@ class VectorStorageManager: *, query_text: Optional[str] = None, limit: int = 10, - scope_id: Optional[str] = None + scope_id: Optional[str] = None, ) -> List[Tuple[str, float]]: """搜索相似记忆""" start_time = time.time() try: logger.debug(f"开始向量搜索: query_text='{query_text[:30] if query_text else 'None'}', limit={limit}") - + if query_vector is None: if not query_text: logger.warning("查询向量和查询文本都为空") @@ -395,34 +389,34 @@ class VectorStorageManager: # 规范化查询向量 query_vector = self._normalize_vector(query_vector) - + logger.debug(f"查询向量维度: {len(query_vector)}, 存储总向量数: {self.storage_stats['total_vectors']}") # 检查向量索引状态 if not self.vector_index: logger.error("向量索引未初始化") return [] - + total_vectors = 0 - if hasattr(self.vector_index, 'ntotal'): + if hasattr(self.vector_index, "ntotal"): total_vectors = self.vector_index.ntotal - elif hasattr(self.vector_index, 'vectors'): + elif hasattr(self.vector_index, "vectors"): total_vectors = len(self.vector_index.vectors) - + logger.debug(f"向量索引中实际向量数: {total_vectors}") - + if total_vectors == 0: logger.warning("向量索引为空,无法执行搜索") return [] # 执行向量搜索 with self._lock: - if hasattr(self.vector_index, 'search'): + if hasattr(self.vector_index, "search"): # FAISS索引 if isinstance(query_vector, np.ndarray): - query_array = query_vector.reshape(1, -1).astype('float32') + query_array = query_vector.reshape(1, -1).astype("float32") else: - query_array = np.array([query_vector], dtype='float32') + query_array = np.array([query_vector], dtype="float32") if self.config.index_type == "ivf" and self.vector_index.ntotal > 0: # 设置IVF搜索参数 @@ -432,11 +426,11 @@ class VectorStorageManager: search_limit = min(limit, total_vectors) logger.debug(f"执行FAISS搜索,搜索限制: {search_limit}") - + distances, indices = self.vector_index.search(query_array, search_limit) distances = distances.flatten().tolist() indices = indices.flatten().tolist() - + logger.debug(f"FAISS搜索结果: {len(distances)} 个距离值, {len(indices)} 个索引") else: # 简单索引 @@ -451,8 +445,8 @@ class VectorStorageManager: valid_results = 0 invalid_indices = 0 filtered_by_scope = 0 - - for distance, index in zip(distances, indices): + + for distance, index in zip(distances, indices, strict=False): if index == -1: # FAISS的无效索引标记 invalid_indices += 1 continue @@ -462,7 +456,7 @@ class VectorStorageManager: logger.debug(f"索引 {index} 没有对应的记忆ID") invalid_indices += 1 continue - + if scope_filter: memory = self.memory_cache.get(memory_id) if memory and str(memory.user_id) != scope_filter: @@ -482,16 +476,15 @@ class VectorStorageManager: search_time = time.time() - start_time self.storage_stats["total_searches"] += 1 self.storage_stats["average_search_time"] = ( - (self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time) / - self.storage_stats["total_searches"] - ) + self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time + ) / self.storage_stats["total_searches"] final_results = results[:limit] logger.info( f"向量搜索完成: 查询='{query_text[:20] if query_text else 'vector'}' " f"耗时={search_time:.3f}s, 返回={len(final_results)}个结果" ) - + return final_results except Exception as e: @@ -520,7 +513,7 @@ class VectorStorageManager: old_index = self.memory_id_to_index[memory_id] # 删除旧向量(如果支持) - if hasattr(self.vector_index, 'remove_ids'): + if hasattr(self.vector_index, "remove_ids"): try: self.vector_index.remove_ids(np.array([old_index])) except: @@ -530,11 +523,11 @@ class VectorStorageManager: new_embedding = self._normalize_vector(new_embedding) # 添加新向量 - if hasattr(self.vector_index, 'add'): + if hasattr(self.vector_index, "add"): if isinstance(new_embedding, np.ndarray): - vector_array = new_embedding.reshape(1, -1).astype('float32') + vector_array = new_embedding.reshape(1, -1).astype("float32") else: - vector_array = np.array([new_embedding], dtype='float32') + vector_array = np.array([new_embedding], dtype="float32") self.vector_index.add(vector_array) new_index = self.vector_index.ntotal - 1 @@ -569,7 +562,7 @@ class VectorStorageManager: index = self.memory_id_to_index[memory_id] # 从向量索引中删除(如果支持) - if hasattr(self.vector_index, 'remove_ids'): + if hasattr(self.vector_index, "remove_ids"): try: self.vector_index.remove_ids(np.array([index])) except: @@ -598,44 +591,37 @@ class VectorStorageManager: logger.info("正在保存向量存储...") # 保存记忆缓存 - cache_data = { - memory_id: memory.to_dict() - for memory_id, memory in self.memory_cache.items() - } + cache_data = {memory_id: memory.to_dict() for memory_id, memory in self.memory_cache.items()} cache_file = self.storage_path / "memory_cache.json" - with open(cache_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(cache_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存向量缓存 vector_cache_file = self.storage_path / "vector_cache.json" - with open(vector_cache_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(vector_cache_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存映射关系 mapping_file = self.storage_path / "id_mapping.json" mapping_data = { "memory_id_to_index": { - str(memory_id): int(index) - for memory_id, index in self.memory_id_to_index.items() + str(memory_id): int(index) for memory_id, index in self.memory_id_to_index.items() }, - "index_to_memory_id": { - str(index): memory_id - for index, memory_id in self.index_to_memory_id.items() - } + "index_to_memory_id": {str(index): memory_id for index, memory_id in self.index_to_memory_id.items()}, } - with open(mapping_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(mapping_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存FAISS索引(如果可用) - if FAISS_AVAILABLE and hasattr(self.vector_index, 'save'): + if FAISS_AVAILABLE and hasattr(self.vector_index, "save"): index_file = self.storage_path / "vector_index.faiss" faiss.write_index(self.vector_index, str(index_file)) # 保存统计信息 stats_file = self.storage_path / "storage_stats.json" - with open(stats_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(stats_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode("utf-8")) logger.info("✅ 向量存储保存完成") @@ -650,36 +636,31 @@ class VectorStorageManager: # 加载记忆缓存 cache_file = self.storage_path / "memory_cache.json" if cache_file.exists(): - with open(cache_file, 'r', encoding='utf-8') as f: + with open(cache_file, "r", encoding="utf-8") as f: cache_data = orjson.loads(f.read()) self.memory_cache = { - memory_id: MemoryChunk.from_dict(memory_data) - for memory_id, memory_data in cache_data.items() + memory_id: MemoryChunk.from_dict(memory_data) for memory_id, memory_data in cache_data.items() } # 加载向量缓存 vector_cache_file = self.storage_path / "vector_cache.json" if vector_cache_file.exists(): - with open(vector_cache_file, 'r', encoding='utf-8') as f: + with open(vector_cache_file, "r", encoding="utf-8") as f: self.vector_cache = orjson.loads(f.read()) # 加载映射关系 mapping_file = self.storage_path / "id_mapping.json" if mapping_file.exists(): - with open(mapping_file, 'r', encoding='utf-8') as f: + with open(mapping_file, "r", encoding="utf-8") as f: mapping_data = orjson.loads(f.read()) raw_memory_to_index = mapping_data.get("memory_id_to_index", {}) self.memory_id_to_index = { - str(memory_id): int(index) - for memory_id, index in raw_memory_to_index.items() + str(memory_id): int(index) for memory_id, index in raw_memory_to_index.items() } raw_index_to_memory = mapping_data.get("index_to_memory_id", {}) - self.index_to_memory_id = { - int(index): memory_id - for index, memory_id in raw_index_to_memory.items() - } + self.index_to_memory_id = {int(index): memory_id for index, memory_id in raw_index_to_memory.items()} # 加载FAISS索引(如果可用) index_loaded = False @@ -699,7 +680,7 @@ class VectorStorageManager: logger.warning(f"加载FAISS索引失败: {e},重新构建") else: logger.info("FAISS索引文件不存在,将重新构建") - + # 如果索引没有成功加载且有向量数据,则重建索引 if not index_loaded and self.vector_cache: logger.info(f"检测到 {len(self.vector_cache)} 个向量缓存,重建索引") @@ -708,7 +689,7 @@ class VectorStorageManager: # 加载统计信息 stats_file = self.storage_path / "storage_stats.json" if stats_file.exists(): - with open(stats_file, 'r', encoding='utf-8') as f: + with open(stats_file, "r", encoding="utf-8") as f: self.storage_stats = orjson.loads(f.read()) # 更新向量计数 @@ -738,7 +719,7 @@ class VectorStorageManager: # 准备向量数据 memory_ids = [] vectors = [] - + for memory_id, embedding in self.vector_cache.items(): if embedding and len(embedding) == self.config.dimension: memory_ids.append(memory_id) @@ -753,18 +734,18 @@ class VectorStorageManager: logger.info(f"准备重建 {len(vectors)} 个向量到索引") # 批量添加向量到FAISS索引 - if hasattr(self.vector_index, 'add'): + if hasattr(self.vector_index, "add"): # FAISS索引 - vector_array = np.array(vectors, dtype='float32') - + vector_array = np.array(vectors, dtype="float32") + # 特殊处理IVF索引 - if self.config.index_type == "ivf" and hasattr(self.vector_index, 'train'): + if self.config.index_type == "ivf" and hasattr(self.vector_index, "train"): logger.info("训练IVF索引...") self.vector_index.train(vector_array) # 添加向量 self.vector_index.add(vector_array) - + # 重建映射关系 for i, memory_id in enumerate(memory_ids): self.memory_id_to_index[memory_id] = i @@ -772,15 +753,15 @@ class VectorStorageManager: else: # 简单索引 - for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors)): + for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors, strict=False)): index_id = self.vector_index.add_vector(vector) self.memory_id_to_index[memory_id] = index_id self.index_to_memory_id[index_id] = memory_id # 更新统计 self.storage_stats["total_vectors"] = len(self.memory_id_to_index) - - final_count = getattr(self.vector_index, 'ntotal', len(self.memory_id_to_index)) + + final_count = getattr(self.vector_index, "ntotal", len(self.memory_id_to_index)) logger.info(f"✅ 向量索引重建完成,索引中向量数: {final_count}") except Exception as e: @@ -875,7 +856,7 @@ class SimpleVectorIndex: def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float: """计算余弦相似度""" try: - dot_product = sum(x * y for x, y in zip(v1, v2)) + dot_product = sum(x * y for x, y in zip(v1, v2, strict=False)) norm1 = sum(x * x for x in v1) ** 0.5 norm2 = sum(x * x for x in v2) ** 0.5 @@ -890,4 +871,4 @@ class SimpleVectorIndex: @property def ntotal(self) -> int: """向量总数""" - return len(self.vectors) \ No newline at end of file + return len(self.vectors) diff --git a/src/chat/memory_system/enhanced_memory_activator.py b/src/chat/memory_system/enhanced_memory_activator.py index a9c5fc9cb..7570715ee 100644 --- a/src/chat/memory_system/enhanced_memory_activator.py +++ b/src/chat/memory_system/enhanced_memory_activator.py @@ -6,7 +6,6 @@ import difflib import orjson -import time from typing import List, Dict, Optional from datetime import datetime @@ -15,7 +14,7 @@ from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.common.logger import get_logger from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.memory_system.memory_manager import memory_manager, MemoryResult +from src.chat.memory_system.memory_manager import MemoryResult logger = get_logger("memory_activator") @@ -127,8 +126,8 @@ class MemoryActivator: for result in memory_results: # 检查是否已存在相似内容的记忆 exists = any( - m["content"] == result.content or - difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7 + m["content"] == result.content + or difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7 for m in self.running_memory ) if not exists: @@ -140,7 +139,7 @@ class MemoryActivator: "confidence": result.confidence, "importance": result.importance, "source": result.source, - "relevance_score": result.relevance_score # 添加相关度评分 + "relevance_score": result.relevance_score, # 添加相关度评分 } self.running_memory.append(memory_entry) logger.debug(f"添加新记忆: {result.memory_type} - {result.content}") @@ -168,17 +167,14 @@ class MemoryActivator: return [] # 构建查询上下文 - context = { - "keywords": keywords, - "query_intent": "conversation_response" - } + context = {"keywords": keywords, "query_intent": "conversation_response"} # 查询记忆 memories = await memory_system.retrieve_relevant_memories( query_text=query_text, user_id="global", # 使用全局作用域 context=context, - limit=5 + limit=5, ) # 转换为 MemoryResult 格式 @@ -191,7 +187,7 @@ class MemoryActivator: importance=memory.metadata.importance.value, timestamp=memory.metadata.created_at, source="unified_memory", - relevance_score=memory.metadata.relevance_score + relevance_score=memory.metadata.relevance_score, ) memory_results.append(result) @@ -214,16 +210,10 @@ class MemoryActivator: if not memory_system or memory_system.status.value != "ready": return None - context = { - "query_intent": "instant_response", - "chat_id": chat_id - } + context = {"query_intent": "instant_response", "chat_id": chat_id} memories = await memory_system.retrieve_relevant_memories( - query_text=target_message, - user_id="global", - context=context, - limit=1 + query_text=target_message, user_id="global", context=context, limit=1 ) if memories: @@ -248,4 +238,4 @@ memory_activator = MemoryActivator() # 兼容性别名 enhanced_memory_activator = memory_activator -init_prompt() \ No newline at end of file +init_prompt() diff --git a/src/chat/memory_system/memory_activator_new.py b/src/chat/memory_system/memory_activator_new.py index e11e66e25..491034de4 100644 --- a/src/chat/memory_system/memory_activator_new.py +++ b/src/chat/memory_system/memory_activator_new.py @@ -6,7 +6,6 @@ import difflib import orjson -import time from typing import List, Dict, Optional from datetime import datetime @@ -15,7 +14,7 @@ from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.common.logger import get_logger from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.memory_system.memory_manager import memory_manager, MemoryResult +from src.chat.memory_system.memory_manager import MemoryResult logger = get_logger("memory_activator") @@ -127,8 +126,8 @@ class MemoryActivator: for result in memory_results: # 检查是否已存在相似内容的记忆 exists = any( - m["content"] == result.content or - difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7 + m["content"] == result.content + or difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7 for m in self.running_memory ) if not exists: @@ -140,7 +139,7 @@ class MemoryActivator: "confidence": result.confidence, "importance": result.importance, "source": result.source, - "relevance_score": result.relevance_score # 添加相关度评分 + "relevance_score": result.relevance_score, # 添加相关度评分 } self.running_memory.append(memory_entry) logger.debug(f"添加新记忆: {result.memory_type} - {result.content}") @@ -168,17 +167,14 @@ class MemoryActivator: return [] # 构建查询上下文 - context = { - "keywords": keywords, - "query_intent": "conversation_response" - } + context = {"keywords": keywords, "query_intent": "conversation_response"} # 查询记忆 memories = await memory_system.retrieve_relevant_memories( query_text=query_text, user_id="global", # 使用全局作用域 context=context, - limit=5 + limit=5, ) # 转换为 MemoryResult 格式 @@ -191,7 +187,7 @@ class MemoryActivator: importance=memory.metadata.importance.value, timestamp=memory.metadata.created_at, source="unified_memory", - relevance_score=memory.metadata.relevance_score + relevance_score=memory.metadata.relevance_score, ) memory_results.append(result) @@ -214,16 +210,10 @@ class MemoryActivator: if not memory_system or memory_system.status.value != "ready": return None - context = { - "query_intent": "instant_response", - "chat_id": chat_id - } + context = {"query_intent": "instant_response", "chat_id": chat_id} memories = await memory_system.retrieve_relevant_memories( - query_text=target_message, - user_id="global", - context=context, - limit=1 + query_text=target_message, user_id="global", context=context, limit=1 ) if memories: @@ -246,4 +236,4 @@ class MemoryActivator: memory_activator = MemoryActivator() -init_prompt() \ No newline at end of file +init_prompt() diff --git a/src/chat/memory_system/memory_builder.py b/src/chat/memory_system/memory_builder.py index 5937b491d..0c3f47043 100644 --- a/src/chat/memory_system/memory_builder.py +++ b/src/chat/memory_system/memory_builder.py @@ -33,7 +33,7 @@ import time from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Union, Type +from typing import Any, Dict, List, Optional, Union, Type import orjson @@ -53,14 +53,15 @@ logger = get_logger(__name__) class ExtractionStrategy(Enum): """提取策略""" - LLM_BASED = "llm_based" # 基于LLM的智能提取 - RULE_BASED = "rule_based" # 基于规则的提取 - HYBRID = "hybrid" # 混合策略 + LLM_BASED = "llm_based" # 基于LLM的智能提取 + RULE_BASED = "rule_based" # 基于规则的提取 + HYBRID = "hybrid" # 混合策略 @dataclass class ExtractionResult: """提取结果""" + memories: List[MemoryChunk] confidence_scores: List[float] extraction_time: float @@ -80,15 +81,11 @@ class MemoryBuilder: "total_extractions": 0, "successful_extractions": 0, "failed_extractions": 0, - "average_confidence": 0.0 + "average_confidence": 0.0, } async def build_memories( - self, - conversation_text: str, - context: Dict[str, Any], - user_id: str, - timestamp: float + self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: float ) -> List[MemoryChunk]: """从对话中构建记忆""" start_time = time.time() @@ -119,19 +116,13 @@ class MemoryBuilder: raise async def _extract_with_llm( - self, - text: str, - context: Dict[str, Any], - user_id: str, - timestamp: float + self, text: str, context: Dict[str, Any], user_id: str, timestamp: float ) -> List[MemoryChunk]: """使用LLM提取记忆""" try: prompt = self._build_llm_extraction_prompt(text, context) - response, _ = await self.llm_model.generate_response_async( - prompt, temperature=0.3 - ) + response, _ = await self.llm_model.generate_response_async(prompt, temperature=0.3) # 解析LLM响应 memories = self._parse_llm_response(response, user_id, timestamp, context) @@ -342,16 +333,12 @@ class MemoryBuilder: start = stripped.find("{") end = stripped.rfind("}") if start != -1 and end != -1 and end > start: - return stripped[start:end + 1].strip() + return stripped[start : end + 1].strip() return stripped if stripped.startswith("{") and stripped.endswith("}") else None def _parse_llm_response( - self, - response: str, - user_id: str, - timestamp: float, - context: Dict[str, Any] + self, response: str, user_id: str, timestamp: float, context: Dict[str, Any] ) -> List[MemoryChunk]: """解析LLM响应""" if not response: @@ -366,9 +353,7 @@ class MemoryBuilder: data = orjson.loads(json_payload) except Exception as e: preview = json_payload[:200] - raise MemoryExtractionError( - f"LLM响应JSON解析失败: {e}, 片段: {preview}" - ) from e + raise MemoryExtractionError(f"LLM响应JSON解析失败: {e}, 片段: {preview}") from e memory_list = data.get("memories", []) @@ -406,17 +391,15 @@ class MemoryBuilder: try: # 检查是否包含模糊代称 display_text = mem_data.get("display", "") - if any(ambiguous_term in display_text for ambiguous_term in ["用户", "user", "the user", "对方", "对手"]): + if any( + ambiguous_term in display_text for ambiguous_term in ["用户", "user", "the user", "对方", "对手"] + ): logger.debug(f"拒绝构建包含模糊代称的记忆,display字段: {display_text}") continue subject_value = mem_data.get("subject") normalized_subject = self._normalize_subjects( - subject_value, - bot_identifiers, - system_identifiers, - default_subjects, - bot_display + subject_value, bot_identifiers, system_identifiers, default_subjects, bot_display ) if not normalized_subject: @@ -425,17 +408,11 @@ class MemoryBuilder: # 创建记忆块 importance_level = self._parse_enum_value( - ImportanceLevel, - mem_data.get("importance"), - ImportanceLevel.NORMAL, - "importance" + ImportanceLevel, mem_data.get("importance"), ImportanceLevel.NORMAL, "importance" ) confidence_level = self._parse_enum_value( - ConfidenceLevel, - mem_data.get("confidence"), - ConfidenceLevel.MEDIUM, - "confidence" + ConfidenceLevel, mem_data.get("confidence"), ConfidenceLevel.MEDIUM, "confidence" ) predicate_value = mem_data.get("predicate", "") @@ -457,7 +434,7 @@ class MemoryBuilder: source_context=mem_data.get("reasoning", ""), importance=importance_level, confidence=confidence_level, - display=display_text + display=display_text, ) if used_fallback_display: @@ -483,13 +460,7 @@ class MemoryBuilder: return memories - def _parse_enum_value( - self, - enum_cls: Type[Enum], - raw_value: Any, - default: Enum, - field_name: str - ) -> Enum: + def _parse_enum_value(self, enum_cls: Type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum: """解析枚举值,兼容数字/字符串表示""" if isinstance(raw_value, enum_cls): return raw_value @@ -533,12 +504,14 @@ class MemoryBuilder: try: return enum_cls(raw_value) except Exception: - logger.debug("%s=%s 类型 %s 无法解析为 %s,使用默认值 %s", - field_name, - raw_value, - type(raw_value).__name__, - enum_cls.__name__, - default.name) + logger.debug( + "%s=%s 类型 %s 无法解析为 %s,使用默认值 %s", + field_name, + raw_value, + type(raw_value).__name__, + enum_cls.__name__, + default.name, + ) return default def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]: @@ -606,7 +579,7 @@ class MemoryBuilder: "members", "member_names", "mention_users", - "audiences" + "audiences", ] for key in candidate_keys: @@ -727,7 +700,7 @@ class MemoryBuilder: bot_identifiers: set[str], system_identifiers: set[str], default_subjects: List[str], - bot_display: Optional[str] = None + bot_display: Optional[str] = None, ) -> List[str]: defaults = default_subjects or ["对话参与者"] @@ -800,7 +773,9 @@ class MemoryBuilder: return obj.strip() or None return None - def _compose_display_text(self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]) -> str: + def _compose_display_text( + self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]] + ) -> str: subject_phrase = "、".join(subjects) if subjects else "对话参与者" predicate = (predicate or "").strip() @@ -866,11 +841,7 @@ class MemoryBuilder: return f"{subject_phrase}{predicate}".strip() return subject_phrase - def _validate_and_enhance_memories( - self, - memories: List[MemoryChunk], - context: Dict[str, Any] - ) -> List[MemoryChunk]: + def _validate_and_enhance_memories(self, memories: List[MemoryChunk], context: Dict[str, Any]) -> List[MemoryChunk]: """验证和增强记忆""" validated_memories = [] @@ -905,11 +876,7 @@ class MemoryBuilder: return True - def _enhance_memory( - self, - memory: MemoryChunk, - context: Dict[str, Any] - ) -> MemoryChunk: + def _enhance_memory(self, memory: MemoryChunk, context: Dict[str, Any]) -> MemoryChunk: """增强记忆块""" # 时间规范化处理 self._normalize_time_in_memory(memory) @@ -919,7 +886,7 @@ class MemoryBuilder: memory.temporal_context = { "timestamp": memory.metadata.created_at, "timezone": context.get("timezone", "UTC"), - "day_of_week": datetime.fromtimestamp(memory.metadata.created_at).strftime("%A") + "day_of_week": datetime.fromtimestamp(memory.metadata.created_at).strftime("%A"), } # 添加情感上下文(如果有) @@ -941,22 +908,22 @@ class MemoryBuilder: # 定义相对时间映射 relative_time_patterns = { - r'今天|今日': current_time.strftime('%Y-%m-%d'), - r'昨天|昨日': (current_time - timedelta(days=1)).strftime('%Y-%m-%d'), - r'明天|明日': (current_time + timedelta(days=1)).strftime('%Y-%m-%d'), - r'后天': (current_time + timedelta(days=2)).strftime('%Y-%m-%d'), - r'大后天': (current_time + timedelta(days=3)).strftime('%Y-%m-%d'), - r'前天': (current_time - timedelta(days=2)).strftime('%Y-%m-%d'), - r'大前天': (current_time - timedelta(days=3)).strftime('%Y-%m-%d'), - r'本周|这周|这星期': current_time.strftime('%Y-%m-%d'), - r'上周|上星期': (current_time - timedelta(weeks=1)).strftime('%Y-%m-%d'), - r'下周|下星期': (current_time + timedelta(weeks=1)).strftime('%Y-%m-%d'), - r'本月|这个月': current_time.strftime('%Y-%m-01'), - r'上月|上个月': (current_time.replace(day=1) - timedelta(days=1)).strftime('%Y-%m-01'), - r'下月|下个月': (current_time.replace(day=1) + timedelta(days=32)).replace(day=1).strftime('%Y-%m-01'), - r'今年|今年': current_time.strftime('%Y'), - r'去年|上一年': str(current_time.year - 1), - r'明年|下一年': str(current_time.year + 1), + r"今天|今日": current_time.strftime("%Y-%m-%d"), + r"昨天|昨日": (current_time - timedelta(days=1)).strftime("%Y-%m-%d"), + r"明天|明日": (current_time + timedelta(days=1)).strftime("%Y-%m-%d"), + r"后天": (current_time + timedelta(days=2)).strftime("%Y-%m-%d"), + r"大后天": (current_time + timedelta(days=3)).strftime("%Y-%m-%d"), + r"前天": (current_time - timedelta(days=2)).strftime("%Y-%m-%d"), + r"大前天": (current_time - timedelta(days=3)).strftime("%Y-%m-%d"), + r"本周|这周|这星期": current_time.strftime("%Y-%m-%d"), + r"上周|上星期": (current_time - timedelta(weeks=1)).strftime("%Y-%m-%d"), + r"下周|下星期": (current_time + timedelta(weeks=1)).strftime("%Y-%m-%d"), + r"本月|这个月": current_time.strftime("%Y-%m-01"), + r"上月|上个月": (current_time.replace(day=1) - timedelta(days=1)).strftime("%Y-%m-01"), + r"下月|下个月": (current_time.replace(day=1) + timedelta(days=32)).replace(day=1).strftime("%Y-%m-01"), + r"今年|今年": current_time.strftime("%Y"), + r"去年|上一年": str(current_time.year - 1), + r"明年|下一年": str(current_time.year + 1), } def _normalize_value(value): @@ -1009,10 +976,14 @@ class MemoryBuilder: # 更新平均置信度 if self.extraction_stats["successful_extractions"] > 0: - total_confidence = self.extraction_stats["average_confidence"] * (self.extraction_stats["successful_extractions"] - success_count) + total_confidence = self.extraction_stats["average_confidence"] * ( + self.extraction_stats["successful_extractions"] - success_count + ) # 假设新记忆的平均置信度为0.8 total_confidence += 0.8 * success_count - self.extraction_stats["average_confidence"] = total_confidence / self.extraction_stats["successful_extractions"] + self.extraction_stats["average_confidence"] = ( + total_confidence / self.extraction_stats["successful_extractions"] + ) def get_extraction_stats(self) -> Dict[str, Any]: """获取提取统计信息""" @@ -1024,5 +995,5 @@ class MemoryBuilder: "total_extractions": 0, "successful_extractions": 0, "failed_extractions": 0, - "average_confidence": 0.0 - } \ No newline at end of file + "average_confidence": 0.0, + } diff --git a/src/chat/memory_system/memory_chunk.py b/src/chat/memory_system/memory_chunk.py index 0ddf6bd0f..b5b609af6 100644 --- a/src/chat/memory_system/memory_chunk.py +++ b/src/chat/memory_system/memory_chunk.py @@ -8,8 +8,7 @@ import time import uuid import orjson from typing import Dict, List, Optional, Any, Union, Iterable -from dataclasses import dataclass, field, asdict -from datetime import datetime +from dataclasses import dataclass, field from enum import Enum import hashlib @@ -21,33 +20,36 @@ logger = get_logger(__name__) class MemoryType(Enum): """记忆类型分类""" - PERSONAL_FACT = "personal_fact" # 个人事实(姓名、职业、住址等) - EVENT = "event" # 事件(重要经历、约会等) - PREFERENCE = "preference" # 偏好(喜好、习惯等) - OPINION = "opinion" # 观点(对事物的看法) - RELATIONSHIP = "relationship" # 关系(与他人的关系) - EMOTION = "emotion" # 情感状态 - KNOWLEDGE = "knowledge" # 知识信息 - SKILL = "skill" # 技能能力 - GOAL = "goal" # 目标计划 - EXPERIENCE = "experience" # 经验教训 - CONTEXTUAL = "contextual" # 上下文信息 + + PERSONAL_FACT = "personal_fact" # 个人事实(姓名、职业、住址等) + EVENT = "event" # 事件(重要经历、约会等) + PREFERENCE = "preference" # 偏好(喜好、习惯等) + OPINION = "opinion" # 观点(对事物的看法) + RELATIONSHIP = "relationship" # 关系(与他人的关系) + EMOTION = "emotion" # 情感状态 + KNOWLEDGE = "knowledge" # 知识信息 + SKILL = "skill" # 技能能力 + GOAL = "goal" # 目标计划 + EXPERIENCE = "experience" # 经验教训 + CONTEXTUAL = "contextual" # 上下文信息 class ConfidenceLevel(Enum): """置信度等级""" - LOW = 1 # 低置信度,可能不准确 - MEDIUM = 2 # 中等置信度,有一定依据 - HIGH = 3 # 高置信度,有明确来源 - VERIFIED = 4 # 已验证,非常可靠 + + LOW = 1 # 低置信度,可能不准确 + MEDIUM = 2 # 中等置信度,有一定依据 + HIGH = 3 # 高置信度,有明确来源 + VERIFIED = 4 # 已验证,非常可靠 class ImportanceLevel(Enum): """重要性等级""" - LOW = 1 # 低重要性,普通信息 - NORMAL = 2 # 一般重要性,日常信息 - HIGH = 3 # 高重要性,重要信息 - CRITICAL = 4 # 关键重要性,核心信息 + + LOW = 1 # 低重要性,普通信息 + NORMAL = 2 # 一般重要性,日常信息 + HIGH = 3 # 高重要性,重要信息 + CRITICAL = 4 # 关键重要性,核心信息 @dataclass @@ -61,12 +63,7 @@ class ContentStructure: def to_dict(self) -> Dict[str, Any]: """转换为字典格式""" - return { - "subject": self.subject, - "predicate": self.predicate, - "object": self.object, - "display": self.display - } + return {"subject": self.subject, "predicate": self.predicate, "object": self.object, "display": self.display} @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure": @@ -75,7 +72,7 @@ class ContentStructure: subject=data.get("subject", ""), predicate=data.get("predicate", ""), object=data.get("object", ""), - display=data.get("display", "") + display=data.get("display", ""), ) def to_subject_list(self) -> List[str]: @@ -98,24 +95,25 @@ class ContentStructure: @dataclass class MemoryMetadata: """记忆元数据 - 简化版本""" + # 基础信息 - memory_id: str # 唯一标识符 - user_id: str # 用户ID - chat_id: Optional[str] = None # 聊天ID(群聊或私聊) + memory_id: str # 唯一标识符 + user_id: str # 用户ID + chat_id: Optional[str] = None # 聊天ID(群聊或私聊) # 时间信息 - created_at: float = 0.0 # 创建时间戳 - last_accessed: float = 0.0 # 最后访问时间 - last_modified: float = 0.0 # 最后修改时间 + created_at: float = 0.0 # 创建时间戳 + last_accessed: float = 0.0 # 最后访问时间 + last_modified: float = 0.0 # 最后修改时间 # 激活频率管理 - last_activation_time: float = 0.0 # 最后激活时间 - activation_frequency: int = 0 # 激活频率(单位时间内的激活次数) - total_activations: int = 0 # 总激活次数 + last_activation_time: float = 0.0 # 最后激活时间 + activation_frequency: int = 0 # 激活频率(单位时间内的激活次数) + total_activations: int = 0 # 总激活次数 # 统计信息 - access_count: int = 0 # 访问次数 - relevance_score: float = 0.0 # 相关度评分 + access_count: int = 0 # 访问次数 + relevance_score: float = 0.0 # 相关度评分 # 信心和重要性(核心字段) confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM @@ -123,10 +121,10 @@ class MemoryMetadata: # 遗忘机制相关 forgetting_threshold: float = 0.0 # 遗忘阈值(动态计算) - last_forgetting_check: float = 0.0 # 上次遗忘检查时间 + last_forgetting_check: float = 0.0 # 上次遗忘检查时间 # 来源信息 - source_context: Optional[str] = None # 来源上下文片段 + source_context: Optional[str] = None # 来源上下文片段 # 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source source: Optional[str] = None @@ -153,13 +151,13 @@ class MemoryMetadata: self.last_forgetting_check = current_time # 兼容性:如果旧字段 source 被使用,保证 source 与 source_context 同步 - if not getattr(self, 'source', None) and getattr(self, 'source_context', None): + if not getattr(self, "source", None) and getattr(self, "source_context", None): try: self.source = str(self.source_context) except Exception: self.source = None # 如果有 source 字段但 source_context 为空,也同步回去 - if not getattr(self, 'source_context', None) and getattr(self, 'source', None): + if not getattr(self, "source_context", None) and getattr(self, "source", None): try: self.source_context = str(self.source) except Exception: @@ -177,7 +175,6 @@ class MemoryMetadata: def _update_activation_frequency(self, current_time: float): """更新激活频率(24小时内的激活次数)""" - from datetime import datetime, timedelta # 如果超过24小时,重置激活频率 if current_time - self.last_activation_time > 86400: # 24小时 = 86400秒 @@ -251,7 +248,7 @@ class MemoryMetadata: "importance": self.importance.value, "forgetting_threshold": self.forgetting_threshold, "last_forgetting_check": self.last_forgetting_check, - "source_context": self.source_context + "source_context": self.source_context, } @classmethod @@ -273,7 +270,7 @@ class MemoryMetadata: importance=ImportanceLevel(data.get("importance", ImportanceLevel.NORMAL.value)), forgetting_threshold=data.get("forgetting_threshold", 0.0), last_forgetting_check=data.get("last_forgetting_check", 0), - source_context=data.get("source_context") + source_context=data.get("source_context"), ) @@ -285,21 +282,21 @@ class MemoryChunk: metadata: MemoryMetadata # 内容结构 - content: ContentStructure # 主谓宾结构 - memory_type: MemoryType # 记忆类型 + content: ContentStructure # 主谓宾结构 + memory_type: MemoryType # 记忆类型 # 扩展信息 - keywords: List[str] = field(default_factory=list) # 关键词列表 - tags: List[str] = field(default_factory=list) # 标签列表 - categories: List[str] = field(default_factory=list) # 分类列表 + keywords: List[str] = field(default_factory=list) # 关键词列表 + tags: List[str] = field(default_factory=list) # 标签列表 + categories: List[str] = field(default_factory=list) # 分类列表 # 语义信息 - embedding: Optional[List[float]] = None # 语义向量 - semantic_hash: Optional[str] = None # 语义哈希值 + embedding: Optional[List[float]] = None # 语义向量 + semantic_hash: Optional[str] = None # 语义哈希值 # 关联信息 related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表 - temporal_context: Optional[Dict[str, Any]] = None # 时间上下文 + temporal_context: Optional[Dict[str, Any]] = None # 时间上下文 def __post_init__(self): """后初始化处理""" @@ -317,7 +314,7 @@ class MemoryChunk: embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding])) hash_input = f"{content_str}|{embedding_str}" - hash_object = hashlib.sha256(hash_input.encode('utf-8')) + hash_object = hashlib.sha256(hash_input.encode("utf-8")) self.semantic_hash = hash_object.hexdigest()[:16] except Exception as e: @@ -430,7 +427,7 @@ class MemoryChunk: "embedding": self.embedding, "semantic_hash": self.semantic_hash, "related_memories": self.related_memories, - "temporal_context": self.temporal_context + "temporal_context": self.temporal_context, } @classmethod @@ -449,14 +446,14 @@ class MemoryChunk: embedding=data.get("embedding"), semantic_hash=data.get("semantic_hash"), related_memories=data.get("related_memories", []), - temporal_context=data.get("temporal_context") + temporal_context=data.get("temporal_context"), ) return chunk def to_json(self) -> str: """转换为JSON字符串""" - return orjson.dumps(self.to_dict(), ensure_ascii=False).decode('utf-8') + return orjson.dumps(self.to_dict(), ensure_ascii=False).decode("utf-8") @classmethod def from_json(cls, json_str: str) -> "MemoryChunk": @@ -530,7 +527,7 @@ class MemoryChunk: MemoryType.SKILL: "🛠️", MemoryType.GOAL: "🎯", MemoryType.EXPERIENCE: "💡", - MemoryType.CONTEXTUAL: "📝" + MemoryType.CONTEXTUAL: "📝", } emoji = type_emoji.get(self.memory_type, "📝") @@ -581,7 +578,7 @@ def create_memory_chunk( importance: ImportanceLevel = ImportanceLevel.NORMAL, confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM, display: Optional[str] = None, - **kwargs + **kwargs, ) -> MemoryChunk: """便捷的内存块创建函数""" metadata = MemoryMetadata( @@ -593,7 +590,7 @@ def create_memory_chunk( last_modified=0, confidence=confidence, importance=importance, - source_context=source_context + source_context=source_context, ) subjects: List[str] @@ -607,18 +604,8 @@ def create_memory_chunk( display_text = display or _build_display_text(subjects, predicate, obj) - content = ContentStructure( - subject=subject_payload, - predicate=predicate, - object=obj, - display=display_text - ) + content = ContentStructure(subject=subject_payload, predicate=predicate, object=obj, display=display_text) - chunk = MemoryChunk( - metadata=metadata, - content=content, - memory_type=memory_type, - **kwargs - ) + chunk = MemoryChunk(metadata=metadata, content=content, memory_type=memory_type, **kwargs) - return chunk \ No newline at end of file + return chunk diff --git a/src/chat/memory_system/memory_forgetting_engine.py b/src/chat/memory_system/memory_forgetting_engine.py index a52580a9f..3e243e433 100644 --- a/src/chat/memory_system/memory_forgetting_engine.py +++ b/src/chat/memory_system/memory_forgetting_engine.py @@ -6,8 +6,8 @@ import time import asyncio -from typing import List, Dict, Optional, Set, Tuple -from datetime import datetime, timedelta +from typing import List, Dict, Optional, Tuple +from datetime import datetime from dataclasses import dataclass from src.common.logger import get_logger @@ -19,6 +19,7 @@ logger = get_logger(__name__) @dataclass class ForgettingStats: """遗忘统计信息""" + total_checked: int = 0 marked_for_forgetting: int = 0 actually_forgotten: int = 0 @@ -30,34 +31,35 @@ class ForgettingStats: @dataclass class ForgettingConfig: """遗忘引擎配置""" + # 检查频率配置 - check_interval_hours: int = 24 # 定期检查间隔(小时) - batch_size: int = 100 # 批处理大小 + check_interval_hours: int = 24 # 定期检查间隔(小时) + batch_size: int = 100 # 批处理大小 # 遗忘阈值配置 - base_forgetting_days: float = 30.0 # 基础遗忘天数 - min_forgetting_days: float = 7.0 # 最小遗忘天数 - max_forgetting_days: float = 365.0 # 最大遗忘天数 + base_forgetting_days: float = 30.0 # 基础遗忘天数 + min_forgetting_days: float = 7.0 # 最小遗忘天数 + max_forgetting_days: float = 365.0 # 最大遗忘天数 # 重要程度权重 critical_importance_bonus: float = 45.0 # 关键重要性额外天数 - high_importance_bonus: float = 30.0 # 高重要性额外天数 - normal_importance_bonus: float = 15.0 # 一般重要性额外天数 - low_importance_bonus: float = 0.0 # 低重要性额外天数 + high_importance_bonus: float = 30.0 # 高重要性额外天数 + normal_importance_bonus: float = 15.0 # 一般重要性额外天数 + low_importance_bonus: float = 0.0 # 低重要性额外天数 # 置信度权重 verified_confidence_bonus: float = 30.0 # 已验证置信度额外天数 - high_confidence_bonus: float = 20.0 # 高置信度额外天数 - medium_confidence_bonus: float = 10.0 # 中等置信度额外天数 - low_confidence_bonus: float = 0.0 # 低置信度额外天数 + high_confidence_bonus: float = 20.0 # 高置信度额外天数 + medium_confidence_bonus: float = 10.0 # 中等置信度额外天数 + low_confidence_bonus: float = 0.0 # 低置信度额外天数 # 激活频率权重 activation_frequency_weight: float = 0.5 # 每次激活增加的天数权重 - max_frequency_bonus: float = 10.0 # 最大激活频率奖励天数 + max_frequency_bonus: float = 10.0 # 最大激活频率奖励天数 # 休眠配置 - dormant_threshold_days: int = 90 # 休眠状态判定天数 - force_forget_dormant_days: int = 180 # 强制遗忘休眠记忆的天数 + dormant_threshold_days: int = 90 # 休眠状态判定天数 + force_forget_dormant_days: int = 180 # 强制遗忘休眠记忆的天数 class MemoryForgettingEngine: @@ -107,13 +109,12 @@ class MemoryForgettingEngine: # 激活频率权重 frequency_bonus = min( memory.metadata.activation_frequency * self.config.activation_frequency_weight, - self.config.max_frequency_bonus + self.config.max_frequency_bonus, ) threshold += frequency_bonus # 确保在合理范围内 - return max(self.config.min_forgetting_days, - min(threshold, self.config.max_forgetting_days)) + return max(self.config.min_forgetting_days, min(threshold, self.config.max_forgetting_days)) def should_forget_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool: """ @@ -265,8 +266,8 @@ class MemoryForgettingEngine: "actually_forgotten": self.stats.actually_forgotten, "dormant_memories": self.stats.dormant_memories, "check_duration": self.stats.check_duration, - "last_check_time": self.stats.last_check_time - } + "last_check_time": self.stats.last_check_time, + }, } def is_forgetting_check_needed(self) -> bool: @@ -302,7 +303,9 @@ class MemoryForgettingEngine: # 如果启用自动清理,执行实际的遗忘操作 if enable_auto_cleanup and (result["normal_forgetting"] or result["force_forgetting"]): - logger.info(f"检测到 {len(result['normal_forgetting'])} 条普通遗忘和 {len(result['force_forgetting'])} 条强制遗忘记忆") + logger.info( + f"检测到 {len(result['normal_forgetting'])} 条普通遗忘和 {len(result['force_forgetting'])} 条强制遗忘记忆" + ) # 这里可以调用实际的删除逻辑 # await self.cleanup_forgotten_memories(result["normal_forgetting"] + result["force_forgetting"]) @@ -318,14 +321,16 @@ class MemoryForgettingEngine: "marked_for_forgetting": self.stats.marked_for_forgetting, "actually_forgotten": self.stats.actually_forgotten, "dormant_memories": self.stats.dormant_memories, - "last_check_time": datetime.fromtimestamp(self.stats.last_check_time).isoformat() if self.stats.last_check_time else None, + "last_check_time": datetime.fromtimestamp(self.stats.last_check_time).isoformat() + if self.stats.last_check_time + else None, "last_check_duration": self.stats.check_duration, "config": { "check_interval_hours": self.config.check_interval_hours, "base_forgetting_days": self.config.base_forgetting_days, "min_forgetting_days": self.config.min_forgetting_days, - "max_forgetting_days": self.config.max_forgetting_days - } + "max_forgetting_days": self.config.max_forgetting_days, + }, } def reset_stats(self): @@ -349,4 +354,4 @@ memory_forgetting_engine = MemoryForgettingEngine() def get_memory_forgetting_engine() -> MemoryForgettingEngine: """获取全局遗忘引擎实例""" - return memory_forgetting_engine \ No newline at end of file + return memory_forgetting_engine diff --git a/src/chat/memory_system/memory_fusion.py b/src/chat/memory_system/memory_fusion.py index 54c77d5bc..3ecc4cd71 100644 --- a/src/chat/memory_system/memory_fusion.py +++ b/src/chat/memory_system/memory_fusion.py @@ -5,14 +5,12 @@ """ import time -from typing import Dict, List, Optional, Tuple, Set, Any +from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import ( - MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel -) +from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel logger = get_logger(__name__) @@ -20,6 +18,7 @@ logger = get_logger(__name__) @dataclass class FusionResult: """融合结果""" + original_count: int fused_count: int removed_duplicates: int @@ -31,6 +30,7 @@ class FusionResult: @dataclass class DuplicateGroup: """重复记忆组""" + group_id: str memories: List[MemoryChunk] similarity_matrix: List[List[float]] @@ -46,22 +46,20 @@ class MemoryFusionEngine: "total_fusions": 0, "memories_fused": 0, "duplicates_removed": 0, - "average_similarity": 0.0 + "average_similarity": 0.0, } # 融合策略配置 self.fusion_strategies = { - "semantic_similarity": True, # 语义相似性融合 - "temporal_proximity": True, # 时间接近性融合 - "logical_consistency": True, # 逻辑一致性融合 - "confidence_boosting": True, # 置信度提升 - "importance_preservation": True # 重要性保持 + "semantic_similarity": True, # 语义相似性融合 + "temporal_proximity": True, # 时间接近性融合 + "logical_consistency": True, # 逻辑一致性融合 + "confidence_boosting": True, # 置信度提升 + "importance_preservation": True, # 重要性保持 } async def fuse_memories( - self, - new_memories: List[MemoryChunk], - existing_memories: Optional[List[MemoryChunk]] = None + self, new_memories: List[MemoryChunk], existing_memories: Optional[List[MemoryChunk]] = None ) -> List[MemoryChunk]: """融合记忆列表""" start_time = time.time() @@ -73,9 +71,7 @@ class MemoryFusionEngine: logger.info(f"开始记忆融合,新记忆: {len(new_memories)},现有记忆: {len(existing_memories or [])}") # 1. 检测重复记忆组 - duplicate_groups = await self._detect_duplicate_groups( - new_memories, existing_memories or [] - ) + duplicate_groups = await self._detect_duplicate_groups(new_memories, existing_memories or []) if not duplicate_groups: fusion_time = time.time() - start_time @@ -110,9 +106,7 @@ class MemoryFusionEngine: return new_memories # 失败时返回原始记忆 async def _detect_duplicate_groups( - self, - new_memories: List[MemoryChunk], - existing_memories: List[MemoryChunk] + self, new_memories: List[MemoryChunk], existing_memories: List[MemoryChunk] ) -> List[DuplicateGroup]: """检测重复记忆组""" all_memories = new_memories + existing_memories @@ -125,16 +119,12 @@ class MemoryFusionEngine: continue # 创建新的重复组 - group = DuplicateGroup( - group_id=f"group_{len(groups)}", - memories=[memory1], - similarity_matrix=[[1.0]] - ) + group = DuplicateGroup(group_id=f"group_{len(groups)}", memories=[memory1], similarity_matrix=[[1.0]]) processed_ids.add(memory1.memory_id) # 寻找相似记忆 - for j, memory2 in enumerate(all_memories[i+1:], i+1): + for j, memory2 in enumerate(all_memories[i + 1 :], i + 1): if memory2.memory_id in processed_ids: continue @@ -182,9 +172,7 @@ class MemoryFusionEngine: # 5. 时间接近性 if self.fusion_strategies["temporal_proximity"]: - temporal_sim = self._calculate_temporal_similarity( - mem1.metadata.created_at, mem2.metadata.created_at - ) + temporal_sim = self._calculate_temporal_similarity(mem1.metadata.created_at, mem2.metadata.created_at) similarity_scores.append(("temporal", temporal_sim)) # 6. 逻辑一致性 @@ -193,14 +181,7 @@ class MemoryFusionEngine: similarity_scores.append(("logical", logical_sim)) # 计算加权平均相似度 - weights = { - "semantic": 0.35, - "text": 0.25, - "keyword": 0.15, - "type": 0.10, - "temporal": 0.10, - "logical": 0.05 - } + weights = {"semantic": 0.35, "text": 0.25, "keyword": 0.15, "type": 0.10, "temporal": 0.10, "logical": 0.05} weighted_sum = 0.0 total_weight = 0.0 @@ -276,9 +257,7 @@ class MemoryFusionEngine: # 宾语相似性 if isinstance(mem1.content.object, str) and isinstance(mem2.content.object, str): - object_sim = self._calculate_text_similarity( - str(mem1.content.object), str(mem2.content.object) - ) + object_sim = self._calculate_text_similarity(str(mem1.content.object), str(mem2.content.object)) consistency_score += object_sim * 0.3 return consistency_score @@ -349,11 +328,7 @@ class MemoryFusionEngine: # 返回置信度最高的记忆 return max(group.memories, key=lambda m: m.metadata.confidence.value) - async def _merge_memory_attributes( - self, - base_memory: MemoryChunk, - memories: List[MemoryChunk] - ) -> MemoryChunk: + async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: List[MemoryChunk]) -> MemoryChunk: """合并记忆属性""" # 创建基础记忆的深拷贝 fused_memory = MemoryChunk.from_dict(base_memory.to_dict()) @@ -436,7 +411,7 @@ class MemoryFusionEngine: "earliest_timestamp": earliest_time, "latest_timestamp": latest_time, "time_span_hours": (latest_time - earliest_time) / 3600, - "source_memories": len(memories) + "source_memories": len(memories), } # 合并其他上下文信息 @@ -451,9 +426,7 @@ class MemoryFusionEngine: return merged_context async def incremental_fusion( - self, - new_memory: MemoryChunk, - existing_memories: List[MemoryChunk] + self, new_memory: MemoryChunk, existing_memories: List[MemoryChunk] ) -> Tuple[MemoryChunk, List[MemoryChunk]]: """增量融合(单个新记忆与现有记忆融合)""" # 寻找相似记忆 @@ -478,7 +451,7 @@ class MemoryFusionEngine: group = DuplicateGroup( group_id=f"incremental_{int(time.time())}", memories=[new_memory, best_match], - similarity_matrix=[[1.0, similarity], [similarity, 1.0]] + similarity_matrix=[[1.0, similarity], [similarity, 1.0]], ) # 执行融合 @@ -530,5 +503,5 @@ class MemoryFusionEngine: "total_fusions": 0, "memories_fused": 0, "duplicates_removed": 0, - "average_similarity": 0.0 - } \ No newline at end of file + "average_similarity": 0.0, + } diff --git a/src/chat/memory_system/memory_manager.py b/src/chat/memory_system/memory_manager.py index 80b9f7dcf..4c6b2696e 100644 --- a/src/chat/memory_system/memory_manager.py +++ b/src/chat/memory_system/memory_manager.py @@ -9,12 +9,9 @@ from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass from src.common.logger import get_logger -from src.config.config import global_config from src.chat.memory_system.memory_system import MemorySystem from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType -from src.chat.memory_system.memory_system import ( - initialize_memory_system -) +from src.chat.memory_system.memory_system import initialize_memory_system logger = get_logger(__name__) @@ -22,6 +19,7 @@ logger = get_logger(__name__) @dataclass class MemoryResult: """记忆查询结果""" + content: str memory_type: str confidence: float @@ -67,6 +65,7 @@ class MemoryManager: # 获取LLM模型 from src.llm_models.utils_model import LLMRequest from src.config.config import model_config + llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory") # 初始化记忆系统 @@ -121,7 +120,7 @@ class MemoryManager: max_memory_num: int = 3, max_memory_length: int = 2, time_weight: float = 1.0, - keyword_weight: float = 1.0 + keyword_weight: float = 1.0, ) -> List[Tuple[str, str]]: """从文本获取相关记忆 - 兼容原有接口""" if not self.is_initialized or not self.memory_system: @@ -131,14 +130,11 @@ class MemoryManager: # 使用增强记忆系统检索 context = { "chat_id": chat_id, - "expected_memory_types": [MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE] + "expected_memory_types": [MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE], } relevant_memories = await self.memory_system.retrieve_relevant_memories( - query=text, - user_id=user_id, - context=context, - limit=max_memory_num + query=text, user_id=user_id, context=context, limit=max_memory_num ) # 转换为原有格式 (topic, content) @@ -156,11 +152,7 @@ class MemoryManager: return [] async def get_memory_from_topic( - self, - valid_keywords: List[str], - max_memory_num: int = 3, - max_memory_length: int = 2, - max_depth: int = 3 + self, valid_keywords: List[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 ) -> List[Tuple[str, str]]: """从关键词获取记忆 - 兼容原有接口""" if not self.is_initialized or not self.memory_system: @@ -177,15 +169,15 @@ class MemoryManager: MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE, - MemoryType.OPINION - ] + MemoryType.OPINION, + ], } relevant_memories = await self.memory_system.retrieve_relevant_memories( query_text=query_text, user_id="default_user", # 可以根据实际需要传递 context=context, - limit=max_memory_num + limit=max_memory_num, ) # 转换为原有格式 (topic, content) @@ -216,11 +208,7 @@ class MemoryManager: return [] async def process_conversation( - self, - conversation_text: str, - context: Dict[str, Any], - user_id: str, - timestamp: Optional[float] = None + self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: Optional[float] = None ) -> List[MemoryChunk]: """处理对话并构建记忆 - 新增功能""" if not self.is_initialized or not self.memory_system: @@ -247,11 +235,7 @@ class MemoryManager: return [] async def get_enhanced_memory_context( - self, - query_text: str, - user_id: str, - context: Optional[Dict[str, Any]] = None, - limit: int = 5 + self, query_text: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: int = 5 ) -> List[MemoryResult]: """获取增强记忆上下文 - 新增功能""" if not self.is_initialized or not self.memory_system: @@ -259,10 +243,7 @@ class MemoryManager: try: relevant_memories = await self.memory_system.retrieve_relevant_memories( - query=query_text, - user_id=None, - context=context or {}, - limit=limit + query=query_text, user_id=None, context=context or {}, limit=limit ) results = [] @@ -276,7 +257,7 @@ class MemoryManager: timestamp=memory.metadata.created_at, source="enhanced_memory", relevance_score=memory.metadata.relevance_score, - structure=structure + structure=structure, ) results.append(result) @@ -342,7 +323,9 @@ class MemoryManager: return None return f"{subject}的职业是{profession}" if predicate == "lives_in": - location = self._extract_from_object(obj_value, ["location", "city", "place"]) or self._format_object(obj_value) + location = self._extract_from_object(obj_value, ["location", "city", "place"]) or self._format_object( + obj_value + ) location = self._clean_text(location) if not location: return None @@ -385,7 +368,9 @@ class MemoryManager: return None return f"{subject}最喜欢{favorite}" if predicate == "mentioned_event": - event_text = self._extract_from_object(obj_value, ["event_text", "description"]) or self._format_object(obj_value) + event_text = self._extract_from_object(obj_value, ["event_text", "description"]) or self._format_object( + obj_value + ) event_text = self._clean_text(self._truncate(event_text)) if not event_text: return None @@ -494,4 +479,4 @@ class MemoryManager: # 全局记忆管理器实例 -memory_manager = MemoryManager() \ No newline at end of file +memory_manager = MemoryManager() diff --git a/src/chat/memory_system/memory_metadata_index.py b/src/chat/memory_system/memory_metadata_index.py index 32104ffab..ad27971a6 100644 --- a/src/chat/memory_system/memory_metadata_index.py +++ b/src/chat/memory_system/memory_metadata_index.py @@ -12,7 +12,6 @@ from dataclasses import dataclass, asdict from datetime import datetime from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import MemoryType, ImportanceLevel, ConfidenceLevel logger = get_logger(__name__) @@ -20,22 +19,23 @@ logger = get_logger(__name__) @dataclass class MemoryMetadataIndexEntry: """元数据索引条目(轻量级,只用于快速过滤)""" + memory_id: str user_id: str - + # 分类信息 memory_type: str # MemoryType.value subjects: List[str] # 主语列表 objects: List[str] # 宾语列表 keywords: List[str] # 关键词列表 tags: List[str] # 标签列表 - + # 数值字段(用于范围过滤) importance: int # ImportanceLevel.value (1-4) confidence: int # ConfidenceLevel.value (1-4) created_at: float # 创建时间戳 access_count: int # 访问次数 - + # 可选字段 chat_id: Optional[str] = None content_preview: Optional[str] = None # 内容预览(前100字符) @@ -43,152 +43,152 @@ class MemoryMetadataIndexEntry: class MemoryMetadataIndex: """记忆元数据索引管理器""" - + def __init__(self, index_file: str = "data/memory_metadata_index.json"): self.index_file = Path(index_file) self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry - + # 倒排索引(用于快速查找) self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids} self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids} self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids} self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids} - + self.lock = threading.RLock() - + # 加载已有索引 self._load_index() - + def _load_index(self): """从文件加载索引""" if not self.index_file.exists(): logger.info(f"元数据索引文件不存在,将创建新索引: {self.index_file}") return - + try: - with open(self.index_file, 'rb') as f: + with open(self.index_file, "rb") as f: data = orjson.loads(f.read()) - + # 重建内存索引 - for entry_data in data.get('entries', []): + for entry_data in data.get("entries", []): entry = MemoryMetadataIndexEntry(**entry_data) self.index[entry.memory_id] = entry self._update_inverted_indices(entry) - + logger.info(f"✅ 加载元数据索引: {len(self.index)} 条记录") - + except Exception as e: logger.error(f"加载元数据索引失败: {e}", exc_info=True) - + def _save_index(self): """保存索引到文件""" try: # 确保目录存在 self.index_file.parent.mkdir(parents=True, exist_ok=True) - + # 序列化所有条目 entries = [asdict(entry) for entry in self.index.values()] data = { - 'version': '1.0', - 'count': len(entries), - 'last_updated': datetime.now().isoformat(), - 'entries': entries + "version": "1.0", + "count": len(entries), + "last_updated": datetime.now().isoformat(), + "entries": entries, } - + # 写入文件(使用临时文件 + 原子重命名) - temp_file = self.index_file.with_suffix('.tmp') - with open(temp_file, 'wb') as f: + temp_file = self.index_file.with_suffix(".tmp") + with open(temp_file, "wb") as f: f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2)) - + temp_file.replace(self.index_file) logger.debug(f"元数据索引已保存: {len(entries)} 条记录") - + except Exception as e: logger.error(f"保存元数据索引失败: {e}", exc_info=True) - + def _update_inverted_indices(self, entry: MemoryMetadataIndexEntry): """更新倒排索引""" memory_id = entry.memory_id - + # 类型索引 self.type_index.setdefault(entry.memory_type, set()).add(memory_id) - + # 主语索引 for subject in entry.subjects: subject_norm = subject.strip().lower() if subject_norm: self.subject_index.setdefault(subject_norm, set()).add(memory_id) - + # 关键词索引 for keyword in entry.keywords: keyword_norm = keyword.strip().lower() if keyword_norm: self.keyword_index.setdefault(keyword_norm, set()).add(memory_id) - + # 标签索引 for tag in entry.tags: tag_norm = tag.strip().lower() if tag_norm: self.tag_index.setdefault(tag_norm, set()).add(memory_id) - + def add_or_update(self, entry: MemoryMetadataIndexEntry): """添加或更新索引条目""" with self.lock: # 如果已存在,先从倒排索引中移除旧记录 if entry.memory_id in self.index: self._remove_from_inverted_indices(entry.memory_id) - + # 添加新记录 self.index[entry.memory_id] = entry self._update_inverted_indices(entry) - + def _remove_from_inverted_indices(self, memory_id: str): """从倒排索引中移除记录""" if memory_id not in self.index: return - + entry = self.index[memory_id] - + # 从类型索引移除 if entry.memory_type in self.type_index: self.type_index[entry.memory_type].discard(memory_id) - + # 从主语索引移除 for subject in entry.subjects: subject_norm = subject.strip().lower() if subject_norm in self.subject_index: self.subject_index[subject_norm].discard(memory_id) - + # 从关键词索引移除 for keyword in entry.keywords: keyword_norm = keyword.strip().lower() if keyword_norm in self.keyword_index: self.keyword_index[keyword_norm].discard(memory_id) - + # 从标签索引移除 for tag in entry.tags: tag_norm = tag.strip().lower() if tag_norm in self.tag_index: self.tag_index[tag_norm].discard(memory_id) - + def remove(self, memory_id: str): """移除索引条目""" with self.lock: if memory_id in self.index: self._remove_from_inverted_indices(memory_id) del self.index[memory_id] - + def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]): """批量添加或更新""" with self.lock: for entry in entries: self.add_or_update(entry) - + def save(self): """保存索引到磁盘""" with self.lock: self._save_index() - + def search( self, memory_types: Optional[List[str]] = None, @@ -201,11 +201,11 @@ class MemoryMetadataIndex: created_before: Optional[float] = None, user_id: Optional[str] = None, limit: Optional[int] = None, - flexible_mode: bool = True # 新增:灵活匹配模式 + flexible_mode: bool = True, # 新增:灵活匹配模式 ) -> List[str]: """ 搜索符合条件的记忆ID列表(支持模糊匹配) - + Returns: List[str]: 符合条件的 memory_id 列表 """ @@ -219,7 +219,7 @@ class MemoryMetadataIndex: created_after=created_after, created_before=created_before, user_id=user_id, - limit=limit + limit=limit, ) else: return self._search_strict( @@ -232,7 +232,7 @@ class MemoryMetadataIndex: created_after=created_after, created_before=created_before, user_id=user_id, - limit=limit + limit=limit, ) def _search_flexible( @@ -243,7 +243,7 @@ class MemoryMetadataIndex: created_before: Optional[float] = None, user_id: Optional[str] = None, limit: Optional[int] = None, - **kwargs # 接受但不使用的参数 + **kwargs, # 接受但不使用的参数 ) -> List[str]: """ 灵活搜索模式:2/4项匹配即可,支持部分匹配 @@ -258,10 +258,7 @@ class MemoryMetadataIndex: """ # 用户过滤(必选) if user_id: - base_candidates = { - mid for mid, entry in self.index.items() - if entry.user_id == user_id - } + base_candidates = {mid for mid, entry in self.index.items() if entry.user_id == user_id} else: base_candidates = set(self.index.keys()) @@ -386,7 +383,7 @@ class MemoryMetadataIndex: created_after: Optional[float] = None, created_before: Optional[float] = None, user_id: Optional[str] = None, - limit: Optional[int] = None + limit: Optional[int] = None, ) -> List[str]: """严格搜索模式(原有逻辑)""" # 初始候选集(所有记忆) @@ -394,10 +391,7 @@ class MemoryMetadataIndex: # 用户过滤(必选) if user_id: - candidate_ids = { - mid for mid, entry in self.index.items() - if entry.user_id == user_id - } + candidate_ids = {mid for mid, entry in self.index.items() if entry.user_id == user_id} else: candidate_ids = set(self.index.keys()) @@ -447,7 +441,8 @@ class MemoryMetadataIndex: # 重要性过滤 if importance_min is not None or importance_max is not None: importance_ids = { - mid for mid in candidate_ids + mid + for mid in candidate_ids if (importance_min is None or self.index[mid].importance >= importance_min) and (importance_max is None or self.index[mid].importance <= importance_max) } @@ -456,41 +451,37 @@ class MemoryMetadataIndex: # 时间范围过滤 if created_after is not None or created_before is not None: time_ids = { - mid for mid in candidate_ids + mid + for mid in candidate_ids if (created_after is None or self.index[mid].created_at >= created_after) and (created_before is None or self.index[mid].created_at <= created_before) } candidate_ids &= time_ids # 转换为列表并排序(按创建时间倒序) - result_ids = sorted( - candidate_ids, - key=lambda mid: self.index[mid].created_at, - reverse=True - ) + result_ids = sorted(candidate_ids, key=lambda mid: self.index[mid].created_at, reverse=True) # 限制数量 if limit: result_ids = result_ids[:limit] logger.debug( - f"[严格搜索] types={memory_types}, subjects={subjects}, " - f"keywords={keywords}, 返回={len(result_ids)}条" + f"[严格搜索] types={memory_types}, subjects={subjects}, keywords={keywords}, 返回={len(result_ids)}条" ) return result_ids - + def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]: """获取单个索引条目""" return self.index.get(memory_id) - + def get_stats(self) -> Dict[str, Any]: """获取索引统计信息""" with self.lock: return { - 'total_memories': len(self.index), - 'types': {mtype: len(ids) for mtype, ids in self.type_index.items()}, - 'subjects_count': len(self.subject_index), - 'keywords_count': len(self.keyword_index), - 'tags_count': len(self.tag_index), + "total_memories": len(self.index), + "types": {mtype: len(ids) for mtype, ids in self.type_index.items()}, + "subjects_count": len(self.subject_index), + "keywords_count": len(self.keyword_index), + "tags_count": len(self.tag_index), } diff --git a/src/chat/memory_system/memory_query_planner.py b/src/chat/memory_system/memory_query_planner.py index 690e26627..a8be9d951 100644 --- a/src/chat/memory_system/memory_query_planner.py +++ b/src/chat/memory_system/memory_query_planner.py @@ -80,10 +80,7 @@ class MemoryQueryPlanner: return self._default_plan(query_text) def _default_plan(self, query_text: str) -> MemoryQueryPlan: - return MemoryQueryPlan( - semantic_query=query_text, - limit=self.default_limit - ) + return MemoryQueryPlan(semantic_query=query_text, limit=self.default_limit) def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan: semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query @@ -122,7 +119,7 @@ class MemoryQueryPlanner: recency_preference=self._safe_str(data.get("recency")) or "any", limit=self._safe_int(data.get("limit"), self.default_limit), emphasis=self._safe_str(data.get("emphasis")) or "balanced", - raw_plan=data + raw_plan=data, ) return plan @@ -154,18 +151,18 @@ class MemoryQueryPlanner: context_section = f""" -## 📋 未读消息上下文 (共{unread_context.get('total_count', 0)}条未读消息) +## 📋 未读消息上下文 (共{unread_context.get("total_count", 0)}条未读消息) ### 最近消息预览: {chr(10).join(message_previews)} ### 上下文关键词: -{', '.join(unread_keywords[:15]) if unread_keywords else '无'} +{", ".join(unread_keywords[:15]) if unread_keywords else "无"} ### 对话参与者: -{', '.join(unread_participants) if unread_participants else '无'} +{", ".join(unread_participants) if unread_participants else "无"} ### 上下文摘要: -{context_summary[:300] if context_summary else '无'} +{context_summary[:300] if context_summary else "无"} """ else: context_section = """ @@ -223,7 +220,7 @@ class MemoryQueryPlanner: start = stripped.find("{") end = stripped.rfind("}") if start != -1 and end != -1 and end > start: - return stripped[start:end + 1] + return stripped[start : end + 1] return stripped if stripped.startswith("{") and stripped.endswith("}") else None @@ -243,4 +240,4 @@ class MemoryQueryPlanner: return default return number except (TypeError, ValueError): - return default \ No newline at end of file + return default diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index 86134e9fc..4a275babd 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -35,6 +35,7 @@ GLOBAL_MEMORY_SCOPE = "global" class MemorySystemStatus(Enum): """记忆系统状态""" + INITIALIZING = "initializing" READY = "ready" BUILDING = "building" @@ -45,6 +46,7 @@ class MemorySystemStatus(Enum): @dataclass class MemorySystemConfig: """记忆系统配置""" + # 记忆构建配置 min_memory_length: int = 10 max_memory_length: int = 500 @@ -97,11 +99,9 @@ class MemorySystemConfig: max_memory_length=global_config.memory.max_memory_length, memory_value_threshold=global_config.memory.memory_value_threshold, min_build_interval_seconds=getattr(global_config.memory, "memory_build_interval", 300.0), - # 向量存储配置 vector_dimension=int(embedding_dimension), similarity_threshold=global_config.memory.vector_similarity_threshold, - # 召回配置 coarse_recall_limit=global_config.memory.metadata_filter_limit, fine_recall_limit=global_config.memory.vector_search_limit, @@ -112,21 +112,16 @@ class MemorySystemConfig: semantic_weight=global_config.memory.semantic_weight, context_weight=global_config.memory.context_weight, recency_weight=global_config.memory.recency_weight, - # 融合配置 fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold, - deduplication_window=timedelta(hours=global_config.memory.deduplication_window_hours) + deduplication_window=timedelta(hours=global_config.memory.deduplication_window_hours), ) class MemorySystem: """精准记忆系统核心类""" - def __init__( - self, - llm_model: Optional[LLMRequest] = None, - config: Optional[MemorySystemConfig] = None - ): + def __init__(self, llm_model: Optional[LLMRequest] = None, config: Optional[MemorySystemConfig] = None): self.config = config or MemorySystemConfig.from_global_config() self.llm_model = llm_model self.status = MemorySystemStatus.INITIALIZING @@ -175,16 +170,16 @@ class MemorySystem: extraction_task_config = value_task_config or fallback_task if value_task_config is None or extraction_task_config is None: - raise RuntimeError("无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。") + raise RuntimeError( + "无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。" + ) self.value_assessment_model = LLMRequest( - model_set=value_task_config, - request_type="memory.value_assessment" + model_set=value_task_config, request_type="memory.value_assessment" ) self.memory_extraction_model = LLMRequest( - model_set=extraction_task_config, - request_type="memory.extraction" + model_set=extraction_task_config, request_type="memory.extraction" ) # 初始化核心组件(简化版) @@ -198,13 +193,13 @@ class MemorySystem: memory_collection="unified_memory_v2", metadata_collection="memory_metadata_v2", similarity_threshold=self.config.similarity_threshold, - search_limit=getattr(global_config.memory, 'unified_storage_search_limit', 20), - batch_size=getattr(global_config.memory, 'unified_storage_batch_size', 100), - enable_caching=getattr(global_config.memory, 'unified_storage_enable_caching', True), - cache_size_limit=getattr(global_config.memory, 'unified_storage_cache_limit', 1000), - auto_cleanup_interval=getattr(global_config.memory, 'unified_storage_auto_cleanup_interval', 3600), - enable_forgetting=getattr(global_config.memory, 'enable_memory_forgetting', True), - retention_hours=getattr(global_config.memory, 'memory_retention_hours', 720) # 30天 + search_limit=getattr(global_config.memory, "unified_storage_search_limit", 20), + batch_size=getattr(global_config.memory, "unified_storage_batch_size", 100), + enable_caching=getattr(global_config.memory, "unified_storage_enable_caching", True), + cache_size_limit=getattr(global_config.memory, "unified_storage_cache_limit", 1000), + auto_cleanup_interval=getattr(global_config.memory, "unified_storage_auto_cleanup_interval", 3600), + enable_forgetting=getattr(global_config.memory, "enable_memory_forgetting", True), + retention_hours=getattr(global_config.memory, "memory_retention_hours", 720), # 30天 ) try: @@ -220,32 +215,27 @@ class MemorySystem: # 从全局配置创建遗忘引擎配置 forgetting_config = ForgettingConfig( # 检查频率配置 - check_interval_hours=getattr(global_config.memory, 'forgetting_check_interval_hours', 24), + check_interval_hours=getattr(global_config.memory, "forgetting_check_interval_hours", 24), batch_size=100, # 固定值,暂不配置 - # 遗忘阈值配置 - base_forgetting_days=getattr(global_config.memory, 'base_forgetting_days', 30.0), - min_forgetting_days=getattr(global_config.memory, 'min_forgetting_days', 7.0), - max_forgetting_days=getattr(global_config.memory, 'max_forgetting_days', 365.0), - + base_forgetting_days=getattr(global_config.memory, "base_forgetting_days", 30.0), + min_forgetting_days=getattr(global_config.memory, "min_forgetting_days", 7.0), + max_forgetting_days=getattr(global_config.memory, "max_forgetting_days", 365.0), # 重要程度权重 - critical_importance_bonus=getattr(global_config.memory, 'critical_importance_bonus', 45.0), - high_importance_bonus=getattr(global_config.memory, 'high_importance_bonus', 30.0), - normal_importance_bonus=getattr(global_config.memory, 'normal_importance_bonus', 15.0), - low_importance_bonus=getattr(global_config.memory, 'low_importance_bonus', 0.0), - + critical_importance_bonus=getattr(global_config.memory, "critical_importance_bonus", 45.0), + high_importance_bonus=getattr(global_config.memory, "high_importance_bonus", 30.0), + normal_importance_bonus=getattr(global_config.memory, "normal_importance_bonus", 15.0), + low_importance_bonus=getattr(global_config.memory, "low_importance_bonus", 0.0), # 置信度权重 - verified_confidence_bonus=getattr(global_config.memory, 'verified_confidence_bonus', 30.0), - high_confidence_bonus=getattr(global_config.memory, 'high_confidence_bonus', 20.0), - medium_confidence_bonus=getattr(global_config.memory, 'medium_confidence_bonus', 10.0), - low_confidence_bonus=getattr(global_config.memory, 'low_confidence_bonus', 0.0), - + verified_confidence_bonus=getattr(global_config.memory, "verified_confidence_bonus", 30.0), + high_confidence_bonus=getattr(global_config.memory, "high_confidence_bonus", 20.0), + medium_confidence_bonus=getattr(global_config.memory, "medium_confidence_bonus", 10.0), + low_confidence_bonus=getattr(global_config.memory, "low_confidence_bonus", 0.0), # 激活频率权重 - activation_frequency_weight=getattr(global_config.memory, 'activation_frequency_weight', 0.5), - max_frequency_bonus=getattr(global_config.memory, 'max_frequency_bonus', 10.0), - + activation_frequency_weight=getattr(global_config.memory, "activation_frequency_weight", 0.5), + max_frequency_bonus=getattr(global_config.memory, "max_frequency_bonus", 10.0), # 休眠配置 - dormant_threshold_days=getattr(global_config.memory, 'dormant_threshold_days', 90) + dormant_threshold_days=getattr(global_config.memory, "dormant_threshold_days", 90), ) self.forgetting_engine = MemoryForgettingEngine(forgetting_config) @@ -253,17 +243,11 @@ class MemorySystem: planner_task_config = getattr(model_config.model_task_config, "utils_small", None) planner_model: Optional[LLMRequest] = None try: - planner_model = LLMRequest( - model_set=planner_task_config, - request_type="memory.query_planner" - ) + planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner") except Exception as planner_exc: logger.warning("查询规划模型初始化失败,将使用默认规划策略: %s", planner_exc, exc_info=True) - self.query_planner = MemoryQueryPlanner( - planner_model, - default_limit=self.config.final_recall_limit - ) + self.query_planner = MemoryQueryPlanner(planner_model, default_limit=self.config.final_recall_limit) # 统一存储已经自动加载数据,无需额外加载 logger.info("✅ 简化版记忆系统初始化完成") @@ -277,11 +261,7 @@ class MemorySystem: raise async def retrieve_memories_for_building( - self, - query_text: str, - user_id: Optional[str] = None, - context: Optional[Dict[str, Any]] = None, - limit: int = 5 + self, query_text: str, user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5 ) -> List[MemoryChunk]: """在构建记忆时检索相关记忆,使用统一存储系统 @@ -304,9 +284,7 @@ class MemorySystem: try: # 使用统一存储检索相似记忆 search_results = await self.unified_storage.search_similar_memories( - query_text=query_text, - limit=limit, - scope_id=user_id + query_text=query_text, limit=limit, scope_id=user_id ) # 转换为记忆对象 @@ -324,10 +302,7 @@ class MemorySystem: return [] async def build_memory_from_conversation( - self, - conversation_text: str, - context: Dict[str, Any], - timestamp: Optional[float] = None + self, conversation_text: str, context: Dict[str, Any], timestamp: Optional[float] = None ) -> List[MemoryChunk]: """从对话中构建记忆 @@ -383,7 +358,7 @@ class MemorySystem: conversation_text, normalized_context, GLOBAL_MEMORY_SCOPE, # 强制使用 global,不区分用户 - timestamp or time.time() + timestamp or time.time(), ) if not memory_chunks: @@ -393,10 +368,7 @@ class MemorySystem: # 3. 记忆融合与去重(包含与历史记忆的融合) existing_candidates = await self._collect_fusion_candidates(memory_chunks) - fused_chunks = await self.fusion_engine.fuse_memories( - memory_chunks, - existing_candidates - ) + fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks, existing_candidates) # 4. 存储记忆到统一存储 stored_count = await self._store_memories_unified(fused_chunks) @@ -459,11 +431,7 @@ class MemorySystem: return [] candidate_ids: Set[str] = set() - new_memory_ids = { - memory.memory_id - for memory in new_memories - if memory and getattr(memory, "memory_id", None) - } + new_memory_ids = {memory.memory_id for memory in new_memories if memory and getattr(memory, "memory_id", None)} # 基于指纹的直接匹配 for memory in new_memories: @@ -501,9 +469,7 @@ class MemorySystem: continue search_tasks.append( self.unified_storage.search_similar_memories( - query_text=display_text, - limit=8, - scope_id=GLOBAL_MEMORY_SCOPE + query_text=display_text, limit=8, scope_id=GLOBAL_MEMORY_SCOPE ) ) @@ -545,10 +511,7 @@ class MemorySystem: return existing_candidates - async def process_conversation_memory( - self, - context: Dict[str, Any] - ) -> Dict[str, Any]: + async def process_conversation_memory(self, context: Dict[str, Any]) -> Dict[str, Any]: """对外暴露的对话记忆处理接口,仅依赖上下文信息""" start_time = time.time() @@ -563,7 +526,9 @@ class MemorySystem: or "" ) - conversation_text = conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate) + conversation_text = ( + conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate) + ) timestamp = context.get("timestamp") if timestamp is None: @@ -573,9 +538,7 @@ class MemorySystem: normalized_context.setdefault("conversation_text", conversation_text) memories = await self.build_memory_from_conversation( - conversation_text=conversation_text, - context=normalized_context, - timestamp=timestamp + conversation_text=conversation_text, context=normalized_context, timestamp=timestamp ) processing_time = time.time() - start_time @@ -586,18 +549,13 @@ class MemorySystem: "created_memories": memories, "memory_count": memory_count, "processing_time": processing_time, - "status": self.status.value + "status": self.status.value, } except Exception as e: processing_time = time.time() - start_time logger.error(f"对话记忆处理失败: {e}", exc_info=True) - return { - "success": False, - "error": str(e), - "processing_time": processing_time, - "status": self.status.value - } + return {"success": False, "error": str(e), "processing_time": processing_time, "status": self.status.value} async def retrieve_relevant_memories( self, @@ -605,7 +563,7 @@ class MemorySystem: user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5, - **kwargs + **kwargs, ) -> List[MemoryChunk]: """检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)""" raw_query = query_text or kwargs.get("query") @@ -617,7 +575,7 @@ class MemorySystem: return [] context = context or {} - + # 所有记忆完全共享,统一使用 global 作用域,不区分用户 resolved_user_id = GLOBAL_MEMORY_SCOPE @@ -627,7 +585,7 @@ class MemorySystem: try: normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, None) effective_limit = self.config.final_recall_limit - + # === 阶段一:元数据粗筛(软性过滤) === coarse_filters = { "user_id": GLOBAL_MEMORY_SCOPE, # 必选:确保作用域正确 @@ -642,119 +600,126 @@ class MemorySystem: # 构建包含未读消息的增强上下文 enhanced_context = await self._build_enhanced_query_context(raw_query, normalized_context) query_plan = await self.query_planner.plan_query(raw_query, enhanced_context) - + # 使用LLM优化后的查询语句(更精确的语义表达) if getattr(query_plan, "semantic_query", None): optimized_query = query_plan.semantic_query - + # 构建JSON元数据过滤条件(用于阶段一粗筛) # 将查询规划的结果转换为元数据过滤条件 if getattr(query_plan, "memory_types", None): - metadata_filters['memory_types'] = [mt.value for mt in query_plan.memory_types] - + metadata_filters["memory_types"] = [mt.value for mt in query_plan.memory_types] + if getattr(query_plan, "subject_includes", None): - metadata_filters['subjects'] = query_plan.subject_includes - + metadata_filters["subjects"] = query_plan.subject_includes + if getattr(query_plan, "required_keywords", None): - metadata_filters['keywords'] = query_plan.required_keywords - + metadata_filters["keywords"] = query_plan.required_keywords + # 时间范围过滤 recency = getattr(query_plan, "recency_preference", "any") current_time = time.time() if recency == "recent": # 最近7天 - metadata_filters['created_after'] = current_time - (7 * 24 * 3600) + metadata_filters["created_after"] = current_time - (7 * 24 * 3600) elif recency == "historical": # 30天以前 - metadata_filters['created_before'] = current_time - (30 * 24 * 3600) - + metadata_filters["created_before"] = current_time - (30 * 24 * 3600) + # 添加用户ID到元数据过滤 - metadata_filters['user_id'] = GLOBAL_MEMORY_SCOPE - + metadata_filters["user_id"] = GLOBAL_MEMORY_SCOPE + logger.debug(f"[阶段一] 查询优化: '{raw_query}' → '{optimized_query}'") logger.debug(f"[阶段一] 元数据过滤条件: {metadata_filters}") - + except Exception as plan_exc: logger.warning("查询规划失败,使用原始查询: %s", plan_exc, exc_info=True) # 即使查询规划失败,也保留基本的user_id过滤 - metadata_filters = {'user_id': GLOBAL_MEMORY_SCOPE} + metadata_filters = {"user_id": GLOBAL_MEMORY_SCOPE} # === 阶段二:向量精筛 === coarse_limit = self.config.coarse_recall_limit # 粗筛阶段返回更多候选 - + logger.debug(f"[阶段二] 开始向量搜索: query='{optimized_query[:60]}...', limit={coarse_limit}") - + search_results = await self.unified_storage.search_similar_memories( query_text=optimized_query, limit=coarse_limit, filters=coarse_filters, # ChromaDB where条件(保留兼容) - metadata_filters=metadata_filters # JSON元数据索引过滤 + metadata_filters=metadata_filters, # JSON元数据索引过滤 ) - + logger.info(f"[阶段二] 向量搜索完成: 返回 {len(search_results)} 条候选") # === 阶段三:综合重排 === scored_memories = [] current_time = time.time() - + for memory, vector_similarity in search_results: # 1. 向量相似度得分(已归一化到 0-1) vector_score = vector_similarity - + # 2. 时效性得分(指数衰减,30天半衰期) age_seconds = current_time - memory.metadata.created_at age_days = age_seconds / (24 * 3600) # 使用 math.exp 而非 np.exp(避免依赖numpy) import math + recency_score = math.exp(-age_days / 30) - + # 3. 重要性得分(枚举值转换为归一化得分 0-1) # ImportanceLevel: LOW=1, NORMAL=2, HIGH=3, CRITICAL=4 importance_enum = memory.metadata.importance - if hasattr(importance_enum, 'value'): + if hasattr(importance_enum, "value"): # 枚举类型,转换为0-1范围:(value - 1) / 3 importance_score = (importance_enum.value - 1) / 3.0 else: # 如果已经是数值,直接使用 importance_score = float(importance_enum) if importance_enum else 0.5 - + # 4. 访问频率得分(归一化,访问10次以上得满分) access_count = memory.metadata.access_count frequency_score = min(access_count / 10.0, 1.0) - + # 综合得分(加权平均) final_score = ( - self.config.vector_weight * vector_score + - self.config.recency_weight * recency_score + - self.config.context_weight * importance_score + - 0.1 * frequency_score # 访问频率权重(固定10%) + self.config.vector_weight * vector_score + + self.config.recency_weight * recency_score + + self.config.context_weight * importance_score + + 0.1 * frequency_score # 访问频率权重(固定10%) ) - - scored_memories.append((memory, final_score, { - "vector": vector_score, - "recency": recency_score, - "importance": importance_score, - "frequency": frequency_score, - "final": final_score - })) - + + scored_memories.append( + ( + memory, + final_score, + { + "vector": vector_score, + "recency": recency_score, + "importance": importance_score, + "frequency": frequency_score, + "final": final_score, + }, + ) + ) + # 更新访问记录 memory.update_access() # 按综合得分排序 scored_memories.sort(key=lambda x: x[1], reverse=True) - + # 返回 Top-K final_memories = [mem for mem, score, details in scored_memories[:effective_limit]] - + retrieval_time = time.time() - start_time # 详细日志 if scored_memories: - logger.info(f"[阶段三] 综合重排完成: Top 3 得分详情") + logger.info("[阶段三] 综合重排完成: Top 3 得分详情") for i, (mem, score, details) in enumerate(scored_memories[:3], 1): try: - summary = mem.content[:60] if hasattr(mem, 'content') and mem.content else "" + summary = mem.content[:60] if hasattr(mem, "content") and mem.content else "" except: summary = "" logger.info( @@ -803,15 +768,12 @@ class MemorySystem: start = stripped.find("{") end = stripped.rfind("}") if start != -1 and end != -1 and end > start: - return stripped[start:end + 1].strip() + return stripped[start : end + 1].strip() return stripped if stripped.startswith("{") and stripped.endswith("}") else None def _normalize_context( - self, - raw_context: Optional[Dict[str, Any]], - user_id: Optional[str], - timestamp: Optional[float] + self, raw_context: Optional[Dict[str, Any]], user_id: Optional[str], timestamp: Optional[float] ) -> Dict[str, Any]: """标准化上下文,确保必备字段存在且格式正确""" context: Dict[str, Any] = {} @@ -850,9 +812,7 @@ class MemorySystem: # 历史窗口配置 window_candidate = ( - context.get("history_limit") - or context.get("history_window") - or context.get("memory_history_limit") + context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit") ) if window_candidate is not None: try: @@ -888,7 +848,9 @@ class MemorySystem: enhanced_context["unread_messages_context"] = unread_messages_summary enhanced_context["has_unread_context"] = True - logger.debug(f"为查询规划构建了增强上下文,包含 {len(unread_messages_summary.get('messages', []))} 条未读消息") + logger.debug( + f"为查询规划构建了增强上下文,包含 {len(unread_messages_summary.get('messages', []))} 条未读消息" + ) else: enhanced_context["has_unread_context"] = False logger.debug("未找到未读消息,使用基础上下文进行查询规划") @@ -934,26 +896,30 @@ class MemorySystem: for msg in unread_messages[:10]: # 限制处理最近10条未读消息 try: # 提取消息内容 - content = (getattr(msg, "processed_plain_text", None) or - getattr(msg, "display_message", None) or "") + content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None) or "" if not content: continue # 提取发送者信息 sender_name = "未知用户" if hasattr(msg, "user_info") and msg.user_info: - sender_name = (getattr(msg.user_info, "user_nickname", None) or - getattr(msg.user_info, "user_cardname", None) or - getattr(msg.user_info, "user_id", None) or "未知用户") + sender_name = ( + getattr(msg.user_info, "user_nickname", None) + or getattr(msg.user_info, "user_cardname", None) + or getattr(msg.user_info, "user_id", None) + or "未知用户" + ) participant_names.add(sender_name) # 添加到消息摘要 - messages_summary.append({ - "sender": sender_name, - "content": content[:200], # 限制长度避免过长 - "timestamp": getattr(msg, "time", None) - }) + messages_summary.append( + { + "sender": sender_name, + "content": content[:200], # 限制长度避免过长 + "timestamp": getattr(msg, "time", None), + } + ) # 提取关键词(简单实现) content_lower = content.lower() @@ -975,10 +941,12 @@ class MemorySystem: "processed_count": len(messages_summary), "keywords": list(all_keywords)[:20], # 最多20个关键词 "participants": list(participant_names), - "context_summary": self._build_unread_context_summary(messages_summary) + "context_summary": self._build_unread_context_summary(messages_summary), } - logger.debug(f"收集到未读消息上下文: {len(messages_summary)}条消息,{len(all_keywords)}个关键词,{len(participant_names)}个参与者") + logger.debug( + f"收集到未读消息上下文: {len(messages_summary)}条消息,{len(all_keywords)}个关键词,{len(participant_names)}个参与者" + ) return unread_context except Exception as e: @@ -1051,10 +1019,7 @@ class MemorySystem: if user_id and fallback_text: try: relevant_memories = await self.retrieve_memories_for_building( - query_text=fallback_text, - user_id=user_id, - context=context, - limit=3 + query_text=fallback_text, user_id=user_id, context=context, limit=3 ) if relevant_memories: @@ -1068,9 +1033,7 @@ class MemorySystem: memory_transcript = f"{memory_transcript}\n[当前消息] {cleaned_fallback}" logger.debug( - "使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s", - len(relevant_memories), - user_id + "使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s", len(relevant_memories), user_id ) return memory_transcript @@ -1087,11 +1050,7 @@ class MemorySystem: def _determine_history_limit(self, context: Dict[str, Any]) -> int: """确定历史消息获取数量,限制在30-50之间""" default_limit = 40 - candidate = ( - context.get("history_limit") - or context.get("history_window") - or context.get("memory_history_limit") - ) + candidate = context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit") if isinstance(candidate, str): try: @@ -1186,9 +1145,9 @@ class MemorySystem: {text} 上下文信息: -- 用户ID: {context.get('user_id', 'unknown')} -- 消息类型: {context.get('message_type', 'unknown')} -- 时间: {datetime.fromtimestamp(context.get('timestamp', time.time()))} +- 用户ID: {context.get("user_id", "unknown")} +- 消息类型: {context.get("message_type", "unknown")} +- 时间: {datetime.fromtimestamp(context.get("timestamp", time.time()))} ## 📋 评估要求: @@ -1214,9 +1173,7 @@ class MemorySystem: }} """ - response, _ = await self.value_assessment_model.generate_response_async( - prompt, temperature=0.3 - ) + response, _ = await self.value_assessment_model.generate_response_async(prompt, temperature=0.3) # 解析响应 try: @@ -1236,7 +1193,7 @@ class MemorySystem: return max(0.0, min(1.0, value_score)) except (orjson.JSONDecodeError, ValueError) as e: - preview = response[:200].replace('\n', ' ') + preview = response[:200].replace("\n", " ") logger.warning(f"解析价值评估响应失败: {e}, 响应片段: {preview}") return 0.5 # 默认中等价值 @@ -1331,13 +1288,15 @@ class MemorySystem: else: obj_part = str(obj).strip() - base = "|".join([ - str(memory.user_id or "unknown"), - memory.memory_type.value, - subject_part, - predicate_part, - obj_part, - ]) + base = "|".join( + [ + str(memory.user_id or "unknown"), + memory.memory_type.value, + subject_part, + predicate_part, + obj_part, + ] + ) return hashlib.sha256(base.encode("utf-8")).hexdigest() @@ -1352,7 +1311,7 @@ class MemorySystem: "total_memories": self.total_memories, "last_build_time": self.last_build_time, "last_retrieval_time": self.last_retrieval_time, - "config": asdict(self.config) + "config": asdict(self.config), } def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float: @@ -1369,7 +1328,9 @@ class MemorySystem: keyword_overlap = 0.0 if context_keywords: memory_keywords = set(k.lower() for k in memory.keywords) - keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max(len(context_keywords), 1) + keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max( + len(context_keywords), 1 + ) importance_boost = (memory.metadata.importance.value - 1) / 3 * 0.1 confidence_boost = (memory.metadata.confidence.value - 1) / 3 * 0.05 @@ -1429,7 +1390,7 @@ class MemorySystem: """重建向量存储(如果需要)""" try: # 检查是否有记忆缓存数据 - if not hasattr(self.unified_storage, 'memory_cache') or not self.unified_storage.memory_cache: + if not hasattr(self.unified_storage, "memory_cache") or not self.unified_storage.memory_cache: logger.info("无记忆缓存数据,跳过向量存储重建") return @@ -1443,19 +1404,19 @@ class MemorySystem: memories_to_rebuild.append(memory) elif memory.text_content and memory.text_content.strip(): memories_to_rebuild.append(memory) - + if not memories_to_rebuild: logger.warning("没有找到可重建向量的记忆") return - + logger.info(f"准备为 {len(memories_to_rebuild)} 条记忆重建向量") - + # 批量重建向量 batch_size = 10 rebuild_count = 0 - + for i in range(0, len(memories_to_rebuild), batch_size): - batch = memories_to_rebuild[i:i + batch_size] + batch = memories_to_rebuild[i : i + batch_size] try: await self.unified_storage.store_memories(batch) rebuild_count += len(batch) @@ -1472,7 +1433,7 @@ class MemorySystem: final_count = self.unified_storage.storage_stats.get("total_vectors", 0) logger.info(f"✅ 向量存储重建完成,最终向量数量: {final_count}") - + except Exception as e: logger.error(f"❌ 向量存储重建失败: {e}", exc_info=True) @@ -1495,4 +1456,4 @@ async def initialize_memory_system(llm_model: Optional[LLMRequest] = None): if memory_system is None: memory_system = MemorySystem(llm_model=llm_model) await memory_system.initialize() - return memory_system \ No newline at end of file + return memory_system diff --git a/src/chat/memory_system/vector_memory_storage_v2.py b/src/chat/memory_system/vector_memory_storage_v2.py index 6c590d888..3c924ba30 100644 --- a/src/chat/memory_system/vector_memory_storage_v2.py +++ b/src/chat/memory_system/vector_memory_storage_v2.py @@ -15,12 +15,10 @@ import time import orjson import asyncio import threading -from typing import Dict, List, Optional, Tuple, Set, Any, Union -from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple, Any +from dataclasses import dataclass from datetime import datetime -from pathlib import Path -import numpy as np from src.common.logger import get_logger from src.common.vector_db import vector_db_service from src.chat.utils.utils import get_embedding @@ -33,6 +31,7 @@ logger = get_logger(__name__) # 全局枚举映射表缓存 _ENUM_MAPPINGS_CACHE = {} + def _build_enum_mapping(enum_class: type) -> Dict[str, Any]: """构建枚举类的完整映射表 @@ -49,10 +48,10 @@ def _build_enum_mapping(enum_class: type) -> Dict[str, Any]: return _ENUM_MAPPINGS_CACHE[cache_key] mapping = { - "name_to_enum": {}, # 枚举名称 -> 枚举实例 (HIGH -> ImportanceLevel.HIGH) - "value_to_enum": {}, # 整数值 -> 枚举实例 (3 -> ImportanceLevel.HIGH) - "value_str_to_enum": {}, # 字符串value -> 枚举实例 ("3" -> ImportanceLevel.HIGH) - "enum_value_to_name": {}, # 枚举实例 -> 名称映射 (反向) + "name_to_enum": {}, # 枚举名称 -> 枚举实例 (HIGH -> ImportanceLevel.HIGH) + "value_to_enum": {}, # 整数值 -> 枚举实例 (3 -> ImportanceLevel.HIGH) + "value_str_to_enum": {}, # 字符串value -> 枚举实例 ("3" -> ImportanceLevel.HIGH) + "enum_value_to_name": {}, # 枚举实例 -> 名称映射 (反向) "all_possible_strings": set(), # 所有可能的字符串表示 } @@ -77,7 +76,9 @@ def _build_enum_mapping(enum_class: type) -> Dict[str, Any]: # 缓存结果 _ENUM_MAPPINGS_CACHE[cache_key] = mapping - logger.debug(f"构建枚举映射表: {enum_class.__name__} -> {len(mapping['name_to_enum'])} 个名称映射, {len(mapping['value_to_enum'])} 个值映射") + logger.debug( + f"构建枚举映射表: {enum_class.__name__} -> {len(mapping['name_to_enum'])} 个名称映射, {len(mapping['value_to_enum'])} 个值映射" + ) return mapping @@ -85,42 +86,43 @@ def _build_enum_mapping(enum_class: type) -> Dict[str, Any]: @dataclass class VectorStorageConfig: """Vector存储配置""" + # 集合配置 memory_collection: str = "unified_memory_v2" metadata_collection: str = "memory_metadata_v2" - + # 检索配置 similarity_threshold: float = 0.5 # 降低阈值以提高召回率(0.5-0.6 是合理范围) search_limit: int = 20 batch_size: int = 100 - + # 性能配置 enable_caching: bool = True cache_size_limit: int = 1000 auto_cleanup_interval: int = 3600 # 1小时 - + # 遗忘配置 enable_forgetting: bool = True retention_hours: int = 24 * 30 # 30天 - + @classmethod def from_global_config(cls): """从全局配置创建实例""" from src.config.config import global_config - + memory_cfg = global_config.memory - + return cls( - memory_collection=getattr(memory_cfg, 'vector_db_memory_collection', 'unified_memory_v2'), - metadata_collection=getattr(memory_cfg, 'vector_db_metadata_collection', 'memory_metadata_v2'), - similarity_threshold=getattr(memory_cfg, 'vector_db_similarity_threshold', 0.5), - search_limit=getattr(memory_cfg, 'vector_db_search_limit', 20), - batch_size=getattr(memory_cfg, 'vector_db_batch_size', 100), - enable_caching=getattr(memory_cfg, 'vector_db_enable_caching', True), - cache_size_limit=getattr(memory_cfg, 'vector_db_cache_size_limit', 1000), - auto_cleanup_interval=getattr(memory_cfg, 'vector_db_auto_cleanup_interval', 3600), - enable_forgetting=getattr(memory_cfg, 'enable_memory_forgetting', True), - retention_hours=getattr(memory_cfg, 'vector_db_retention_hours', 720), + memory_collection=getattr(memory_cfg, "vector_db_memory_collection", "unified_memory_v2"), + metadata_collection=getattr(memory_cfg, "vector_db_metadata_collection", "memory_metadata_v2"), + similarity_threshold=getattr(memory_cfg, "vector_db_similarity_threshold", 0.5), + search_limit=getattr(memory_cfg, "vector_db_search_limit", 20), + batch_size=getattr(memory_cfg, "vector_db_batch_size", 100), + enable_caching=getattr(memory_cfg, "vector_db_enable_caching", True), + cache_size_limit=getattr(memory_cfg, "vector_db_cache_size_limit", 1000), + auto_cleanup_interval=getattr(memory_cfg, "vector_db_auto_cleanup_interval", 3600), + enable_forgetting=getattr(memory_cfg, "enable_memory_forgetting", True), + retention_hours=getattr(memory_cfg, "vector_db_retention_hours", 720), ) @@ -133,15 +135,16 @@ class VectorMemoryStorage: """ index = {} for memory in self.memory_cache.values(): - for kw in getattr(memory, 'keywords', []): + for kw in getattr(memory, "keywords", []): if not kw: continue kw_norm = kw.strip().lower() if kw_norm: - index.setdefault(kw_norm, []).append(getattr(memory.metadata, 'memory_id', None)) + index.setdefault(kw_norm, []).append(getattr(memory.metadata, "memory_id", None)) return index + """基于Vector DB的记忆存储系统""" - + def __init__(self, config: Optional[VectorStorageConfig] = None): # 默认从全局配置读取,如果没有传入config if config is None: @@ -153,25 +156,25 @@ class VectorMemoryStorage: self.config = VectorStorageConfig() else: self.config = config - + # 从配置中获取批处理大小和集合名称 self.batch_size = self.config.batch_size self.collection_name = self.config.memory_collection self.vector_db_service = vector_db_service - + # 内存缓存 self.memory_cache: Dict[str, MemoryChunk] = {} self.cache_timestamps: Dict[str, float] = {} self._cache = self.memory_cache # 别名,兼容旧代码 - + # 元数据索引管理器(JSON文件索引) self.metadata_index = MemoryMetadataIndex() - + # 遗忘引擎 self.forgetting_engine: Optional[MemoryForgettingEngine] = None if self.config.enable_forgetting: self.forgetting_engine = MemoryForgettingEngine() - + # 统计信息 self.stats = { "total_memories": 0, @@ -180,55 +183,48 @@ class VectorMemoryStorage: "total_searches": 0, "total_stores": 0, "last_cleanup_time": 0.0, - "forgetting_stats": {} + "forgetting_stats": {}, } - + # 线程锁 self._lock = threading.RLock() - + # 定时清理任务 self._cleanup_task = None self._stop_cleanup = False - + # 初始化系统 self._initialize_storage() self._start_cleanup_task() - + def _initialize_storage(self): """初始化Vector DB存储""" try: # 创建记忆集合 vector_db_service.get_or_create_collection( name=self.config.memory_collection, - metadata={ - "description": "统一记忆存储V2", - "hnsw:space": "cosine", - "version": "2.0" - } + metadata={"description": "统一记忆存储V2", "hnsw:space": "cosine", "version": "2.0"}, ) - + # 创建元数据集合(用于复杂查询) vector_db_service.get_or_create_collection( name=self.config.metadata_collection, - metadata={ - "description": "记忆元数据索引", - "hnsw:space": "cosine", - "version": "2.0" - } + metadata={"description": "记忆元数据索引", "hnsw:space": "cosine", "version": "2.0"}, ) - + # 获取当前记忆总数 self.stats["total_memories"] = vector_db_service.count(self.config.memory_collection) - + logger.info(f"Vector记忆存储初始化完成,当前记忆数: {self.stats['total_memories']}") - + except Exception as e: logger.error(f"Vector存储系统初始化失败: {e}", exc_info=True) raise - + def _start_cleanup_task(self): """启动定时清理任务""" if self.config.auto_cleanup_interval > 0: + def cleanup_worker(): while not self._stop_cleanup: try: @@ -237,47 +233,50 @@ class VectorMemoryStorage: asyncio.create_task(self._perform_auto_cleanup()) except Exception as e: logger.error(f"定时清理任务出错: {e}") - + self._cleanup_task = threading.Thread(target=cleanup_worker, daemon=True) self._cleanup_task.start() logger.info(f"定时清理任务已启动,间隔: {self.config.auto_cleanup_interval}秒") - + async def _perform_auto_cleanup(self): """执行自动清理""" try: current_time = time.time() - + # 清理过期缓存 if self.config.enable_caching: expired_keys = [ - memory_id for memory_id, timestamp in self.cache_timestamps.items() + memory_id + for memory_id, timestamp in self.cache_timestamps.items() if current_time - timestamp > 3600 # 1小时过期 ] - + for key in expired_keys: self.memory_cache.pop(key, None) self.cache_timestamps.pop(key, None) - + if expired_keys: logger.debug(f"清理了 {len(expired_keys)} 个过期缓存项") - + # 执行遗忘检查 if self.forgetting_engine: await self.perform_forgetting_check() - + self.stats["last_cleanup_time"] = current_time - + except Exception as e: logger.error(f"自动清理失败: {e}") - + def _memory_to_vector_format(self, memory: MemoryChunk) -> Dict[str, Any]: """将MemoryChunk转换为向量存储格式""" try: # 获取memory_id - memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None) - + memory_id = getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", None) + # 生成向量表示的文本 - display_text = getattr(memory, 'display', None) or getattr(memory, 'text_content', None) or str(memory.content) + display_text = ( + getattr(memory, "display", None) or getattr(memory, "text_content", None) or str(memory.content) + ) if not display_text.strip(): logger.warning(f"记忆 {memory_id} 缺少有效的显示文本") display_text = f"{memory.memory_type.value}: {', '.join(memory.subjects)}" @@ -296,16 +295,16 @@ class VectorMemoryStorage: "keywords": orjson.dumps(memory.keywords).decode("utf-8"), # 列表转JSON字符串 "tags": orjson.dumps(memory.tags).decode("utf-8"), # 列表转JSON字符串 "categories": orjson.dumps(memory.categories).decode("utf-8"), # 列表转JSON字符串 - "relevance_score": memory.metadata.relevance_score + "relevance_score": memory.metadata.relevance_score, } # 添加可选字段 if memory.metadata.source_context: metadata["source_context"] = str(memory.metadata.source_context) - + if memory.content.predicate: metadata["predicate"] = memory.content.predicate - + if memory.content.object: if isinstance(memory.content.object, (dict, list)): metadata["object"] = orjson.dumps(memory.content.object).decode() @@ -316,14 +315,14 @@ class VectorMemoryStorage: "id": memory_id, "embedding": None, # 将由vector_db_service生成 "metadata": metadata, - "document": display_text + "document": display_text, } except Exception as e: - memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown') + memory_id = getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", "unknown") logger.error(f"转换记忆 {memory_id} 到向量格式失败: {e}", exc_info=True) raise - + def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]: """将Vector DB结果转换为MemoryChunk""" try: @@ -331,10 +330,10 @@ class VectorMemoryStorage: if "memory_data" in metadata: memory_dict = orjson.loads(metadata["memory_data"]) return MemoryChunk.from_dict(memory_dict) - + # 兜底:从基础字段重建(使用新的结构化格式) logger.warning(f"未找到memory_data,使用兜底逻辑重建记忆 (id={metadata.get('memory_id', 'unknown')})") - + # 构建符合MemoryChunk.from_dict期望的结构 memory_dict = { "metadata": { @@ -345,24 +344,30 @@ class VectorMemoryStorage: "last_modified": metadata.get("timestamp", time.time()), "access_count": metadata.get("access_count", 0), "relevance_score": 0.0, - "confidence": self._parse_enum_value(metadata.get("confidence", 2), ConfidenceLevel, ConfidenceLevel.MEDIUM), - "importance": self._parse_enum_value(metadata.get("importance", 2), ImportanceLevel, ImportanceLevel.NORMAL), + "confidence": self._parse_enum_value( + metadata.get("confidence", 2), ConfidenceLevel, ConfidenceLevel.MEDIUM + ), + "importance": self._parse_enum_value( + metadata.get("importance", 2), ImportanceLevel, ImportanceLevel.NORMAL + ), "source_context": None, }, "content": { "subject": "", "predicate": "", "object": "", - "display": document # 使用document作为显示文本 + "display": document, # 使用document作为显示文本 }, "memory_type": metadata.get("memory_type", "contextual"), - "keywords": orjson.loads(metadata.get("keywords", "[]")) if isinstance(metadata.get("keywords"), str) else metadata.get("keywords", []), + "keywords": orjson.loads(metadata.get("keywords", "[]")) + if isinstance(metadata.get("keywords"), str) + else metadata.get("keywords", []), "tags": [], "categories": [], "embedding": None, "semantic_hash": None, "related_memories": [], - "temporal_context": None + "temporal_context": None, } return MemoryChunk.from_dict(memory_dict) @@ -434,40 +439,39 @@ class VectorMemoryStorage: # 其他类型,返回默认值 logger.warning(f"不支持的{enum_class.__name__}值类型: {type(value)},使用默认值") return default - + def _get_from_cache(self, memory_id: str) -> Optional[MemoryChunk]: """从缓存获取记忆""" if not self.config.enable_caching: return None - + with self._lock: if memory_id in self.memory_cache: self.cache_timestamps[memory_id] = time.time() self.stats["cache_hits"] += 1 return self.memory_cache[memory_id] - + self.stats["cache_misses"] += 1 return None - + def _add_to_cache(self, memory: MemoryChunk): """添加记忆到缓存""" if not self.config.enable_caching: return - + with self._lock: # 检查缓存大小限制 if len(self.memory_cache) >= self.config.cache_size_limit: # 移除最老的缓存项 - oldest_id = min(self.cache_timestamps.keys(), - key=lambda k: self.cache_timestamps[k]) + oldest_id = min(self.cache_timestamps.keys(), key=lambda k: self.cache_timestamps[k]) self.memory_cache.pop(oldest_id, None) self.cache_timestamps.pop(oldest_id, None) - - memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None) + + memory_id = getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", None) if memory_id: self.memory_cache[memory_id] = memory self.cache_timestamps[memory_id] = time.time() - + async def store_memories(self, memories: List[MemoryChunk]) -> int: """批量存储记忆""" if not memories: @@ -475,7 +479,7 @@ class VectorMemoryStorage: start_time = datetime.now() success_count = 0 - + try: # 转换为向量格式 vector_data_list = [] @@ -484,7 +488,7 @@ class VectorMemoryStorage: vector_data = self._memory_to_vector_format(memory) vector_data_list.append(vector_data) except Exception as e: - memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown') + memory_id = getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", "unknown") logger.error(f"处理记忆 {memory_id} 失败: {e}") continue @@ -494,8 +498,8 @@ class VectorMemoryStorage: # 批量存储到向量数据库 for i in range(0, len(vector_data_list), self.batch_size): - batch = vector_data_list[i:i + self.batch_size] - + batch = vector_data_list[i : i + self.batch_size] + try: # 生成embeddings embeddings = [] @@ -507,29 +511,37 @@ class VectorMemoryStorage: logger.error(f"生成embedding失败: {e}") # 使用零向量作为后备 embeddings.append([0.0] * 768) # 默认维度 - + # vector_db_service.add 需要embeddings参数 self.vector_db_service.add( collection_name=self.collection_name, embeddings=embeddings, ids=[item["id"] for item in batch], documents=[item["document"] for item in batch], - metadatas=[item["metadata"] for item in batch] + metadatas=[item["metadata"] for item in batch], ) success = True - + if success: # 更新缓存和元数据索引 metadata_entries = [] for item in batch: memory_id = item["id"] # 从原始 memories 列表中找到对应的 MemoryChunk - memory = next((m for m in memories if (getattr(m.metadata, 'memory_id', None) or getattr(m, 'memory_id', None)) == memory_id), None) + memory = next( + ( + m + for m in memories + if (getattr(m.metadata, "memory_id", None) or getattr(m, "memory_id", None)) + == memory_id + ), + None, + ) if memory: # 更新缓存 self._cache[memory_id] = memory success_count += 1 - + # 创建元数据索引条目 try: index_entry = MemoryMetadataIndexEntry( @@ -545,12 +557,12 @@ class VectorMemoryStorage: created_at=memory.metadata.created_at, access_count=memory.metadata.access_count, chat_id=memory.metadata.chat_id, - content_preview=str(memory.content)[:100] if memory.content else None + content_preview=str(memory.content)[:100] if memory.content else None, ) metadata_entries.append(index_entry) except Exception as e: logger.warning(f"创建元数据索引条目失败 (memory_id={memory_id}): {e}") - + # 批量更新元数据索引 if metadata_entries: try: @@ -560,14 +572,14 @@ class VectorMemoryStorage: logger.error(f"批量更新元数据索引失败: {e}") else: logger.warning(f"批次存储失败,跳过 {len(batch)} 条记忆") - + except Exception as e: logger.error(f"批量存储失败: {e}", exc_info=True) continue duration = (datetime.now() - start_time).total_seconds() logger.info(f"成功存储 {success_count}/{len(memories)} 条记忆,耗时 {duration:.2f}秒") - + # 保存元数据索引到磁盘 if success_count > 0: try: @@ -575,18 +587,18 @@ class VectorMemoryStorage: logger.debug("元数据索引已保存到磁盘") except Exception as e: logger.error(f"保存元数据索引失败: {e}") - + return success_count except Exception as e: logger.error(f"批量存储记忆失败: {e}", exc_info=True) return success_count - + async def store_memory(self, memory: MemoryChunk) -> bool: """存储单条记忆""" result = await self.store_memories([memory]) return result > 0 - + async def search_similar_memories( self, query_text: str, @@ -594,11 +606,11 @@ class VectorMemoryStorage: similarity_threshold: Optional[float] = None, filters: Optional[Dict[str, Any]] = None, # 新增:元数据过滤参数(用于JSON索引粗筛) - metadata_filters: Optional[Dict[str, Any]] = None + metadata_filters: Optional[Dict[str, Any]] = None, ) -> List[Tuple[MemoryChunk, float]]: """ 搜索相似记忆(混合索引模式) - + Args: query_text: 查询文本 limit: 返回数量限制 @@ -617,47 +629,49 @@ class VectorMemoryStorage: """ if not query_text.strip(): return [] - + try: # === 阶段一:JSON元数据粗筛(可选) === candidate_ids: Optional[List[str]] = None if metadata_filters: logger.debug(f"[JSON元数据粗筛] 开始,过滤条件: {metadata_filters}") candidate_ids = self.metadata_index.search( - memory_types=metadata_filters.get('memory_types'), - subjects=metadata_filters.get('subjects'), - keywords=metadata_filters.get('keywords'), - tags=metadata_filters.get('tags'), - importance_min=metadata_filters.get('importance_min'), - importance_max=metadata_filters.get('importance_max'), - created_after=metadata_filters.get('created_after'), - created_before=metadata_filters.get('created_before'), - user_id=metadata_filters.get('user_id'), + memory_types=metadata_filters.get("memory_types"), + subjects=metadata_filters.get("subjects"), + keywords=metadata_filters.get("keywords"), + tags=metadata_filters.get("tags"), + importance_min=metadata_filters.get("importance_min"), + importance_max=metadata_filters.get("importance_max"), + created_after=metadata_filters.get("created_after"), + created_before=metadata_filters.get("created_before"), + user_id=metadata_filters.get("user_id"), limit=self.config.search_limit * 2, # 粗筛返回更多候选 - flexible_mode=True # 使用灵活匹配模式 + flexible_mode=True, # 使用灵活匹配模式 ) logger.info(f"[JSON元数据粗筛] 完成,筛选出 {len(candidate_ids)} 个候选ID") # 如果粗筛后没有结果,回退到全部记忆搜索 if not candidate_ids: total_memories = len(self.metadata_index.index) - logger.warning(f"JSON元数据粗筛后无候选,启用回退机制:在全部 {total_memories} 条记忆中进行向量搜索") + logger.warning( + f"JSON元数据粗筛后无候选,启用回退机制:在全部 {total_memories} 条记忆中进行向量搜索" + ) logger.info("💡 提示:这可能是因为查询条件过于严格,或相关记忆的元数据与查询条件不完全匹配") candidate_ids = None # 设为None表示不限制候选ID else: - logger.debug(f"[JSON元数据粗筛] 成功筛选出候选,进入向量精筛阶段") - + logger.debug("[JSON元数据粗筛] 成功筛选出候选,进入向量精筛阶段") + # === 阶段二:向量精筛 === # 生成查询向量 query_embedding = await get_embedding(query_text) if not query_embedding: return [] - + threshold = similarity_threshold or self.config.similarity_threshold - + # 构建where条件 where_conditions = filters or {} - + # 如果有候选ID列表,添加到where条件 if candidate_ids: # ChromaDB的where条件需要使用$in操作符 @@ -665,117 +679,114 @@ class VectorMemoryStorage: logger.debug(f"[向量精筛] 限制在 {len(candidate_ids)} 个候选ID内搜索") else: logger.info("[向量精筛] 在全部记忆中搜索(元数据筛选无结果回退)") - + # 查询Vector DB logger.debug(f"[向量精筛] 开始,limit={min(limit, self.config.search_limit)}") results = vector_db_service.query( collection_name=self.config.memory_collection, query_embeddings=[query_embedding], n_results=min(limit, self.config.search_limit), - where=where_conditions if where_conditions else None + where=where_conditions if where_conditions else None, ) - + # 处理结果 similar_memories = [] - + if results.get("documents") and results["documents"][0]: documents = results["documents"][0] distances = results.get("distances", [[]])[0] metadatas = results.get("metadatas", [[]])[0] ids = results.get("ids", [[]])[0] - - logger.info(f"向量检索返回原始结果:documents={len(documents)}, ids={len(ids)}, metadatas={len(metadatas)}") - for i, (doc, metadata, memory_id) in enumerate(zip(documents, metadatas, ids)): + + logger.info( + f"向量检索返回原始结果:documents={len(documents)}, ids={len(ids)}, metadatas={len(metadatas)}" + ) + for i, (doc, metadata, memory_id) in enumerate(zip(documents, metadatas, ids, strict=False)): # 计算相似度 distance = distances[i] if i < len(distances) else 1.0 similarity = 1 - distance # ChromaDB返回距离,转换为相似度 - + if similarity < threshold: continue - + # 首先尝试从缓存获取 memory = self._get_from_cache(memory_id) - + if not memory: # 从Vector结果重建 memory = self._vector_result_to_memory(doc, metadata) if memory: self._add_to_cache(memory) - + if memory: similar_memories.append((memory, similarity)) # 记录单条结果的关键日志(id,相似度,简短文本) try: - short_text = (str(memory.content)[:120]) if hasattr(memory, 'content') else (doc[:120] if isinstance(doc, str) else '') + short_text = ( + (str(memory.content)[:120]) + if hasattr(memory, "content") + else (doc[:120] if isinstance(doc, str) else "") + ) except Exception: - short_text = '' + short_text = "" logger.info(f"检索结果 - id={memory_id}, similarity={similarity:.4f}, summary={short_text}") - + # 按相似度排序 similar_memories.sort(key=lambda x: x[1], reverse=True) - + self.stats["total_searches"] += 1 - logger.info(f"搜索相似记忆: query='{query_text[:60]}...', limit={limit}, threshold={threshold}, filters={where_conditions}, 返回数={len(similar_memories)}") + logger.info( + f"搜索相似记忆: query='{query_text[:60]}...', limit={limit}, threshold={threshold}, filters={where_conditions}, 返回数={len(similar_memories)}" + ) logger.debug(f"搜索相似记忆 详细结果数={len(similar_memories)}") - + return similar_memories - + except Exception as e: logger.error(f"搜索相似记忆失败: {e}") return [] - + async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]: """根据ID获取记忆""" # 首先尝试从缓存获取 memory = self._get_from_cache(memory_id) if memory: return memory - + try: # 从Vector DB获取 - results = vector_db_service.get( - collection_name=self.config.memory_collection, - ids=[memory_id] - ) - + results = vector_db_service.get(collection_name=self.config.memory_collection, ids=[memory_id]) + if results.get("documents") and results["documents"]: document = results["documents"][0] metadata = results["metadatas"][0] if results.get("metadatas") else {} - + memory = self._vector_result_to_memory(document, metadata) if memory: self._add_to_cache(memory) - + return memory - + except Exception as e: logger.error(f"获取记忆 {memory_id} 失败: {e}") - + return None - - async def get_memories_by_filters( - self, - filters: Dict[str, Any], - limit: int = 100 - ) -> List[MemoryChunk]: + + async def get_memories_by_filters(self, filters: Dict[str, Any], limit: int = 100) -> List[MemoryChunk]: """根据过滤条件获取记忆""" try: - results = vector_db_service.get( - collection_name=self.config.memory_collection, - where=filters, - limit=limit - ) - + results = vector_db_service.get(collection_name=self.config.memory_collection, where=filters, limit=limit) + memories = [] if results.get("documents"): documents = results["documents"] metadatas = results.get("metadatas", [{}] * len(documents)) ids = results.get("ids", []) - + logger.info(f"按过滤条件获取返回: docs={len(documents)}, ids={len(ids)}") - for i, (doc, metadata) in enumerate(zip(documents, metadatas)): + for i, (doc, metadata) in enumerate(zip(documents, metadatas, strict=False)): memory_id = ids[i] if i < len(ids) else None - + # 首先尝试从缓存获取 if memory_id: memory = self._get_from_cache(memory_id) @@ -783,7 +794,7 @@ class VectorMemoryStorage: memories.append(memory) logger.debug(f"过滤获取命中缓存: id={memory_id}") continue - + # 从Vector结果重建 memory = self._vector_result_to_memory(doc, metadata) if memory: @@ -791,137 +802,129 @@ class VectorMemoryStorage: if memory_id: self._add_to_cache(memory) logger.debug(f"过滤获取结果: id={memory_id}, meta_keys={list(metadata.keys())}") - + return memories - + except Exception as e: logger.error(f"根据过滤条件获取记忆失败: {e}") return [] - + async def update_memory(self, memory: MemoryChunk) -> bool: """更新记忆""" try: - memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None) + memory_id = getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", None) if not memory_id: logger.error("无法更新记忆:缺少memory_id") return False - + # 先删除旧记忆 await self.delete_memory(memory_id) - + # 重新存储更新后的记忆 return await self.store_memory(memory) - + except Exception as e: - memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown') + memory_id = getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", "unknown") logger.error(f"更新记忆 {memory_id} 失败: {e}") return False - + async def delete_memory(self, memory_id: str) -> bool: """删除记忆""" try: # 从Vector DB删除 - vector_db_service.delete( - collection_name=self.config.memory_collection, - ids=[memory_id] - ) - + vector_db_service.delete(collection_name=self.config.memory_collection, ids=[memory_id]) + # 从缓存删除 with self._lock: self.memory_cache.pop(memory_id, None) self.cache_timestamps.pop(memory_id, None) - + self.stats["total_memories"] = max(0, self.stats["total_memories"] - 1) logger.debug(f"删除记忆: {memory_id}") - + return True - + except Exception as e: logger.error(f"删除记忆 {memory_id} 失败: {e}") return False - + async def delete_memories_by_filters(self, filters: Dict[str, Any]) -> int: """根据过滤条件批量删除记忆""" try: # 先获取要删除的记忆ID results = vector_db_service.get( - collection_name=self.config.memory_collection, - where=filters, - include=["metadatas"] + collection_name=self.config.memory_collection, where=filters, include=["metadatas"] ) - + if not results.get("ids"): return 0 - + memory_ids = results["ids"] - + # 批量删除 - vector_db_service.delete( - collection_name=self.config.memory_collection, - where=filters - ) - + vector_db_service.delete(collection_name=self.config.memory_collection, where=filters) + # 从缓存删除 with self._lock: for memory_id in memory_ids: self.memory_cache.pop(memory_id, None) self.cache_timestamps.pop(memory_id, None) - + deleted_count = len(memory_ids) self.stats["total_memories"] = max(0, self.stats["total_memories"] - deleted_count) logger.info(f"批量删除记忆: {deleted_count} 条") - + return deleted_count - + except Exception as e: logger.error(f"批量删除记忆失败: {e}") return 0 - + async def perform_forgetting_check(self) -> Dict[str, Any]: """执行遗忘检查""" if not self.forgetting_engine: return {"error": "遗忘引擎未启用"} - + try: # 获取所有记忆进行遗忘检查 # 注意:对于大型数据集,这里应该分批处理 current_time = time.time() cutoff_time = current_time - (self.config.retention_hours * 3600) - + # 先删除明显过期的记忆 expired_filters = {"timestamp": {"$lt": cutoff_time}} expired_count = await self.delete_memories_by_filters(expired_filters) - + # 对剩余记忆执行智能遗忘检查 # 这里为了性能考虑,只检查一部分记忆 sample_memories = await self.get_memories_by_filters({}, limit=500) - + if sample_memories: result = await self.forgetting_engine.perform_forgetting_check(sample_memories) - + # 遗忘标记的记忆 forgetting_ids = result.get("normal_forgetting", []) + result.get("force_forgetting", []) forgotten_count = 0 - + for memory_id in forgetting_ids: if await self.delete_memory(memory_id): forgotten_count += 1 - + result["forgotten_count"] = forgotten_count result["expired_count"] = expired_count - + # 更新统计 self.stats["forgetting_stats"] = self.forgetting_engine.get_forgetting_stats() - + logger.info(f"遗忘检查完成: 过期删除 {expired_count}, 智能遗忘 {forgotten_count}") return result - + return {"expired_count": expired_count, "forgotten_count": 0} - + except Exception as e: logger.error(f"执行遗忘检查失败: {e}") return {"error": str(e)} - + def get_storage_stats(self) -> Dict[str, Any]: """获取存储统计信息""" try: @@ -929,27 +932,27 @@ class VectorMemoryStorage: self.stats["total_memories"] = current_total except Exception: pass - + return { **self.stats, "cache_size": len(self.memory_cache), "collection_name": self.config.memory_collection, "storage_type": "vector_db_v2", - "uptime": time.time() - self.stats.get("start_time", time.time()) + "uptime": time.time() - self.stats.get("start_time", time.time()), } - + def stop(self): """停止存储系统""" self._stop_cleanup = True - + if self._cleanup_task and self._cleanup_task.is_alive(): logger.info("正在停止定时清理任务...") - + # 清空缓存 with self._lock: self.memory_cache.clear() self.cache_timestamps.clear() - + logger.info("Vector记忆存储系统已停止") @@ -960,36 +963,33 @@ _global_vector_storage = None def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> VectorMemoryStorage: """获取全局Vector记忆存储实例""" global _global_vector_storage - + if _global_vector_storage is None: _global_vector_storage = VectorMemoryStorage(config) - + return _global_vector_storage # 兼容性接口 class VectorMemoryStorageAdapter: """适配器类,提供与原UnifiedMemoryStorage兼容的接口""" - + def __init__(self, config: Optional[VectorStorageConfig] = None): self.storage = VectorMemoryStorage(config) - + async def store_memories(self, memories: List[MemoryChunk]) -> int: return await self.storage.store_memories(memories) - + async def search_similar_memories( - self, - query_text: str, - limit: int = 10, - scope_id: Optional[str] = None, - filters: Optional[Dict[str, Any]] = None + self, query_text: str, limit: int = 10, scope_id: Optional[str] = None, filters: Optional[Dict[str, Any]] = None ) -> List[Tuple[str, float]]: - results = await self.storage.search_similar_memories( - query_text, limit, filters=filters - ) + results = await self.storage.search_similar_memories(query_text, limit, filters=filters) # 转换为原格式:(memory_id, similarity) - return [(getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown'), similarity) for memory, similarity in results] - + return [ + (getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", "unknown"), similarity) + for memory, similarity in results + ] + def get_stats(self) -> Dict[str, Any]: return self.storage.get_storage_stats() @@ -998,33 +998,34 @@ if __name__ == "__main__": # 简单测试 async def test_vector_storage(): storage = VectorMemoryStorage() - + # 创建测试记忆 from src.chat.memory_system.memory_chunk import MemoryType + test_memory = MemoryChunk( memory_id="test_001", user_id="test_user", text_content="今天天气很好,适合出门散步", memory_type=MemoryType.FACT, keywords=["天气", "散步"], - importance=0.7 + importance=0.7, ) - + # 存储记忆 success = await storage.store_memory(test_memory) print(f"存储结果: {success}") - + # 搜索记忆 results = await storage.search_similar_memories("天气怎么样", limit=5) print(f"搜索结果: {len(results)} 条") - + for memory, similarity in results: print(f" - {memory.text_content[:50]}... (相似度: {similarity:.3f})") - + # 获取统计信息 stats = storage.get_storage_stats() print(f"存储统计: {stats}") - + storage.stop() - - asyncio.run(test_vector_storage()) \ No newline at end of file + + asyncio.run(test_vector_storage()) diff --git a/src/chat/message_manager/__init__.py b/src/chat/message_manager/__init__.py index 3e27858a0..fe5e90785 100644 --- a/src/chat/message_manager/__init__.py +++ b/src/chat/message_manager/__init__.py @@ -5,15 +5,12 @@ from .message_manager import MessageManager, message_manager from .context_manager import SingleStreamContextManager -from .distribution_manager import ( - StreamLoopManager, - stream_loop_manager -) +from .distribution_manager import StreamLoopManager, stream_loop_manager __all__ = [ "MessageManager", "message_manager", "SingleStreamContextManager", "StreamLoopManager", - "stream_loop_manager" -] \ No newline at end of file + "stream_loop_manager", +] diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index 0f744b129..5f3212065 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -230,12 +230,14 @@ class SingleStreamContextManager: 异步计算消息的兴趣度。 此方法通过检查当前是否存在正在运行的 asyncio 事件循环来兼容同步和异步调用。 """ + # 内部异步函数,封装实际的计算逻辑 async def _get_score(): try: from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ( chatter_interest_scoring_system, ) + interest_score = await chatter_interest_scoring_system._calculate_single_message_score( message=message, bot_nickname=global_config.bot.nickname ) diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index f16a22b70..69f3e662d 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -34,17 +34,13 @@ class StreamLoopManager: } # 配置参数 - self.max_concurrent_streams = ( - max_concurrent_streams or global_config.chat.max_concurrent_distributions - ) + self.max_concurrent_streams = max_concurrent_streams or global_config.chat.max_concurrent_distributions # 强制分发策略 self.force_dispatch_unread_threshold: Optional[int] = getattr( global_config.chat, "force_dispatch_unread_threshold", 20 ) - self.force_dispatch_min_interval: float = getattr( - global_config.chat, "force_dispatch_min_interval", 0.1 - ) + self.force_dispatch_min_interval: float = getattr(global_config.chat, "force_dispatch_min_interval", 0.1) # Chatter管理器 self.chatter_manager: Optional[ChatterManager] = None @@ -108,7 +104,9 @@ class StreamLoopManager: if force and len(self.stream_loops) >= self.max_concurrent_streams: logger.warning( - "流 %s 未读消息积压严重(>%s),突破并发限制强制启动分发", stream_id, self.force_dispatch_unread_threshold + "流 %s 未读消息积压严重(>%s),突破并发限制强制启动分发", + stream_id, + self.force_dispatch_unread_threshold, ) # 创建流循环任务 @@ -168,9 +166,7 @@ class StreamLoopManager: if has_messages: if force_dispatch: - logger.info( - "流 %s 未读消息 %d 条,触发强制分发", stream_id, unread_count - ) + logger.info("流 %s 未读消息 %d 条,触发强制分发", stream_id, unread_count) # 3. 激活chatter处理 success = await self._process_stream_messages(stream_id, context) diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 5b715715b..bd55bd43f 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -11,7 +11,7 @@ from typing import Dict, Optional, Any, TYPE_CHECKING, List from src.chat.message_receive.chat_stream import ChatStream from src.common.logger import get_logger from src.common.data_models.database_data_model import DatabaseMessages -from src.common.data_models.message_manager_data_model import StreamContext, MessageManagerStats, StreamStats +from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats from src.chat.chatter_manager import ChatterManager from src.chat.planner_actions.action_manager import ChatterActionManager from .sleep_manager.sleep_manager import SleepManager @@ -21,7 +21,7 @@ from src.plugin_system.apis.chat_api import get_chat_manager from .distribution_manager import stream_loop_manager if TYPE_CHECKING: - from src.common.data_models.message_manager_data_model import StreamContext + pass logger = get_logger("message_manager") @@ -63,7 +63,7 @@ class MessageManager: stream_loop_manager.set_chatter_manager(self.chatter_manager) logger.info("🚀 消息管理器已启动 | 流循环管理器已启动") - + async def stop(self): """停止消息管理器""" if not self.is_running: @@ -88,7 +88,9 @@ class MessageManager: logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在") return await self._check_and_handle_interruption(chat_stream) - chat_stream.context_manager.context.processing_task = asyncio.create_task(chat_stream.context_manager.add_message(message)) + chat_stream.context_manager.context.processing_task = asyncio.create_task( + chat_stream.context_manager.add_message(message) + ) except Exception as e: logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}") @@ -141,11 +143,7 @@ class MessageManager: if not message_id: continue - payload = { - key: value - for key, value in item.items() - if key != "message_id" and value is not None - } + payload = {key: value for key, value in item.items() if key != "message_id" and value is not None} if not payload: continue @@ -169,9 +167,7 @@ class MessageManager: if not chat_stream: logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在") return - success = await chat_stream.context_manager.update_message( - message_id, {"actions": [action]} - ) + success = await chat_stream.context_manager.update_message(message_id, {"actions": [action]}) if success: logger.debug(f"为消息 {message_id} 添加动作 {action} 成功") else: @@ -193,7 +189,7 @@ class MessageManager: context.is_active = False # 取消处理任务 - if hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done(): + if hasattr(context, "processing_task") and context.processing_task and not context.processing_task.done(): context.processing_task.cancel() logger.info(f"停用聊天流: {stream_id}") @@ -236,7 +232,11 @@ class MessageManager: unread_count=unread_count, history_count=len(context.history_messages), last_check_time=context.last_check_time, - has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()), + has_active_task=bool( + hasattr(context, "processing_task") + and context.processing_task + and not context.processing_task.done() + ), ) except Exception as e: @@ -284,7 +284,10 @@ class MessageManager: return # 检查是否有正在进行的处理任务 - if chat_stream.context_manager.context.processing_task and not chat_stream.context_manager.context.processing_task.done(): + if ( + chat_stream.context_manager.context.processing_task + and not chat_stream.context_manager.context.processing_task.done() + ): # 计算打断概率 interruption_probability = chat_stream.context_manager.context.calculate_interruption_probability( global_config.chat.interruption_max_limit, global_config.chat.interruption_probability_factor @@ -310,7 +313,9 @@ class MessageManager: # 增加打断计数并应用afc阈值降低 chat_stream.context_manager.context.increment_interruption_count() - chat_stream.context_manager.context.apply_interruption_afc_reduction(global_config.chat.interruption_afc_reduction) + chat_stream.context_manager.context.apply_interruption_afc_reduction( + global_config.chat.interruption_afc_reduction + ) # 检查是否已达到最大次数 if chat_stream.context_manager.context.interruption_count >= global_config.chat.interruption_max_limit: @@ -364,7 +369,7 @@ class MessageManager: return context = chat_stream.context_manager.context - if hasattr(context, 'unread_messages') and context.unread_messages: + if hasattr(context, "unread_messages") and context.unread_messages: logger.debug(f"正在为流 {stream_id} 清除 {len(context.unread_messages)} 条未读消息") context.unread_messages.clear() else: diff --git a/src/chat/message_manager/sleep_manager/notification_sender.py b/src/chat/message_manager/sleep_manager/notification_sender.py index 07e8b09d4..d40cd612d 100644 --- a/src/chat/message_manager/sleep_manager/notification_sender.py +++ b/src/chat/message_manager/sleep_manager/notification_sender.py @@ -1,33 +1,33 @@ from src.common.logger import get_logger -#from ..hfc_context import HfcContext +# from ..hfc_context import HfcContext logger = get_logger("notification_sender") class NotificationSender: @staticmethod - async def send_goodnight_notification(context): # type: ignore + async def send_goodnight_notification(context): # type: ignore """发送晚安通知""" - #try: - #from ..proactive.events import ProactiveTriggerEvent - #from ..proactive.proactive_thinker import ProactiveThinker - - #event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight") - #proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor) - #await proactive_thinker.think(event) - #except Exception as e: - #logger.error(f"发送晚安通知失败: {e}") + # try: + # from ..proactive.events import ProactiveTriggerEvent + # from ..proactive.proactive_thinker import ProactiveThinker + + # event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight") + # proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor) + # await proactive_thinker.think(event) + # except Exception as e: + # logger.error(f"发送晚安通知失败: {e}") @staticmethod - async def send_insomnia_notification(context, reason: str): # type: ignore + async def send_insomnia_notification(context, reason: str): # type: ignore """发送失眠通知""" - #try: - #from ..proactive.events import ProactiveTriggerEvent - #from ..proactive.proactive_thinker import ProactiveThinker + # try: + # from ..proactive.events import ProactiveTriggerEvent + # from ..proactive.proactive_thinker import ProactiveThinker - #event = ProactiveTriggerEvent(source="sleep_manager", reason=reason) - #proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor) - #await proactive_thinker.think(event) - #except Exception as e: - #logger.error(f"发送失眠通知失败: {e}") \ No newline at end of file + # event = ProactiveTriggerEvent(source="sleep_manager", reason=reason) + # proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor) + # await proactive_thinker.think(event) + # except Exception as e: + # logger.error(f"发送失眠通知失败: {e}") diff --git a/src/chat/message_manager/sleep_manager/sleep_manager.py b/src/chat/message_manager/sleep_manager/sleep_manager.py index 0ed21e685..b0cf79b1b 100644 --- a/src/chat/message_manager/sleep_manager/sleep_manager.py +++ b/src/chat/message_manager/sleep_manager/sleep_manager.py @@ -1,6 +1,6 @@ import asyncio import random -from datetime import datetime, timedelta, date +from datetime import datetime, timedelta from typing import Optional, TYPE_CHECKING from src.common.logger import get_logger @@ -21,6 +21,7 @@ class SleepManager: 它实现了一个状态机,根据预设的时间表、睡眠压力和随机因素, 在不同的睡眠状态(如清醒、准备入睡、睡眠、失眠)之间进行切换。 """ + def __init__(self): """ 初始化睡眠管理器。 @@ -97,7 +98,7 @@ class SleepManager: logger.info(f"进入理论休眠时间 '{activity}',开始进行睡眠决策...") else: logger.info("进入理论休眠时间,开始进行睡眠决策...") - + if global_config.sleep_system.enable_flexible_sleep: # --- 新的弹性睡眠逻辑 --- if wakeup_manager: @@ -112,7 +113,7 @@ class SleepManager: pressure_diff = (pressure_threshold - sleep_pressure) / pressure_threshold # 延迟分钟数,压力越低,延迟越长 delay_minutes = int(pressure_diff * max_delay_minutes) - + # 确保总延迟不超过当日最大值 remaining_delay = max_delay_minutes - self.context.total_delayed_minutes_today delay_minutes = min(delay_minutes, remaining_delay) @@ -151,9 +152,10 @@ class SleepManager: if wakeup_manager and global_config.sleep_system.enable_pre_sleep_notification: asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context)) self.context.current_state = SleepState.SLEEPING - - def _handle_preparing_sleep(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]): + def _handle_preparing_sleep( + self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"] + ): """处理“准备入睡”状态下的逻辑。""" # 如果在准备期间离开了理论睡眠时间,则取消入睡 if not is_in_theoretical_sleep: @@ -166,16 +168,22 @@ class SleepManager: logger.info("睡眠缓冲期结束,正式进入休眠状态。") self.context.current_state = SleepState.SLEEPING self._last_fully_slept_log_time = now.timestamp() - + # 设置一个随机的延迟,用于触发“睡后失眠”检查 delay_minutes_range = global_config.sleep_system.insomnia_trigger_delay_minutes delay_minutes = random.randint(delay_minutes_range[0], delay_minutes_range[1]) self.context.sleep_buffer_end_time = now + timedelta(minutes=delay_minutes) logger.info(f"已设置睡后失眠检查,将在 {delay_minutes} 分钟后触发。") - + self.context.save() - def _handle_sleeping(self, now: datetime, is_in_theoretical_sleep: bool, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]): + def _handle_sleeping( + self, + now: datetime, + is_in_theoretical_sleep: bool, + activity: Optional[str], + wakeup_manager: Optional["WakeUpManager"], + ): """处理“正在睡觉”状态下的逻辑。""" # 如果理论睡眠时间结束,则自然醒来 if not is_in_theoretical_sleep: @@ -198,14 +206,16 @@ class SleepManager: if insomnia_reason: self.context.current_state = SleepState.INSOMNIA - + # 设置失眠的持续时间 duration_minutes_range = global_config.sleep_system.insomnia_duration_minutes duration_minutes = random.randint(*duration_minutes_range) self.context.sleep_buffer_end_time = now + timedelta(minutes=duration_minutes) - + # 发送失眠通知 - asyncio.create_task(NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason)) + asyncio.create_task( + NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason) + ) logger.info(f"进入失眠状态 (原因: {insomnia_reason}),将持续 {duration_minutes} 分钟。") else: # 睡眠压力正常,不触发失眠,清除检查时间点 diff --git a/src/chat/message_manager/sleep_manager/sleep_state.py b/src/chat/message_manager/sleep_manager/sleep_state.py index d59f1f3d6..105302169 100644 --- a/src/chat/message_manager/sleep_manager/sleep_state.py +++ b/src/chat/message_manager/sleep_manager/sleep_state.py @@ -25,6 +25,7 @@ class SleepContext: """ 睡眠上下文,负责封装和管理所有与睡眠相关的状态,并处理其持久化。 """ + def __init__(self): """初始化睡眠上下文,并从本地存储加载初始状态。""" self.current_state: SleepState = SleepState.AWAKE @@ -83,4 +84,4 @@ class SleepContext: logger.info(f"成功从本地存储加载睡眠上下文: {state}") except Exception as e: - logger.warning(f"加载睡眠上下文失败,将使用默认值: {e}") \ No newline at end of file + logger.warning(f"加载睡眠上下文失败,将使用默认值: {e}") diff --git a/src/chat/message_manager/sleep_manager/time_checker.py b/src/chat/message_manager/sleep_manager/time_checker.py index 47376ac35..773830c3a 100644 --- a/src/chat/message_manager/sleep_manager/time_checker.py +++ b/src/chat/message_manager/sleep_manager/time_checker.py @@ -15,23 +15,25 @@ class TimeChecker: self._daily_sleep_offset: int = 0 self._daily_wake_offset: int = 0 self._offset_date = None - + def _get_daily_offsets(self): """获取当天的睡眠和起床时间偏移量,每天生成一次""" today = datetime.now().date() - + # 如果是新的一天,重新生成偏移量 if self._offset_date != today: sleep_offset_range = global_config.sleep_system.sleep_time_offset_minutes wake_offset_range = global_config.sleep_system.wake_up_time_offset_minutes - + # 生成 ±offset_range 范围内的随机偏移量 self._daily_sleep_offset = random.randint(-sleep_offset_range, sleep_offset_range) self._daily_wake_offset = random.randint(-wake_offset_range, wake_offset_range) self._offset_date = today - - logger.debug(f"生成新的每日偏移量 - 睡觉时间偏移: {self._daily_sleep_offset}分钟, 起床时间偏移: {self._daily_wake_offset}分钟") - + + logger.debug( + f"生成新的每日偏移量 - 睡觉时间偏移: {self._daily_sleep_offset}分钟, 起床时间偏移: {self._daily_wake_offset}分钟" + ) + return self._daily_sleep_offset, self._daily_wake_offset @staticmethod @@ -82,28 +84,36 @@ class TimeChecker: try: start_time_str = global_config.sleep_system.fixed_sleep_time end_time_str = global_config.sleep_system.fixed_wake_up_time - + # 获取当天的偏移量 sleep_offset, wake_offset = self._get_daily_offsets() - + # 解析基础时间 base_start_time = datetime.strptime(start_time_str, "%H:%M") base_end_time = datetime.strptime(end_time_str, "%H:%M") - + # 应用偏移量 actual_start_time = (base_start_time + timedelta(minutes=sleep_offset)).time() actual_end_time = (base_end_time + timedelta(minutes=wake_offset)).time() - - logger.debug(f"固定睡眠时间检查 - 基础时间: {start_time_str}-{end_time_str}, " - f"偏移后时间: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')}, " - f"当前时间: {now_time.strftime('%H:%M')}") + + logger.debug( + f"固定睡眠时间检查 - 基础时间: {start_time_str}-{end_time_str}, " + f"偏移后时间: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')}, " + f"当前时间: {now_time.strftime('%H:%M')}" + ) if actual_start_time <= actual_end_time: if actual_start_time <= now_time < actual_end_time: - return True, f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})" + return ( + True, + f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})", + ) else: if now_time >= actual_start_time or now_time < actual_end_time: - return True, f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})" + return ( + True, + f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})", + ) except ValueError as e: logger.error(f"固定的睡眠时间格式不正确,请使用 HH:MM 格式: {e}") - return False, None \ No newline at end of file + return False, None diff --git a/src/chat/message_manager/sleep_manager/wakeup_context.py b/src/chat/message_manager/sleep_manager/wakeup_context.py index bfa1a62dd..efb59e392 100644 --- a/src/chat/message_manager/sleep_manager/wakeup_context.py +++ b/src/chat/message_manager/sleep_manager/wakeup_context.py @@ -1,4 +1,3 @@ -import time from src.common.logger import get_logger from src.manager.local_store_manager import local_storage @@ -9,6 +8,7 @@ class WakeUpContext: """ 唤醒上下文,负责封装和管理所有与唤醒相关的状态,并处理其持久化。 """ + def __init__(self): """初始化唤醒上下文,并从本地存储加载初始状态。""" self.wakeup_value: float = 0.0 @@ -42,4 +42,4 @@ class WakeUpContext: "sleep_pressure": self.sleep_pressure, } local_storage[self._get_storage_key()] = state - logger.debug(f"已将唤醒上下文保存到本地存储: {state}") \ No newline at end of file + logger.debug(f"已将唤醒上下文保存到本地存储: {state}") diff --git a/src/chat/message_manager/sleep_manager/wakeup_manager.py b/src/chat/message_manager/sleep_manager/wakeup_manager.py index 51ab80bb1..5fc68ff41 100644 --- a/src/chat/message_manager/sleep_manager/wakeup_manager.py +++ b/src/chat/message_manager/sleep_manager/wakeup_manager.py @@ -3,7 +3,6 @@ import time from typing import Optional, TYPE_CHECKING from src.common.logger import get_logger from src.config.config import global_config -from src.manager.local_store_manager import local_storage from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext if TYPE_CHECKING: @@ -51,7 +50,7 @@ class WakeUpManager: if not self.enabled: logger.info("唤醒度系统已禁用,跳过启动") return - + self.is_running = True if not self._decay_task or self._decay_task.done(): self._decay_task = asyncio.create_task(self._decay_loop()) @@ -88,6 +87,7 @@ class WakeUpManager: self.context.is_angry = False # 通知情绪管理系统清除愤怒状态 from src.mood.mood_manager import mood_manager + if self.angry_chat_id: mood_manager.clear_angry_from_wakeup(self.angry_chat_id) self.angry_chat_id = None @@ -104,7 +104,9 @@ class WakeUpManager: logger.debug(f"唤醒度衰减: {old_value:.1f} -> {self.context.wakeup_value:.1f}") self.context.save() - def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None) -> bool: + def add_wakeup_value( + self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None + ) -> bool: """ 增加唤醒度值 @@ -173,6 +175,7 @@ class WakeUpManager: # 通知情绪管理系统进入愤怒状态 from src.mood.mood_manager import mood_manager + mood_manager.set_angry_from_wakeup(chat_id) # 通知SleepManager重置睡眠状态 @@ -194,6 +197,7 @@ class WakeUpManager: self.context.is_angry = False # 通知情绪管理系统清除愤怒状态 from src.mood.mood_manager import mood_manager + if self.angry_chat_id: mood_manager.clear_angry_from_wakeup(self.angry_chat_id) self.angry_chat_id = None diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index d30600b70..47d1f26e2 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -190,7 +190,7 @@ class ChatBot: try: # 检查聊天类型限制 if not plus_command_instance.is_chat_type_allowed(): - is_group = message.message_info.group_info + is_group = message.message_info.group_info logger.info( f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" ) @@ -420,7 +420,9 @@ class ChatBot: await message.process() # 在这里打印[所见]日志,确保在所有处理和过滤之前记录 - logger.info(f"\u001b[38;5;118m{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m") + logger.info( + f"\u001b[38;5;118m{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m" + ) # 过滤检查 if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore @@ -452,7 +454,7 @@ class ChatBot: result = await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message=message) if not result.all_continue_process(): raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理") - + # TODO:暂不可用 # 确认从接口发来的message是否有自定义的prompt模板信息 if message.message_info.template_info and not message.message_info.template_info.template_default: @@ -469,14 +471,14 @@ class ChatBot: async def preprocess(): # 存储消息到数据库 from .storage import MessageStorage - + try: await MessageStorage.store_message(message, message.chat_stream) logger.debug(f"消息已存储到数据库: {message.message_info.message_id}") except Exception as e: logger.error(f"存储消息到数据库失败: {e}") traceback.print_exc() - + # 使用消息管理器处理消息(保持原有功能) from src.common.data_models.database_data_model import DatabaseMessages diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index d31daeaee..559490694 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -89,7 +89,7 @@ class ChatStream: # 复制 stream_context,但跳过 processing_task new_stream.stream_context = copy.deepcopy(self.stream_context, memo) - if hasattr(new_stream.stream_context, 'processing_task'): + if hasattr(new_stream.stream_context, "processing_task"): new_stream.stream_context.processing_task = None # 复制 context_manager @@ -377,6 +377,7 @@ class ChatStream: # 默认基础分 return 0.3 + class ChatManager: """聊天管理器,管理所有聊天流""" @@ -563,9 +564,8 @@ class ChatManager: if not hasattr(stream, "context_manager"): # 创建新的单流上下文管理器 from src.chat.message_manager.context_manager import SingleStreamContextManager - stream.context_manager = SingleStreamContextManager( - stream_id=stream_id, context=stream.stream_context - ) + + stream.context_manager = SingleStreamContextManager(stream_id=stream_id, context=stream.stream_context) # 保存到内存和数据库 self.streams[stream_id] = stream @@ -721,6 +721,7 @@ class ChatManager: # 确保 ChatStream 有自己的 context_manager if not hasattr(stream, "context_manager"): from src.chat.message_manager.context_manager import SingleStreamContextManager + stream.context_manager = SingleStreamContextManager( stream_id=stream.stream_id, context=stream.stream_context ) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index f6041b7d4..fee932b62 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -108,7 +108,7 @@ class MessageRecv(Message): self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) self.raw_message = message_dict.get("raw_message") - + self.chat_stream = None self.reply = None self.processed_plain_text = message_dict.get("processed_plain_text", "") diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index a6aaa00eb..5a654e867 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -52,7 +52,11 @@ class MessageStorage: filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) else: # 如果没有设置display_message,使用processed_plain_text作为显示消息 - filtered_display_message = re.sub(pattern, "", message.processed_plain_text, flags=re.DOTALL) if message.processed_plain_text else "" + filtered_display_message = ( + re.sub(pattern, "", message.processed_plain_text, flags=re.DOTALL) + if message.processed_plain_text + else "" + ) interest_value = 0 is_mentioned = False reply_to = message.reply_to @@ -175,9 +179,11 @@ class MessageStorage: from src.common.database.sqlalchemy_models import get_db_session async with get_db_session() as session: - matched_message = (await session.execute( - select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) - )).scalar() + matched_message = ( + await session.execute( + select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) + ) + ).scalar() if matched_message: await session.execute( @@ -211,9 +217,11 @@ class MessageStorage: from src.common.database.sqlalchemy_models import get_db_session async with get_db_session() as session: - image_record = (await session.execute( - select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) - )).scalar() + image_record = ( + await session.execute( + select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) + ) + ).scalar() return f"[picid:{image_record.image_id}]" if image_record else match.group(0) except Exception: return match.group(0) @@ -294,15 +302,19 @@ class MessageStorage: from src.common.database.sqlalchemy_models import Messages # 查找需要修复的记录:interest_value为0、null或很小的值 - query = select(Messages).where( - (Messages.chat_id == chat_id) & - (Messages.time >= since_time) & - ( - (Messages.interest_value == 0) | - (Messages.interest_value.is_(None)) | - (Messages.interest_value < 0.1) + query = ( + select(Messages) + .where( + (Messages.chat_id == chat_id) + & (Messages.time >= since_time) + & ( + (Messages.interest_value == 0) + | (Messages.interest_value.is_(None)) + | (Messages.interest_value < 0.1) + ) ) - ).limit(50) # 限制每次修复的数量,避免性能问题 + .limit(50) + ) # 限制每次修复的数量,避免性能问题 result = await session.execute(query) messages_to_fix = result.scalars().all() @@ -314,7 +326,7 @@ class MessageStorage: default_interest = 0.3 # 默认中等兴趣度 # 如果消息内容较长,可能是重要消息,兴趣度稍高 - if hasattr(msg, 'processed_plain_text') and msg.processed_plain_text: + if hasattr(msg, "processed_plain_text") and msg.processed_plain_text: text_length = len(msg.processed_plain_text) if text_length > 50: # 长消息 default_interest = 0.4 @@ -322,13 +334,15 @@ class MessageStorage: default_interest = 0.35 # 如果是被@的消息,兴趣度更高 - if getattr(msg, 'is_mentioned', False): + if getattr(msg, "is_mentioned", False): default_interest = min(default_interest + 0.2, 0.8) # 执行更新 - update_stmt = update(Messages).where( - Messages.message_id == msg.message_id - ).values(interest_value=default_interest) + update_stmt = ( + update(Messages) + .where(Messages.message_id == msg.message_id) + .values(interest_value=default_interest) + ) result = await session.execute(update_stmt) if result.rowcount > 0: diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index f92956bc2..21a00ee52 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -40,7 +40,7 @@ class ChatterActionManager: @staticmethod def create_action( - action_name: str, + action_name: str, action_data: dict, reasoning: str, cycle_timers: dict, @@ -162,7 +162,7 @@ class ChatterActionManager: Returns: 执行结果 """ - from src.chat.message_manager.message_manager import message_manager + try: logger.debug(f"🎯 [ActionManager] execute_action接收到 target_message: {target_message}") # 通过chat_id获取chat_stream @@ -309,9 +309,7 @@ class ChatterActionManager: # 通过message_manager更新消息的动作记录并刷新focus_energy await message_manager.add_action( - stream_id=chat_stream.stream_id, - message_id=target_message_id, - action=action_name + stream_id=chat_stream.stream_id, message_id=target_message_id, action=action_name ) logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy") @@ -321,9 +319,10 @@ class ChatterActionManager: async def _reset_interruption_count_after_action(self, stream_id: str): """在动作执行成功后重置打断计数""" - from src.chat.message_manager.message_manager import message_manager + try: from src.plugin_system.apis.chat_api import get_chat_manager + chat_manager = get_chat_manager() chat_stream = chat_manager.get_stream(stream_id) if chat_stream: @@ -332,7 +331,9 @@ class ChatterActionManager: old_count = context.context.interruption_count old_afc_adjustment = context.context.get_afc_threshold_adjustment() context.context.reset_interruption_count() - logger.debug(f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0") + logger.debug( + f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0" + ) except Exception as e: logger.warning(f"重置打断计数时出错: {e}") @@ -531,7 +532,7 @@ class ChatterActionManager: # 根据新消息数量决定是否需要引用回复 reply_text = "" is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True - + logger.debug(f"[send_response] message_data: {message_data}") first_replied = False @@ -558,7 +559,9 @@ class ChatterActionManager: # 发送第一段回复 if not first_replied: set_reply_flag = bool(message_data) - logger.debug(f"📤 [ActionManager] 准备发送第一段回复。message_data: {message_data}, set_reply: {set_reply_flag}") + logger.debug( + f"📤 [ActionManager] 准备发送第一段回复。message_data: {message_data}, set_reply: {set_reply_flag}" + ) await send_api.text_to_stream( text=data, stream_id=chat_stream.stream_id, @@ -577,4 +580,4 @@ class ChatterActionManager: typing=True, ) - return reply_text \ No newline at end of file + return reply_text diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index ecd57639b..063fc1bf1 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -29,6 +29,7 @@ from src.chat.utils.chat_message_builder import ( replace_user_references_sync, ) from src.chat.express.expression_selector import expression_selector + # 旧记忆系统已被移除 # 旧记忆系统已被移除 from src.mood.mood_manager import mood_manager @@ -562,7 +563,9 @@ class DefaultReplyer: memory_context["user_aliases"] = memory_aliases if group_info_obj is not None: - group_name = getattr(group_info_obj, "group_name", None) or getattr(group_info_obj, "group_nickname", None) + group_name = getattr(group_info_obj, "group_name", None) or getattr( + group_info_obj, "group_nickname", None + ) if group_name: memory_context["group_name"] = str(group_name) group_id = getattr(group_info_obj, "group_id", None) @@ -576,11 +579,7 @@ class DefaultReplyer: # 检索相关记忆 enhanced_memories = await memory_system.retrieve_relevant_memories( - query=target, - user_id=memory_user_id, - scope_id=stream.stream_id, - context=memory_context, - limit=10 + query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10 ) # 注意:记忆存储已迁移到回复生成完成后进行,不在查询阶段执行 @@ -591,23 +590,27 @@ class DefaultReplyer: logger.debug(f"[记忆转换] 收到 {len(enhanced_memories)} 条原始记忆") for idx, memory_chunk in enumerate(enhanced_memories, 1): # 获取结构化内容的字符串表示 - structure_display = str(memory_chunk.content) if hasattr(memory_chunk, 'content') else "unknown" - + structure_display = str(memory_chunk.content) if hasattr(memory_chunk, "content") else "unknown" + # 获取记忆内容,优先使用display content = memory_chunk.display or memory_chunk.text_content or "" - + # 调试:记录每条记忆的内容获取情况 - logger.debug(f"[记忆转换] 第{idx}条: display={repr(memory_chunk.display)[:80]}, text_content={repr(memory_chunk.text_content)[:80]}, final_content={repr(content)[:80]}") - - running_memories.append({ - "content": content, - "memory_type": memory_chunk.memory_type.value, - "confidence": memory_chunk.metadata.confidence.value, - "importance": memory_chunk.metadata.importance.value, - "relevance": getattr(memory_chunk.metadata, 'relevance_score', 0.5), - "source": memory_chunk.metadata.source, - "structure": structure_display, - }) + logger.debug( + f"[记忆转换] 第{idx}条: display={repr(memory_chunk.display)[:80]}, text_content={repr(memory_chunk.text_content)[:80]}, final_content={repr(content)[:80]}" + ) + + running_memories.append( + { + "content": content, + "memory_type": memory_chunk.memory_type.value, + "confidence": memory_chunk.metadata.confidence.value, + "importance": memory_chunk.metadata.importance.value, + "relevance": getattr(memory_chunk.metadata, "relevance_score", 0.5), + "source": memory_chunk.metadata.source, + "structure": structure_display, + } + ) # 构建瞬时记忆字符串 if running_memories: @@ -615,7 +618,9 @@ class DefaultReplyer: if top_memory: instant_memory = top_memory[0].get("content", "") - logger.info(f"增强记忆系统检索到 {len(enhanced_memories)} 条原始记忆,转换为 {len(running_memories)} 条可用记忆") + logger.info( + f"增强记忆系统检索到 {len(enhanced_memories)} 条原始记忆,转换为 {len(running_memories)} 条可用记忆" + ) except Exception as e: logger.warning(f"增强记忆系统检索失败: {e}") @@ -632,17 +637,17 @@ class DefaultReplyer: memory_parts = ["### 🧠 相关记忆 (Relevant Memories)", ""] # 按相关度排序,并记录相关度信息用于调试 - sorted_memories = sorted(running_memories, key=lambda x: x.get('relevance', 0.0), reverse=True) + sorted_memories = sorted(running_memories, key=lambda x: x.get("relevance", 0.0), reverse=True) # 调试相关度信息 - relevance_info = [(m.get('memory_type', 'unknown'), m.get('relevance', 0.0)) for m in sorted_memories] + relevance_info = [(m.get("memory_type", "unknown"), m.get("relevance", 0.0)) for m in sorted_memories] logger.debug(f"记忆相关度信息: {relevance_info}") logger.debug(f"[记忆构建] 准备将 {len(sorted_memories)} 条记忆添加到提示词") for idx, running_memory in enumerate(sorted_memories, 1): - content = running_memory.get('content', '') - memory_type = running_memory.get('memory_type', 'unknown') - + content = running_memory.get("content", "") + memory_type = running_memory.get("memory_type", "unknown") + # 跳过空内容 if not content or not content.strip(): logger.warning(f"[记忆构建] 跳过第 {idx} 条记忆:内容为空 (type={memory_type})") @@ -801,10 +806,10 @@ class DefaultReplyer: """ try: # 从message_manager获取真实的已读/未读消息 - from src.chat.message_manager.message_manager import message_manager # 获取聊天流的上下文 from src.plugin_system.apis.chat_api import get_chat_manager + chat_manager = get_chat_manager() chat_stream = chat_manager.get_stream(chat_id) if chat_stream: @@ -1000,7 +1005,9 @@ class DefaultReplyer: interest_scores = {} try: - from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system as interest_scoring_system + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ( + chatter_interest_scoring_system as interest_scoring_system, + ) from src.common.data_models.database_data_model import DatabaseMessages # 转换消息格式 @@ -1145,7 +1152,7 @@ class DefaultReplyer: platform, # type: ignore reply_message.get("user_id"), # type: ignore reply_message.get("user_nickname"), - reply_message.get("user_cardname") + reply_message.get("user_cardname"), ) # 检查是否是bot自己的名字,如果是则替换为"(你)" @@ -1657,6 +1664,7 @@ class DefaultReplyer: # 创建关系追踪器实例 from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system + relationship_tracker = ChatterRelationshipTracker(chatter_interest_scoring_system) if relationship_tracker: # 获取用户信息以获取真实的user_id @@ -1699,7 +1707,7 @@ class DefaultReplyer: async def _store_chat_memory_async(self, reply_to: str, reply_message: Optional[Dict[str, Any]] = None): """ 异步存储聊天记忆(从build_memory_block迁移而来) - + Args: reply_to: 回复对象 reply_message: 回复的原始消息 @@ -1768,9 +1776,7 @@ class DefaultReplyer: memory_aliases.append(stripped) alias_values = ( - user_info_dict.get("aliases") - or user_info_dict.get("alias_names") - or user_info_dict.get("alias") + user_info_dict.get("aliases") or user_info_dict.get("alias_names") or user_info_dict.get("alias") ) if isinstance(alias_values, (list, tuple, set)): for alias in alias_values: @@ -1794,7 +1800,9 @@ class DefaultReplyer: memory_context["user_aliases"] = memory_aliases if group_info_obj is not None: - group_name = getattr(group_info_obj, "group_name", None) or getattr(group_info_obj, "group_nickname", None) + group_name = getattr(group_info_obj, "group_name", None) or getattr( + group_info_obj, "group_nickname", None + ) if group_name: memory_context["group_name"] = str(group_name) group_id = getattr(group_info_obj, "group_id", None) @@ -1826,11 +1834,11 @@ class DefaultReplyer: "conversation_text": chat_history, "user_id": memory_user_id, "scope_id": stream.stream_id, - **memory_context + **memory_context, } ) ) - + logger.debug(f"已启动记忆存储任务,用户: {memory_user_display or memory_user_id}") except Exception as e: diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 277ba3d23..8503e369a 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -13,6 +13,7 @@ from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_m from src.common.database.sqlalchemy_database_api import get_db_session from sqlalchemy import select, and_ from src.common.logger import get_logger + logger = get_logger("chat_message_builder") install(extra_lines=3) @@ -277,21 +278,52 @@ async def get_actions_by_timestamp_with_chat( async with get_db_session() as session: if limit > 0: + result = await session.execute( + select(ActionRecords) + .where( + and_( + ActionRecords.chat_id == chat_id, + ActionRecords.time >= timestamp_start, + ActionRecords.time <= timestamp_end, + ) + ) + .order_by(ActionRecords.time.desc()) + .limit(limit) + ) + actions = list(result.scalars()) + actions_result = [] + for action in reversed(actions): + action_dict = { + "id": action.id, + "action_id": action.action_id, + "time": action.time, + "action_name": action.action_name, + "action_data": action.action_data, + "action_done": action.action_done, + "action_build_into_prompt": action.action_build_into_prompt, + "action_prompt_display": action.action_prompt_display, + "chat_id": action.chat_id, + "chat_info_stream_id": action.chat_info_stream_id, + "chat_info_platform": action.chat_info_platform, + } + actions_result.append(action_dict) + actions_result.append(action_dict) + else: # earliest result = await session.execute( select(ActionRecords) .where( and_( ActionRecords.chat_id == chat_id, - ActionRecords.time >= timestamp_start, - ActionRecords.time <= timestamp_end, + ActionRecords.time > timestamp_start, + ActionRecords.time < timestamp_end, ) ) - .order_by(ActionRecords.time.desc()) + .order_by(ActionRecords.time.asc()) .limit(limit) ) actions = list(result.scalars()) actions_result = [] - for action in reversed(actions): + for action in actions: action_dict = { "id": action.id, "action_id": action.action_id, @@ -306,37 +338,6 @@ async def get_actions_by_timestamp_with_chat( "chat_info_platform": action.chat_info_platform, } actions_result.append(action_dict) - actions_result.append(action_dict) - else: # earliest - result = await session.execute( - select(ActionRecords) - .where( - and_( - ActionRecords.chat_id == chat_id, - ActionRecords.time > timestamp_start, - ActionRecords.time < timestamp_end, - ) - ) - .order_by(ActionRecords.time.asc()) - .limit(limit) - ) - actions = list(result.scalars()) - actions_result = [] - for action in actions: - action_dict = { - "id": action.id, - "action_id": action.action_id, - "time": action.time, - "action_name": action.action_name, - "action_data": action.action_data, - "action_done": action.action_done, - "action_build_into_prompt": action.action_build_into_prompt, - "action_prompt_display": action.action_prompt_display, - "chat_id": action.chat_id, - "chat_info_stream_id": action.chat_info_stream_id, - "chat_info_platform": action.chat_info_platform, - } - actions_result.append(action_dict) else: result = await session.execute( select(ActionRecords) @@ -460,7 +461,9 @@ async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp_with_chat( + chat_id: str, timestamp: float, limit: int = 0 +) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ @@ -469,7 +472,9 @@ async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -async def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp_with_users( + timestamp: float, person_ids: list, limit: int = 0 +) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ @@ -478,7 +483,9 @@ async def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: +async def num_new_messages_since( + chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None +) -> int: """ 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。 如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。 @@ -828,7 +835,7 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: async with get_db_session() as session: result = await session.execute(select(Images).where(Images.image_id == pic_id)) image = result.scalar_one_or_none() - if image and hasattr(image, 'description') and image.description: + if image and hasattr(image, "description") and image.description: description = image.description except Exception as e: # 如果查询失败,保持默认描述 @@ -1015,23 +1022,29 @@ async def build_readable_messages( async with get_db_session() as session: # 获取这个时间范围内的动作记录,并匹配chat_id - actions_in_range = (await session.execute( - select(ActionRecords) - .where( - and_( - ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id + actions_in_range = ( + await session.execute( + select(ActionRecords) + .where( + and_( + ActionRecords.time >= min_time, + ActionRecords.time <= max_time, + ActionRecords.chat_id == chat_id, + ) ) + .order_by(ActionRecords.time) ) - .order_by(ActionRecords.time) - )).scalars() + ).scalars() # 获取最新消息之后的第一个动作记录 - action_after_latest = (await session.execute( - select(ActionRecords) - .where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id)) - .order_by(ActionRecords.time) - .limit(1) - )).scalars() + action_after_latest = ( + await session.execute( + select(ActionRecords) + .where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id)) + .order_by(ActionRecords.time) + .limit(1) + ) + ).scalars() # 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError actions = [ @@ -1222,9 +1235,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: except Exception: return "?" - content = await replace_user_references_async( - content, platform, anon_name_resolver, replace_bot_name=False - ) + content = await replace_user_references_async(content, platform, anon_name_resolver, replace_bot_name=False) header = f"{anon_name}说 " output_lines.append(header) diff --git a/src/chat/utils/memory_mappings.py b/src/chat/utils/memory_mappings.py index 79ce50ade..4da20fdb5 100644 --- a/src/chat/utils/memory_mappings.py +++ b/src/chat/utils/memory_mappings.py @@ -17,7 +17,7 @@ MEMORY_TYPE_CHINESE_MAPPING = { "goal": "目标计划", "experience": "经验教训", "contextual": "上下文信息", - "unknown": "未知" + "unknown": "未知", } # 置信度等级到中文标签的映射表 @@ -30,7 +30,7 @@ CONFIDENCE_LEVEL_CHINESE_MAPPING = { "MEDIUM": "中等置信度", "HIGH": "高置信度", "VERIFIED": "已验证", - "unknown": "未知" + "unknown": "未知", } # 重要性等级到中文标签的映射表 @@ -43,7 +43,7 @@ IMPORTANCE_LEVEL_CHINESE_MAPPING = { "NORMAL": "一般重要性", "HIGH": "高重要性", "CRITICAL": "关键重要性", - "unknown": "未知" + "unknown": "未知", } @@ -69,7 +69,7 @@ def get_confidence_level_chinese_label(level) -> str: str: 对应的中文标签,如果找不到则返回"未知" """ # 处理枚举实例 - if hasattr(level, 'value'): + if hasattr(level, "value"): level = level.value # 处理数字 @@ -94,7 +94,7 @@ def get_importance_level_chinese_label(level) -> str: str: 对应的中文标签,如果找不到则返回"未知" """ # 处理枚举实例 - if hasattr(level, 'value'): + if hasattr(level, "value"): level = level.value # 处理数字 @@ -106,4 +106,4 @@ def get_importance_level_chinese_label(level) -> str: level_upper = level.upper() return IMPORTANCE_LEVEL_CHINESE_MAPPING.get(level_upper, "未知") - return "未知" \ No newline at end of file + return "未知" diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 3e011bb15..baf77a143 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -377,12 +377,12 @@ class Prompt: # 性能优化 - 为不同任务设置不同的超时时间 task_timeouts = { - "memory_block": 15.0, # 记忆系统 - 降低超时时间,鼓励预构建 - "tool_info": 15.0, # 工具信息 - "relation_info": 10.0, # 关系信息 - "knowledge_info": 10.0, # 知识库查询 - "cross_context": 10.0, # 上下文处理 - "expression_habits": 10.0, # 表达习惯 + "memory_block": 15.0, # 记忆系统 - 降低超时时间,鼓励预构建 + "tool_info": 15.0, # 工具信息 + "relation_info": 10.0, # 关系信息 + "knowledge_info": 10.0, # 知识库查询 + "cross_context": 10.0, # 上下文处理 + "expression_habits": 10.0, # 表达习惯 } # 分别处理每个任务,避免慢任务影响快任务 @@ -559,7 +559,7 @@ class Prompt: ), enhanced_memory_activator.get_instant_memory( target_message=self.parameters.target, chat_id=self.parameters.chat_id - ) + ), ] try: @@ -602,26 +602,27 @@ class Prompt: "opinion": "opinion", "personal_fact": "personal_fact", "preference": "preference", - "event": "event" + "event": "event", } mapped_type = memory_type_mapping.get(topic, "personal_fact") - formatted_memories.append({ - "display": display_text, - "memory_type": mapped_type, - "metadata": { - "confidence": memory.get("confidence", "未知"), - "importance": memory.get("importance", "一般"), - "timestamp": memory.get("timestamp", ""), - "source": memory.get("source", "unknown"), - "relevance_score": memory.get("relevance_score", 0.0) + formatted_memories.append( + { + "display": display_text, + "memory_type": mapped_type, + "metadata": { + "confidence": memory.get("confidence", "未知"), + "importance": memory.get("importance", "一般"), + "timestamp": memory.get("timestamp", ""), + "source": memory.get("source", "unknown"), + "relevance_score": memory.get("relevance_score", 0.0), + }, } - }) + ) # 使用方括号格式格式化记忆 memory_block = format_memories_bracket_style( - formatted_memories, - query_context=self.parameters.target + formatted_memories, query_context=self.parameters.target ) except Exception as e: logger.warning(f"记忆格式化失败,使用简化格式: {e}") @@ -829,7 +830,8 @@ class Prompt: "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), "safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), - "chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", + "chat_scene": self.parameters.chat_scene + or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", } def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: @@ -856,7 +858,8 @@ class Prompt: "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), "safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), - "chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", + "chat_scene": self.parameters.chat_scene + or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", } def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index ed8530387..1c879a01b 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -293,11 +293,14 @@ class StatisticOutputTask(AsyncTask): # 以最早的时间戳为起始时间获取记录 query_start_time = collect_period[-1][1] - records = await db_get( - model_class=LLMUsage, - filters={"timestamp": {"$gte": query_start_time}}, - order_by="-timestamp", - ) or [] + records = ( + await db_get( + model_class=LLMUsage, + filters={"timestamp": {"$gte": query_start_time}}, + order_by="-timestamp", + ) + or [] + ) for record in records: if not isinstance(record, dict): @@ -389,7 +392,9 @@ class StatisticOutputTask(AsyncTask): return stats @staticmethod - async def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]: + async def _collect_online_time_for_period( + collect_period: List[Tuple[str, datetime]], now: datetime + ) -> Dict[str, Any]: """ 收集指定时间段的在线时间统计数据 @@ -408,11 +413,14 @@ class StatisticOutputTask(AsyncTask): } query_start_time = collect_period[-1][1] - records = await db_get( - model_class=OnlineTime, - filters={"end_timestamp": {"$gte": query_start_time}}, - order_by="-end_timestamp", - ) or [] + records = ( + await db_get( + model_class=OnlineTime, + filters={"end_timestamp": {"$gte": query_start_time}}, + order_by="-end_timestamp", + ) + or [] + ) for record in records: if not isinstance(record, dict): @@ -464,11 +472,14 @@ class StatisticOutputTask(AsyncTask): } query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) - records = await db_get( - model_class=Messages, - filters={"time": {"$gte": query_start_timestamp}}, - order_by="-time", - ) or [] + records = ( + await db_get( + model_class=Messages, + filters={"time": {"$gte": query_start_timestamp}}, + order_by="-time", + ) + or [] + ) for message in records: if not isinstance(message, dict): @@ -1026,11 +1037,14 @@ class StatisticOutputTask(AsyncTask): interval_seconds = interval_minutes * 60 # 单次查询 LLMUsage - llm_records = await db_get( - model_class=LLMUsage, - filters={"timestamp": {"$gte": start_time}}, - order_by="-timestamp", - ) or [] + llm_records = ( + await db_get( + model_class=LLMUsage, + filters={"timestamp": {"$gte": start_time}}, + order_by="-timestamp", + ) + or [] + ) for record in llm_records: if not isinstance(record, dict) or not record.get("timestamp"): continue @@ -1056,11 +1070,14 @@ class StatisticOutputTask(AsyncTask): cost_by_module[module_name][idx] += cost # 单次查询 Messages - msg_records = await db_get( - model_class=Messages, - filters={"time": {"$gte": start_time.timestamp()}}, - order_by="-time", - ) or [] + msg_records = ( + await db_get( + model_class=Messages, + filters={"time": {"$gte": start_time.timestamp()}}, + order_by="-time", + ) + or [] + ) for msg in msg_records: if not isinstance(msg, dict) or not msg.get("time"): continue @@ -1363,4 +1380,4 @@ class StatisticOutputTask(AsyncTask): }}); - """ \ No newline at end of file + """ diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 989428677..ea3bdc89f 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -670,7 +670,6 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: if loop.is_running(): # 如果事件循环在运行,从其他线程提交并等待结果 try: - from concurrent.futures import TimeoutError fut = asyncio.run_coroutine_threadsafe( person_info_manager.get_value(person_id, "person_name"), loop diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 948b6f237..ab0915842 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -81,14 +81,16 @@ class ImageManager: """ try: async with get_db_session() as session: - record = (await session.execute( - select(ImageDescriptions).where( - and_( - ImageDescriptions.image_description_hash == image_hash, - ImageDescriptions.type == description_type, + record = ( + await session.execute( + select(ImageDescriptions).where( + and_( + ImageDescriptions.image_description_hash == image_hash, + ImageDescriptions.type == description_type, + ) ) ) - )).scalar() + ).scalar() return record.description if record else None except Exception as e: logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}") @@ -107,14 +109,16 @@ class ImageManager: current_timestamp = time.time() async with get_db_session() as session: # 查找现有记录 - existing = (await session.execute( - select(ImageDescriptions).where( - and_( - ImageDescriptions.image_description_hash == image_hash, - ImageDescriptions.type == description_type, + existing = ( + await session.execute( + select(ImageDescriptions).where( + and_( + ImageDescriptions.image_description_hash == image_hash, + ImageDescriptions.type == description_type, + ) ) ) - )).scalar() + ).scalar() if existing: # 更新现有记录 @@ -262,9 +266,11 @@ class ImageManager: from src.common.database.sqlalchemy_models import get_db_session async with get_db_session() as session: - existing_img = (await session.execute( - select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji")) - )).scalar() + existing_img = ( + await session.execute( + select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji")) + ) + ).scalar() if existing_img: existing_img.path = file_path diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index bbe489336..19ec72cb6 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -29,7 +29,7 @@ logger = get_logger("utils_video") # Rust模块可用性检测 RUST_VIDEO_AVAILABLE = False try: - import rust_video # pyright: ignore[reportMissingImports] + import rust_video # pyright: ignore[reportMissingImports] RUST_VIDEO_AVAILABLE = True logger.info("✅ Rust 视频处理模块加载成功") @@ -220,7 +220,7 @@ class VideoAnalyzer: return None async def _store_video_result( - self, video_hash: str, description: str, metadata: Optional[Dict] = None + self, video_hash: str, description: str, metadata: Optional[Dict] = None ) -> Optional[Videos]: """存储视频分析结果到数据库""" # 检查描述是否为错误信息,如果是则不保存 diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index 6ebe5c3d8..c77d9e8bd 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -77,7 +77,10 @@ class CacheManager: embedding_array = embedding_array.flatten() # 检查维度是否符合预期 - expected_dim = getattr(CacheManager, "embedding_dimension", None) or global_config.lpmm_knowledge.embedding_dimension + expected_dim = ( + getattr(CacheManager, "embedding_dimension", None) + or global_config.lpmm_knowledge.embedding_dimension + ) if embedding_array.shape[0] != expected_dim: logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}") return None diff --git a/src/common/database/db_migration.py b/src/common/database/db_migration.py index 085c277a3..2ab7ba13e 100644 --- a/src/common/database/db_migration.py +++ b/src/common/database/db_migration.py @@ -82,22 +82,26 @@ async def check_and_migrate_database(): if dialect.name == "sqlite" and isinstance(default_arg, bool): # SQLite 将布尔值存储为 0 或 1 default_value = "1" if default_arg else "0" - elif hasattr(compiler, 'render_literal_value'): + elif hasattr(compiler, "render_literal_value"): try: # 尝试使用 render_literal_value default_value = compiler.render_literal_value(default_arg, column.type) except AttributeError: # 如果失败,则回退到简单的字符串转换 - default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) + default_value = ( + f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) + ) else: # 对于没有 render_literal_value 的旧版或特定方言 - default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) - + default_value = ( + f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) + ) + sql += f" DEFAULT {default_value}" if not column.nullable: sql += " NOT NULL" - + conn.execute(text(sql)) logger.info(f"成功向表 '{table_name}' 添加列 '{column_name}'。") @@ -131,4 +135,3 @@ async def check_and_migrate_database(): continue logger.info("数据库结构检查与自动迁移完成。") - diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 210882dc8..330846983 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -423,4 +423,4 @@ async def store_action_info( except Exception as e: logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}") traceback.print_exc() - return None \ No newline at end of file + return None diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 64c1fd66a..2f78e56d0 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -781,6 +781,7 @@ async def get_db_session() -> AsyncGenerator[Optional[AsyncSession], None]: session = SessionLocal() # 对于 SQLite,在会话开始时设置 PRAGMA from src.config.config import global_config + if global_config.database.database_type == "sqlite": await session.execute(text("PRAGMA busy_timeout = 60000")) await session.execute(text("PRAGMA foreign_keys = ON")) diff --git a/src/common/logger.py b/src/common/logger.py index fa5e27d04..2830c127d 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -781,23 +781,23 @@ class ModuleColoredConsoleRenderer: thought_color = "\033[38;5;218m" # 分割消息内容 prefix, thought = event_content.split("内心思考:", 1) - + # 前缀部分(“决定进行回复,”)使用模块颜色 if module_color: prefix_colored = f"{module_color}{prefix.strip()}{RESET_COLOR}" else: prefix_colored = prefix.strip() - + # “内心思考”部分换行并使用专属颜色 thought_colored = f"\n\n{thought_color}内心思考:{thought.strip()}{RESET_COLOR}\n" - + # 重新组合 # parts.append(prefix_colored + thought_colored) # 将前缀和思考内容作为独立的part添加,避免它们之间出现多余的空格 if prefix_colored: parts.append(prefix_colored) parts.append(thought_colored) - + elif module_color: event_content = f"{module_color}{event_content}{RESET_COLOR}" parts.append(event_content) diff --git a/src/common/vector_db/chromadb_impl.py b/src/common/vector_db/chromadb_impl.py index d50ddf0de..a0267dfed 100644 --- a/src/common/vector_db/chromadb_impl.py +++ b/src/common/vector_db/chromadb_impl.py @@ -98,13 +98,13 @@ class ChromaDBImpl(VectorDBBase): "n_results": n_results, **kwargs, } - + # 修复ChromaDB的where条件格式 if where: processed_where = self._process_where_condition(where) if processed_where: query_params["where"] = processed_where - + return collection.query(**query_params) except Exception as e: logger.error(f"查询集合 '{collection_name}' 失败: {e}") @@ -114,7 +114,7 @@ class ChromaDBImpl(VectorDBBase): "query_embeddings": query_embeddings, "n_results": n_results, } - logger.warning(f"使用回退查询模式(无where条件)") + logger.warning("使用回退查询模式(无where条件)") return collection.query(**fallback_params) except Exception as fallback_e: logger.error(f"回退查询也失败: {fallback_e}") @@ -124,19 +124,19 @@ class ChromaDBImpl(VectorDBBase): """ 处理where条件,转换为ChromaDB支持的格式 ChromaDB支持的格式: - - 简单条件: {"field": "value"} + - 简单条件: {"field": "value"} - 操作符条件: {"field": {"$op": "value"}} - AND条件: {"$and": [condition1, condition2]} - OR条件: {"$or": [condition1, condition2]} """ if not where: return None - + try: # 如果只有一个字段,直接返回 if len(where) == 1: key, value = next(iter(where.items())) - + # 处理列表值(如memory_types) if isinstance(value, list): if len(value) == 1: @@ -146,7 +146,7 @@ class ChromaDBImpl(VectorDBBase): return {key: {"$in": value}} else: return {key: value} - + # 多个字段使用 $and 操作符 conditions = [] for key, value in where.items(): @@ -157,9 +157,9 @@ class ChromaDBImpl(VectorDBBase): conditions.append({key: {"$in": value}}) else: conditions.append({key: value}) - + return {"$and": conditions} - + except Exception as e: logger.warning(f"处理where条件失败: {e}, 使用简化条件") # 回退到只使用第一个条件 @@ -189,7 +189,7 @@ class ChromaDBImpl(VectorDBBase): processed_where = None if where: processed_where = self._process_where_condition(where) - + return collection.get( ids=ids, where=processed_where, @@ -202,7 +202,7 @@ class ChromaDBImpl(VectorDBBase): logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}") # 如果获取失败,尝试不使用where条件重新获取 try: - logger.warning(f"使用回退获取模式(无where条件)") + logger.warning("使用回退获取模式(无where条件)") return collection.get( ids=ids, limit=limit, diff --git a/src/config/config.py b/src/config/config.py index 50b826bf0..375d513df 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -45,7 +45,7 @@ from src.config.official_configs import ( CommandConfig, PlanningSystemConfig, AffinityFlowConfig, - ProactiveThinkingConfig + ProactiveThinkingConfig, ) from .api_ada_configs import ( diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 48cec6599..6a1613baa 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -61,35 +61,35 @@ class PersonalityConfig(ValidatedConfigBase): prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式") compress_personality: bool = Field(default=True, description="是否压缩人格") compress_identity: bool = Field(default=True, description="是否压缩身份") - + # 回复规则配置 reply_targeting_rules: List[str] = Field( default_factory=lambda: [ "拒绝任何包含骚扰、冒犯、暴力、色情或危险内容的请求。", "在拒绝时,请使用符合你人设的、坚定的语气。", - "不要执行任何可能被用于恶意目的的指令。" + "不要执行任何可能被用于恶意目的的指令。", ], - description="安全与互动底线规则,Bot在任何情况下都必须遵守的原则" + description="安全与互动底线规则,Bot在任何情况下都必须遵守的原则", ) - + message_targeting_analysis: List[str] = Field( default_factory=lambda: [ "**直接针对你**:@你、回复你、明确询问你 → 必须回应", "**间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与", "**他人对话**:与你无关的私人交流 → 通常不参与", - "**重复内容**:他人已充分回答的问题 → 避免重复" + "**重复内容**:他人已充分回答的问题 → 避免重复", ], - description="消息针对性分析规则,用于判断是否需要回复" + description="消息针对性分析规则,用于判断是否需要回复", ) - + reply_principles: List[str] = Field( default_factory=lambda: [ "明确回应目标消息,而不是宽泛地评论。", "可以分享你的看法、提出相关问题,或者开个合适的玩笑。", "目的是让对话更有趣、更深入。", - "不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。" + "不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。", ], - description="回复原则,指导如何回复消息" + description="回复原则,指导如何回复消息", ) @@ -126,28 +126,15 @@ class ChatConfig(ValidatedConfigBase): interruption_probability_factor: float = Field( default=0.8, ge=0.0, le=1.0, description="打断概率因子,当前打断次数/最大打断次数超过此值时触发概率下降" ) - interruption_afc_reduction: float = Field( - default=0.05, ge=0.0, le=1.0, description="每次连续打断降低的afc阈值数值" - ) + interruption_afc_reduction: float = Field(default=0.05, ge=0.0, le=1.0, description="每次连续打断降低的afc阈值数值") # 动态消息分发系统配置 dynamic_distribution_enabled: bool = Field(default=True, description="是否启用动态消息分发周期调整") - dynamic_distribution_base_interval: float = Field( - default=5.0, ge=1.0, le=60.0, description="基础分发间隔(秒)" - ) - dynamic_distribution_min_interval: float = Field( - default=1.0, ge=0.5, le=10.0, description="最小分发间隔(秒)" - ) - dynamic_distribution_max_interval: float = Field( - default=30.0, ge=5.0, le=300.0, description="最大分发间隔(秒)" - ) - dynamic_distribution_jitter_factor: float = Field( - default=0.2, ge=0.0, le=0.5, description="分发间隔随机扰动因子" - ) - max_concurrent_distributions: int = Field( - default=10, ge=1, le=100, description="最大并发处理的消息流数量" - ) - + dynamic_distribution_base_interval: float = Field(default=5.0, ge=1.0, le=60.0, description="基础分发间隔(秒)") + dynamic_distribution_min_interval: float = Field(default=1.0, ge=0.5, le=10.0, description="最小分发间隔(秒)") + dynamic_distribution_max_interval: float = Field(default=30.0, ge=5.0, le=300.0, description="最大分发间隔(秒)") + dynamic_distribution_jitter_factor: float = Field(default=0.2, ge=0.0, le=0.5, description="分发间隔随机扰动因子") + max_concurrent_distributions: int = Field(default=10, ge=1, le=100, description="最大并发处理的消息流数量") class MessageReceiveConfig(ValidatedConfigBase): @@ -309,11 +296,13 @@ class MemoryConfig(ValidatedConfigBase): enable_vector_memory_storage: bool = Field(default=True, description="启用Vector DB记忆存储") enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆") enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆") - + # Vector DB配置 vector_db_memory_collection: str = Field(default="unified_memory_v2", description="Vector DB记忆集合名称") vector_db_metadata_collection: str = Field(default="memory_metadata_v2", description="Vector DB元数据集合名称") - vector_db_similarity_threshold: float = Field(default=0.5, description="Vector DB相似度阈值(推荐0.5-0.6,过高会导致检索不到结果)") + vector_db_similarity_threshold: float = Field( + default=0.5, description="Vector DB相似度阈值(推荐0.5-0.6,过高会导致检索不到结果)" + ) vector_db_search_limit: int = Field(default=20, description="Vector DB搜索限制") vector_db_batch_size: int = Field(default=100, description="批处理大小") vector_db_enable_caching: bool = Field(default=True, description="启用内存缓存") @@ -327,23 +316,23 @@ class MemoryConfig(ValidatedConfigBase): base_forgetting_days: float = Field(default=30.0, description="基础遗忘天数") min_forgetting_days: float = Field(default=7.0, description="最小遗忘天数") max_forgetting_days: float = Field(default=365.0, description="最大遗忘天数") - + # 重要程度权重 critical_importance_bonus: float = Field(default=45.0, description="关键重要性额外天数") high_importance_bonus: float = Field(default=30.0, description="高重要性额外天数") normal_importance_bonus: float = Field(default=15.0, description="一般重要性额外天数") low_importance_bonus: float = Field(default=0.0, description="低重要性额外天数") - + # 置信度权重 verified_confidence_bonus: float = Field(default=30.0, description="已验证置信度额外天数") high_confidence_bonus: float = Field(default=20.0, description="高置信度额外天数") medium_confidence_bonus: float = Field(default=10.0, description="中等置信度额外天数") low_confidence_bonus: float = Field(default=0.0, description="低置信度额外天数") - + # 激活频率权重 activation_frequency_weight: float = Field(default=0.5, description="每次激活增加的天数权重") max_frequency_bonus: float = Field(default=10.0, description="最大激活频率奖励天数") - + # 休眠机制 dormant_threshold_days: int = Field(default=90, description="休眠状态判定天数") @@ -596,6 +585,7 @@ class SleepSystemConfig(ValidatedConfigBase): default="我准备睡觉了,请生成一句简短自然的晚安问候。", description="用于生成睡前消息的提示" ) + class ContextGroup(ValidatedConfigBase): """上下文共享组配置""" @@ -658,6 +648,7 @@ class AffinityFlowConfig(ValidatedConfigBase): mention_bot_interest_score: float = Field(default=0.6, description="提及bot的兴趣分") base_relationship_score: float = Field(default=0.5, description="基础人物关系分") + class ProactiveThinkingConfig(ValidatedConfigBase): """主动思考(主动发起对话)功能配置""" @@ -666,24 +657,26 @@ class ProactiveThinkingConfig(ValidatedConfigBase): # --- 触发时机 --- interval: int = Field(default=1500, description="基础触发间隔(秒),AI会围绕这个时间点主动发起对话") - interval_sigma: int = Field(default=120, description="间隔随机化标准差(秒),让触发时间更自然。设为0则为固定间隔。") + interval_sigma: int = Field( + default=120, description="间隔随机化标准差(秒),让触发时间更自然。设为0则为固定间隔。" + ) talk_frequency_adjust: list[list[str]] = Field( - default_factory=lambda: [['', '8:00,1', '12:00,1.2', '18:00,1.5', '01:00,0.6']], - description='每日活跃度调整,格式:[["", "HH:MM,factor", ...], ["stream_id", ...]]' + default_factory=lambda: [["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"]], + description='每日活跃度调整,格式:[["", "HH:MM,factor", ...], ["stream_id", ...]]', ) # --- 作用范围 --- enable_in_private: bool = Field(default=True, description="是否允许在私聊中主动发起对话") enable_in_group: bool = Field(default=True, description="是否允许在群聊中主动发起对话") enabled_private_chats: List[str] = Field( - default_factory=list, - description='私聊白名单,为空则对所有私聊生效。格式: ["platform:user_id", ...]' + default_factory=list, description='私聊白名单,为空则对所有私聊生效。格式: ["platform:user_id", ...]' ) enabled_group_chats: List[str] = Field( - default_factory=list, - description='群聊白名单,为空则对所有群聊生效。格式: ["platform:group_id", ...]' + default_factory=list, description='群聊白名单,为空则对所有群聊生效。格式: ["platform:group_id", ...]' ) # --- 冷启动配置 (针对私聊) --- enable_cold_start: bool = Field(default=True, description="对于白名单中不活跃的私聊,是否允许进行一次“冷启动”问候") - cold_start_cooldown: int = Field(default=86400, description="冷启动后,该私聊的下一次主动思考需要等待的最小时间(秒)") + cold_start_cooldown: int = Field( + default=86400, description="冷启动后,该私聊的下一次主动思考需要等待的最小时间(秒)" + ) diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index a4a106387..4716921f9 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -84,7 +84,9 @@ class Individuality: full_personality = f"{personality_result},{identity_result}" # 获取全局兴趣评分系统实例 - from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system as interest_scoring_system + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ( + chatter_interest_scoring_system as interest_scoring_system, + ) # 初始化智能兴趣系统 await interest_scoring_system.initialize_smart_interests( @@ -114,7 +116,7 @@ class Individuality: @staticmethod def _get_config_hash( - bot_nickname: str, personality_core: str, personality_side: str, identity: str + bot_nickname: str, personality_core: str, personality_side: str, identity: str ) -> tuple[str, str]: """获取personality和identity配置的哈希值 diff --git a/src/llm_models/model_client/aiohttp_gemini_client.py b/src/llm_models/model_client/aiohttp_gemini_client.py index 7b997b680..3d4dd8ca1 100644 --- a/src/llm_models/model_client/aiohttp_gemini_client.py +++ b/src/llm_models/model_client/aiohttp_gemini_client.py @@ -453,9 +453,7 @@ class AiohttpGeminiClient(BaseClient): # 构建请求体 request_data = { "contents": contents, - "generationConfig": _build_generation_config( - max_tokens, temperature, tb, response_format, extra_params - ), + "generationConfig": _build_generation_config(max_tokens, temperature, tb, response_format, extra_params), } # 添加系统指令 diff --git a/src/llm_models/payload_content/message.py b/src/llm_models/payload_content/message.py index 4253efab0..17d1fa30b 100644 --- a/src/llm_models/payload_content/message.py +++ b/src/llm_models/payload_content/message.py @@ -58,7 +58,7 @@ class MessageBuilder: self, image_format: str, image_base64: str, - support_formats=None, # 默认支持格式 + support_formats=None, # 默认支持格式 ) -> "MessageBuilder": """ 添加图片内容 diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index bd895693b..a8a68c2fb 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -18,6 +18,7 @@ - **LLMRequest (主接口)**: 作为模块的统一入口(Facade),为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。 """ + import re import asyncio import time @@ -26,14 +27,13 @@ import string from enum import Enum from rich.traceback import install -from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator +from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine from src.common.logger import get_logger from src.config.config import model_config from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig from .payload_content.message import MessageBuilder, Message -from .payload_content.resp_format import RespFormat -from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType +from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder from .model_client.base_client import BaseClient, APIResponse, client_registry, UsageRecord from .utils import compress_messages, llm_usage_recorder from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException @@ -46,6 +46,7 @@ logger = get_logger("model_utils") # Standalone Utility Functions # ============================================================================== + def _normalize_image_format(image_format: str) -> str: """ 标准化图片格式名称,确保与各种API的兼容性 @@ -57,17 +58,26 @@ def _normalize_image_format(image_format: str) -> str: str: 标准化后的图片格式 """ format_mapping = { - "jpg": "jpeg", "JPG": "jpeg", "JPEG": "jpeg", "jpeg": "jpeg", - "png": "png", "PNG": "png", - "webp": "webp", "WEBP": "webp", - "gif": "gif", "GIF": "gif", - "heic": "heic", "HEIC": "heic", - "heif": "heif", "HEIF": "heif", + "jpg": "jpeg", + "JPG": "jpeg", + "JPEG": "jpeg", + "jpeg": "jpeg", + "png": "png", + "PNG": "png", + "webp": "webp", + "WEBP": "webp", + "gif": "gif", + "GIF": "gif", + "heic": "heic", + "HEIC": "heic", + "heif": "heif", + "HEIF": "heif", } normalized = format_mapping.get(image_format, image_format.lower()) logger.debug(f"图片格式标准化: {image_format} -> {normalized}") return normalized + async def execute_concurrently( coro_callable: Callable[..., Coroutine[Any, Any, Any]], concurrency_count: int, @@ -103,25 +113,29 @@ async def execute_concurrently( for i, res in enumerate(results): if isinstance(res, Exception): logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}") - + first_exception = next((res for res in results if isinstance(res, Exception)), None) if first_exception: raise first_exception raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息") + class RequestType(Enum): """请求类型枚举""" + RESPONSE = "response" EMBEDDING = "embedding" AUDIO = "audio" + # ============================================================================== # Helper Classes for LLMRequest Refactoring # ============================================================================== + class _ModelSelector: """负责模型选择、负载均衡和动态故障切换的策略。""" - + CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数 DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量 @@ -168,16 +182,18 @@ class _ModelSelector: # - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。 least_used_model_name = min( candidate_models_usage, - key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000, + key=lambda k: candidate_models_usage[k][0] + + candidate_models_usage[k][1] * 300 + + candidate_models_usage[k][2] * 1000, ) - + model_info = model_config.get_model_info(least_used_model_name) api_provider = model_config.get_provider(model_info.api_provider) # 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。 # 这是为了避免在某些高并发场景下,共享的ClientSession可能引发的事件循环相关问题。 force_new_client = request_type == "embedding" client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - + logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}") # 增加所选模型的请求使用惩罚值,以实现动态负载均衡。 self.update_usage_penalty(model_info.name, increase=True) @@ -214,26 +230,32 @@ class _ModelSelector: if isinstance(e, (NetworkConnectionError, ReqAbortException)): # 网络连接错误或请求被中断,通常是基础设施问题,应重罚 penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER - logger.warning(f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}") + logger.warning( + f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}" + ) elif isinstance(e, RespNotOkException): # 对于HTTP响应错误,重点关注服务器端错误 if e.status_code >= 500: # 5xx 错误表明服务器端出现问题,应重罚 penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER - logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}") + logger.warning( + f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}" + ) else: # 4xx 客户端错误通常不代表模型本身不可用,给予基础惩罚 - logger.warning(f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") + logger.warning( + f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}" + ) else: # 其他未知异常,给予基础惩罚 logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}") - + self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty) class _PromptProcessor: """封装所有与提示词和响应内容的预处理和后处理逻辑。""" - + def __init__(self): """ 初始化提示处理器。 @@ -276,18 +298,18 @@ class _PromptProcessor: """ # 步骤1: 根据API提供商的配置应用内容混淆 processed_prompt = self._apply_content_obfuscation(prompt, api_provider) - + # 步骤2: 检查模型是否需要注入反截断指令 if getattr(model_info, "use_anti_truncation", False): processed_prompt += self.anti_truncation_instruction logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。") - + return processed_prompt def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]: """ 处理响应内容,提取思维链并检查截断。 - + Returns: Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断) """ @@ -317,14 +339,14 @@ class _PromptProcessor: # 检查当前API提供商是否启用了内容混淆功能 if not getattr(api_provider, "enable_content_obfuscation", False): return text - + # 获取混淆强度,默认为1 intensity = getattr(api_provider, "obfuscation_intensity", 1) logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}") - + # 将抗审查指令和原始文本拼接 processed_text = self.noise_instruction + "\n\n" + text - + # 在拼接后的文本中注入随机噪音 return self._inject_random_noise(processed_text, intensity) @@ -346,12 +368,12 @@ class _PromptProcessor: # 定义不同强度级别的噪音参数:概率和长度范围 params = { 1: {"probability": 15, "length": (3, 6)}, # 低强度 - 2: {"probability": 25, "length": (5, 10)}, # 中强度 - 3: {"probability": 35, "length": (8, 15)}, # 高强度 + 2: {"probability": 25, "length": (5, 10)}, # 中强度 + 3: {"probability": 35, "length": (8, 15)}, # 高强度 } # 根据传入的强度选择配置,如果强度无效则使用默认值 config = params.get(intensity, params[1]) - + words = text.split() result = [] # 遍历每个单词 @@ -366,7 +388,7 @@ class _PromptProcessor: # 生成噪音字符串 noise = "".join(random.choice(chars) for _ in range(noise_length)) result.append(noise) - + # 将处理后的单词列表重新组合成字符串 return " ".join(result) @@ -396,7 +418,7 @@ class _PromptProcessor: else: reasoning = "" clean_content = content.strip() - + return clean_content, reasoning @@ -441,7 +463,7 @@ class _RequestExecutor: """ retry_remain = api_provider.max_retry compressed_messages: Optional[List[Message]] = None - + while retry_remain > 0: try: # 优先使用压缩后的消息列表 @@ -451,11 +473,11 @@ class _RequestExecutor: # 根据请求类型调用不同的客户端方法 if request_type == RequestType.RESPONSE: assert current_messages is not None, "message_list cannot be None for response requests" - + # 修复: 防止 'message_list' 在 kwargs 中重复传递 request_params = kwargs.copy() request_params.pop("message_list", None) - + return await client.get_response( model_info=model_info, message_list=current_messages, **request_params ) @@ -463,15 +485,19 @@ class _RequestExecutor: return await client.get_embedding(model_info=model_info, **kwargs) elif request_type == RequestType.AUDIO: return await client.get_audio_transcriptions(model_info=model_info, **kwargs) - + except Exception as e: logger.debug(f"请求失败: {str(e)}") # 记录失败并更新模型的惩罚值 self.model_selector.update_failure_penalty(model_info.name, e) - + # 处理异常,决定是否重试以及等待多久 wait_interval, new_compressed_messages = self._handle_exception( - e, model_info, api_provider, retry_remain, (kwargs.get("message_list"), compressed_messages is not None) + e, + model_info, + api_provider, + retry_remain, + (kwargs.get("message_list"), compressed_messages is not None), ) if new_compressed_messages: compressed_messages = new_compressed_messages # 更新为压缩后的消息 @@ -482,7 +508,7 @@ class _RequestExecutor: await asyncio.sleep(wait_interval) # 等待指定时间后重试 finally: retry_remain -= 1 - + logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") raise RuntimeError("请求失败,已达到最大重试次数") @@ -491,7 +517,7 @@ class _RequestExecutor: ) -> Tuple[int, Optional[List[Message]]]: """ 默认异常处理函数,决定是否重试。 - + Returns: (等待间隔(-1表示不再重试), 新的消息列表(适用于压缩消息)) """ @@ -534,7 +560,9 @@ class _RequestExecutor: model_name = model_info.name # 处理客户端错误 (400-404),这些错误通常是请求本身有问题,不应重试 if e.status_code in [400, 401, 402, 403, 404]: - logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。") + logger.warning( + f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。" + ) return -1, None # 处理请求体过大的情况 elif e.status_code == 413: @@ -570,9 +598,11 @@ class _RequestExecutor: """ # 只有在剩余重试次数大于1时才进行下一次重试(因为当前这次失败已经消耗掉一次) if remain_try > 1: - logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。") + logger.warning( + f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。" + ) return interval, None - + # 如果已无剩余重试次数,则记录错误并返回-1表示放弃 logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},已达最大重试次数,放弃。") return -1, None @@ -585,7 +615,14 @@ class _RequestStrategy: 即使在单个模型或API端点失败的情况下也能正常工作。 """ - def __init__(self, model_selector: _ModelSelector, prompt_processor: _PromptProcessor, executor: _RequestExecutor, model_list: List[str], task_name: str): + def __init__( + self, + model_selector: _ModelSelector, + prompt_processor: _PromptProcessor, + executor: _RequestExecutor, + model_list: List[str], + task_name: str, + ): """ 初始化请求策略。 @@ -616,11 +653,13 @@ class _RequestStrategy: last_exception: Optional[Exception] = None for attempt in range(max_attempts): - selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request, str(request_type.value)) + selection_result = self.model_selector.select_best_available_model( + failed_models_in_this_request, str(request_type.value) + ) if selection_result is None: logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。") break - + model_info, api_provider, client = selection_result logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_info.name}'...") @@ -637,32 +676,36 @@ class _RequestStrategy: # 合并模型特定的额外参数 if model_info.extra_params: - request_kwargs["extra_params"] = {**model_info.extra_params, **request_kwargs.get("extra_params", {})} + request_kwargs["extra_params"] = { + **model_info.extra_params, + **request_kwargs.get("extra_params", {}), + } + + response = await self._try_model_request( + model_info, api_provider, client, request_type, **request_kwargs + ) - response = await self._try_model_request(model_info, api_provider, client, request_type, **request_kwargs) - # 成功,立即返回 logger.debug(f"模型 '{model_info.name}' 成功生成了回复。") self.model_selector.update_usage_penalty(model_info.name, increase=False) return response, model_info - + except Exception as e: logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。") failed_models_in_this_request.add(model_info.name) last_exception = e # 使用惩罚值已在 select 时增加,失败后不减少,以降低其后续被选中的概率 - + logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。") if raise_when_empty: if last_exception: raise RuntimeError("所有模型均未能生成响应。") from last_exception raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。") - + # 如果不抛出异常,返回一个备用响应 fallback_model_info = model_config.get_model_info(self.model_list[0]) return APIResponse(content="所有模型都请求失败"), fallback_model_info - async def _try_model_request( self, model_info: ModelInfo, api_provider: APIProvider, client: BaseClient, request_type: RequestType, **kwargs ) -> APIResponse: @@ -684,46 +727,49 @@ class _RequestStrategy: RuntimeError: 如果在达到最大重试次数后仍然收到空回复或截断的响应。 """ max_empty_retry = api_provider.max_retry - + for i in range(max_empty_retry + 1): - response = await self.executor.execute_request( - api_provider, client, request_type, model_info, **kwargs - ) + response = await self.executor.execute_request(api_provider, client, request_type, model_info, **kwargs) if request_type != RequestType.RESPONSE: - return response # 对于非响应类型,直接返回 + return response # 对于非响应类型,直接返回 # --- 响应内容处理和空回复/截断检查 --- content = response.content or "" use_anti_truncation = getattr(model_info, "use_anti_truncation", False) - processed_content, reasoning, is_truncated = self.prompt_processor.process_response(content, use_anti_truncation) - + processed_content, reasoning, is_truncated = self.prompt_processor.process_response( + content, use_anti_truncation + ) + # 更新响应对象 response.content = processed_content response.reasoning_content = response.reasoning_content or reasoning is_empty_reply = not response.tool_calls and not (response.content and response.content.strip()) - + if not is_empty_reply and not is_truncated: - return response # 成功获取有效响应 + return response # 成功获取有效响应 if i < max_empty_retry: reason = "空回复" if is_empty_reply else "截断" - logger.warning(f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})...") + logger.warning( + f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})..." + ) if api_provider.retry_interval > 0: await asyncio.sleep(api_provider.retry_interval) else: reason = "空回复" if is_empty_reply else "截断" logger.error(f"模型 '{model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。") raise RuntimeError(f"模型 '{model_info.name}' 已达到空回复/截断的最大内部重试次数。") - - raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里 + + raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里 # ============================================================================== # Main Facade Class # ============================================================================== + class LLMRequest: """ LLM请求协调器。 @@ -745,7 +791,7 @@ class LLMRequest: model: (0, 0, 0) for model in self.model_for_task.model_list } """模型使用量记录,(total_tokens, penalty, usage_penalty)""" - + # 初始化辅助类 self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage) self._prompt_processor = _PromptProcessor() @@ -769,36 +815,44 @@ class LLMRequest: prompt (str): 提示词 image_base64 (str): 图像的Base64编码字符串 image_format (str): 图像格式(如 'png', 'jpeg' 等) - + Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ start_time = time.time() - + # 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行 selection_result = self._model_selector.select_best_available_model(set(), "response") if not selection_result: raise RuntimeError("无法为图像响应选择可用模型。") model_info, api_provider, client = selection_result - + normalized_format = _normalize_image_format(image_format) - message = MessageBuilder().add_text_content(prompt).add_image_content( - image_base64=image_base64, - image_format=normalized_format, - support_formats=client.get_support_image_formats(), - ).build() + message = ( + MessageBuilder() + .add_text_content(prompt) + .add_image_content( + image_base64=image_base64, + image_format=normalized_format, + support_formats=client.get_support_image_formats(), + ) + .build() + ) response = await self._executor.execute_request( - api_provider, client, RequestType.RESPONSE, model_info, + api_provider, + client, + RequestType.RESPONSE, + model_info, message_list=[message], temperature=temperature, max_tokens=max_tokens, ) - + await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions") content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False) reasoning = response.reasoning_content or reasoning - + return content, (reasoning, model_info.name, response.tool_calls) async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: @@ -812,9 +866,7 @@ class LLMRequest: Returns: Optional[str]: 语音转换后的文本内容,如果所有模型都失败则返回None。 """ - response, _ = await self._strategy.execute_with_failover( - RequestType.AUDIO, audio_base64=voice_base64 - ) + response, _ = await self._strategy.execute_with_failover(RequestType.AUDIO, audio_base64=voice_base64) return response.content or None async def generate_response_async( @@ -834,7 +886,7 @@ class LLMRequest: max_tokens (int, optional): 最大token数 tools: 工具配置 raise_when_empty (bool): 是否在空回复时抛出异常 - + Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ @@ -842,12 +894,16 @@ class LLMRequest: if concurrency_count <= 1: return await self._execute_single_text_request(prompt, temperature, max_tokens, tools, raise_when_empty) - + try: return await execute_concurrently( self._execute_single_text_request, concurrency_count, - prompt, temperature, max_tokens, tools, raise_when_empty=False + prompt, + temperature, + max_tokens, + tools, + raise_when_empty=False, ) except Exception as e: logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}") @@ -885,7 +941,7 @@ class LLMRequest: response, model_info = await self._strategy.execute_with_failover( RequestType.RESPONSE, raise_when_empty=raise_when_empty, - prompt=prompt, # 传递原始prompt,由strategy处理 + prompt=prompt, # 传递原始prompt,由strategy处理 tool_options=tool_options, temperature=self.model_for_task.temperature if temperature is None else temperature, max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, @@ -906,21 +962,20 @@ class LLMRequest: Args: embedding_input (str): 获取嵌入的目标 - + Returns: (Tuple[List[float], str]): (嵌入向量,使用的模型名称) """ start_time = time.time() response, model_info = await self._strategy.execute_with_failover( - RequestType.EMBEDDING, - embedding_input=embedding_input + RequestType.EMBEDDING, embedding_input=embedding_input ) - + await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings") - + if not response.embedding: raise RuntimeError("获取embedding失败") - + return response.embedding, model_info.name async def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str): @@ -940,16 +995,18 @@ class LLMRequest: # 步骤1: 更新内存中的token计数,用于负载均衡 total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty) - + # 步骤2: 创建一个后台任务,将用量数据异步写入数据库 - asyncio.create_task(llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=usage, - user_id="system", # 此处可根据业务需求修改 - time_cost=time_cost, - request_type=self.task_name, - endpoint=endpoint, - )) + asyncio.create_task( + llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, + user_id="system", # 此处可根据业务需求修改 + time_cost=time_cost, + request_type=self.task_name, + endpoint=endpoint, + ) + ) @staticmethod def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: @@ -970,14 +1027,14 @@ class LLMRequest: # 如果没有提供工具,直接返回 None if not tools: return None - + tool_options: List[ToolOption] = [] # 遍历每个工具定义 for tool in tools: try: # 使用建造者模式创建 ToolOption builder = ToolOptionBuilder().set_name(tool["name"]).set_description(tool.get("description", "")) - + # 遍历工具的参数 for param in tool.get("parameters", []): # 严格验证参数格式是否为包含5个元素的元组 @@ -994,6 +1051,6 @@ class LLMRequest: except (KeyError, IndexError, TypeError, AssertionError) as e: # 如果构建过程中出现任何错误,记录日志并跳过该工具 logger.error(f"构建工具 '{tool.get('name', 'N/A')}' 失败: {e}") - + # 如果列表非空则返回列表,否则返回 None return tool_options or None diff --git a/src/main.py b/src/main.py index 44084509a..4e91f1419 100644 --- a/src/main.py +++ b/src/main.py @@ -78,6 +78,7 @@ class MainSystem: logger.info("收到退出信号,正在优雅关闭系统...") import asyncio + try: loop = asyncio.get_event_loop() if loop.is_running(): @@ -106,6 +107,7 @@ class MainSystem: # 停止消息管理器 try: from src.chat.message_manager import message_manager + await message_manager.stop() logger.info("🛑 消息管理器已停止") except Exception as e: @@ -241,7 +243,6 @@ MoFox_Bot(第三方修改版) # 处理所有缓存的事件订阅(插件加载完成后) event_manager.process_all_pending_subscriptions() - # 初始化表情管理器 get_emoji_manager().initialize() logger.info("表情包管理器初始化成功") diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py index e0d753842..ba8ee54eb 100644 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -50,13 +50,13 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: query_text=message.processed_plain_text, user_id=str(message.user_info.user_id), scope_id=message.chat_id, - limit=5 + limit=5, ) # 基于检索结果计算兴趣度 if enhanced_memories: # 有相关记忆,兴趣度基于相似度计算 - max_score = max(getattr(memory, 'relevance_score', 0.5) for memory in enhanced_memories) + max_score = max(getattr(memory, "relevance_score", 0.5) for memory in enhanced_memories) interested_rate = min(max_score, 1.0) # 限制在0-1之间 else: # 没有相关记忆,给予基础兴趣度 diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index a43d1186d..1c8782d23 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -4,6 +4,7 @@ from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat import time from src.chat.utils.utils import get_recent_group_speaker + # 旧的Hippocampus系统已被移除,现在使用增强记忆系统 # from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager import random @@ -181,14 +182,16 @@ class PromptBuilder: query_text=text, user_id="system", # 系统查询 scope_id="system", - limit=5 + limit=5, ) related_memory_info = "" if enhanced_memories: for memory_chunk in enhanced_memories: related_memory_info += memory_chunk.display or memory_chunk.text_content or "" - return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info.strip()) + return await global_prompt_manager.format_prompt( + "memory_prompt", memory_info=related_memory_info.strip() + ) return "" except Exception as e: diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index dc7f0f24b..66fcee96f 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -149,7 +149,7 @@ class ChatMood: self.mood_state = response self.last_change_time = message_time - + async def regress_mood(self): message_time = time.time() message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 4cfb75f94..478d4c9fb 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -1,4 +1,3 @@ -import asyncio import copy import datetime import hashlib @@ -145,7 +144,7 @@ class PersonInfoManager: except Exception as e: logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}") return "" - + @staticmethod async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str): """判断是否认识某人""" @@ -166,7 +165,7 @@ class PersonInfoManager: await person_info_manager.update_one_field( person_id=person_id, field_name="nickname", value=user_nickname, data=data ) - + @staticmethod async def create_person_info(person_id: str, data: Optional[dict] = None): """创建一个项""" @@ -491,7 +490,9 @@ class PersonInfoManager: async def _db_check_name_exists_async(name_to_check): async with get_db_session() as session: - result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)) + result = await session.execute( + select(PersonInfo).where(PersonInfo.person_name == name_to_check) + ) record = result.scalar() return record is not None @@ -552,7 +553,6 @@ class PersonInfoManager: else: logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行") - @staticmethod async def get_value(person_id: str, field_name: str) -> Any: """获取单个字段值(同步版本)""" @@ -623,6 +623,7 @@ class PersonInfoManager: result[field_name] = copy.deepcopy(person_info_default.get(field_name)) return result + @staticmethod async def get_specific_value_list( field_name: str, @@ -694,7 +695,7 @@ class PersonInfoManager: return record, False # 其他协程已创建,返回现有记录 # 如果仍然失败,重新抛出异常 raise e - + unique_nickname = await self._generate_unique_person_name(nickname) initial_data = { "person_id": person_id, diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index df0b25e77..4dc478f6c 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -303,12 +303,14 @@ class RelationshipBuilder: if not self.person_engaged_cache: return f"{self.log_prefix} 关系缓存为空" - status_lines = [f"{self.log_prefix} 关系缓存状态:", - f"最后处理消息时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.last_processed_message_time)) if self.last_processed_message_time > 0 else '未设置'}", - f"最后清理时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.last_cleanup_time)) if self.last_cleanup_time > 0 else '未执行'}", - f"总用户数:{len(self.person_engaged_cache)}", - f"清理配置:{'启用' if SEGMENT_CLEANUP_CONFIG['enable_cleanup'] else '禁用'} (最大保存{SEGMENT_CLEANUP_CONFIG['max_segment_age_days']}天, 每用户最多{SEGMENT_CLEANUP_CONFIG['max_segments_per_user']}段)", - ""] + status_lines = [ + f"{self.log_prefix} 关系缓存状态:", + f"最后处理消息时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.last_processed_message_time)) if self.last_processed_message_time > 0 else '未设置'}", + f"最后清理时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.last_cleanup_time)) if self.last_cleanup_time > 0 else '未执行'}", + f"总用户数:{len(self.person_engaged_cache)}", + f"清理配置:{'启用' if SEGMENT_CLEANUP_CONFIG['enable_cleanup'] else '禁用'} (最大保存{SEGMENT_CLEANUP_CONFIG['max_segment_age_days']}天, 每用户最多{SEGMENT_CLEANUP_CONFIG['max_segments_per_user']}段)", + "", + ] for person_id, segments in self.person_engaged_cache.items(): total_count = self._get_total_message_count(person_id) @@ -369,7 +371,7 @@ class RelationshipBuilder: for person_id, segments in self.person_engaged_cache.items(): total_message_count = self._get_total_message_count(person_id) person_name = get_person_info_manager().get_value(person_id, "person_name") or person_id - + if total_message_count >= max_build_threshold or ( total_message_count >= 5 and (immediate_build == person_id or immediate_build == "all") ): @@ -428,7 +430,9 @@ class RelationshipBuilder: start_date = time.strftime("%Y-%m-%d %H:%M", time.localtime(start_time)) # 获取该段的消息(包含边界) - segment_messages = await get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time) + segment_messages = await get_raw_msg_by_timestamp_with_chat_inclusive( + self.chat_id, start_time, end_time + ) logger.debug( f"消息段: {start_date} - {time.strftime('%Y-%m-%d %H:%M', time.localtime(end_time))}, 消息数: {len(segment_messages)}" ) diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 93269123b..90a353291 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -1,7 +1,6 @@ import time import traceback import orjson -import random from typing import List, Dict, Any from json_repair import repair_json diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py index 411a2d326..c80c5942c 100644 --- a/src/plugin_system/apis/__init__.py +++ b/src/plugin_system/apis/__init__.py @@ -19,7 +19,7 @@ from src.plugin_system.apis import ( send_api, tool_api, permission_api, - schedule_api + schedule_api, ) from src.plugin_system.apis.chat_api import ChatManager as context_api from .logging_api import get_logger diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index 612c243a3..baf6418dd 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -8,7 +8,7 @@ readable_text = message_api.build_readable_messages(messages) """ -from typing import List, Dict, Any, Tuple, Optional, Coroutine +from typing import List, Dict, Any, Tuple, Optional from src.config.config import global_config import time from src.chat.utils.chat_message_builder import ( @@ -181,9 +181,7 @@ async def get_messages_by_time_in_chat_for_users( raise ValueError("chat_id 不能为空") if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") - return await get_raw_msg_by_timestamp_with_chat_users( - chat_id, start_time, end_time, person_ids, limit, limit_mode - ) + return await get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode) async def get_random_chat_messages( @@ -384,9 +382,7 @@ async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Op return await num_new_messages_since(chat_id, start_time, end_time) -async def count_new_messages_for_users( - chat_id: str, start_time: float, end_time: float, person_ids: List[str] -) -> int: +async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int: """ 计算指定聊天中指定用户从开始时间到结束时间的新消息数量 diff --git a/src/plugin_system/apis/permission_api.py b/src/plugin_system/apis/permission_api.py index 97fde236c..61b4ca40f 100644 --- a/src/plugin_system/apis/permission_api.py +++ b/src/plugin_system/apis/permission_api.py @@ -61,8 +61,7 @@ class PermissionAPI: def __init__(self): self._permission_manager: Optional[IPermissionManager] = None # 需要保留的前缀(视为绝对节点名,不再自动加 plugins.. 前缀) - self.RESERVED_PREFIXES: tuple[str, ...] = ( - "system.") + self.RESERVED_PREFIXES: tuple[str, ...] = "system." # 系统节点列表 (name, description, default_granted) self._SYSTEM_NODES: list[tuple[str, str, bool]] = [ ("system.superuser", "系统超级管理员:拥有所有权限", False), diff --git a/src/plugin_system/apis/schedule_api.py b/src/plugin_system/apis/schedule_api.py index f1d81049e..e3e759968 100644 --- a/src/plugin_system/apis/schedule_api.py +++ b/src/plugin_system/apis/schedule_api.py @@ -28,6 +28,7 @@ asyncio.run(main()) """ + from datetime import datetime from typing import List, Dict, Any, Optional @@ -176,4 +177,4 @@ async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool: async def archive_monthly_plans(target_month: Optional[str] = None) -> bool: """(异步) 归档指定月份的月度计划的便捷函数""" - return await ScheduleAPI.archive_monthly_plans(target_month) \ No newline at end of file + return await ScheduleAPI.archive_monthly_plans(target_month) diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index ad6621b3a..c770db78b 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -80,7 +80,9 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa message_info = { "platform": message_dict.get("chat_info_platform", ""), - "message_id": message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id"), + "message_id": message_dict.get("message_id") + or message_dict.get("chat_info_message_id") + or message_dict.get("id"), "time": message_dict.get("time"), "group_info": group_info, "user_info": user_info, diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 26b79d4df..d3f012be5 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -2,7 +2,7 @@ import time import asyncio from abc import ABC, abstractmethod -from typing import Tuple, Optional, List, Dict, Any +from typing import Tuple, Optional, List, Dict from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream @@ -27,7 +27,7 @@ class BaseAction(ABC): - parallel_action: 是否允许并行执行 - random_activation_probability: 随机激活概率 - llm_judge_prompt: LLM判断提示词 - + 二步Action相关属性: - is_two_step_action: 是否为二步Action - step_one_description: 第一步的描述 @@ -434,7 +434,9 @@ class BaseAction(ABC): # 确保获取的是Action组件 if component_info.component_type != ComponentType.ACTION: - logger.error(f"{log_prefix} 尝试调用的组件 '{action_name}' 不是一个Action,而是一个 '{component_info.component_type.value}'") + logger.error( + f"{log_prefix} 尝试调用的组件 '{action_name}' 不是一个Action,而是一个 '{component_info.component_type.value}'" + ) return False, f"组件 '{action_name}' 不是一个有效的Action" plugin_config = component_registry.get_plugin_config(component_info.plugin_name) @@ -527,20 +529,20 @@ class BaseAction(ABC): # 第一步:展示可用的子Action available_actions = [sub_action[0] for sub_action in self.sub_actions] description = self.step_one_description or f"{self.action_name}支持以下操作" - + actions_list = "\n".join([f"- {action}: {desc}" for action, desc, _ in self.sub_actions]) response = f"{description}\n\n可用操作:\n{actions_list}\n\n请选择要执行的操作。" - + return True, response else: # 验证选择的子Action是否有效 valid_actions = [sub_action[0] for sub_action in self.sub_actions] if selected_action not in valid_actions: return False, f"无效的操作选择: {selected_action}。可用操作: {valid_actions}" - + # 保存选择的子Action self._selected_sub_action = selected_action - + # 调用第二步执行 return await self.execute_step_two(selected_action) @@ -571,7 +573,7 @@ class BaseAction(ABC): # 如果是二步Action,自动处理第一步 if self.is_two_step_action: return await self.handle_step_one() - + # 普通Action由子类实现 pass diff --git a/src/plugin_system/base/base_chatter.py b/src/plugin_system/base/base_chatter.py index 1bdb79c31..1dd225252 100644 --- a/src/plugin_system/base/base_chatter.py +++ b/src/plugin_system/base/base_chatter.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod -from typing import List, Optional, TYPE_CHECKING +from typing import List, TYPE_CHECKING from src.common.data_models.message_manager_data_model import StreamContext from .component_types import ChatType from src.plugin_system.base.component_types import ChatterInfo, ComponentType if TYPE_CHECKING: from src.chat.planner_actions.action_manager import ChatterActionManager - from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner as ActionPlanner + class BaseChatter(ABC): chatter_name: str = "" @@ -15,7 +15,7 @@ class BaseChatter(ABC): """Chatter组件的描述""" chat_types: List[ChatType] = [ChatType.PRIVATE, ChatType.GROUP] - def __init__(self, stream_id: str, action_manager: 'ChatterActionManager'): + def __init__(self, stream_id: str, action_manager: "ChatterActionManager"): """ 初始化聊天处理器 @@ -45,11 +45,10 @@ class BaseChatter(ABC): Returns: ChatterInfo对象 """ - + return ChatterInfo( name=cls.chatter_name, description=cls.chatter_description or "No description provided.", chat_type_allow=cls.chat_types[0], component_type=ComponentType.CHATTER, ) - diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py index 84dc8b150..229cadb63 100644 --- a/src/plugin_system/base/base_tool.py +++ b/src/plugin_system/base/base_tool.py @@ -64,7 +64,15 @@ class BaseTool(ABC): return { "name": cls.name, "description": cls.step_one_description or cls.description, - "parameters": [("action", ToolParamType.STRING, "选择要执行的操作", True, [sub_tool[0] for sub_tool in cls.sub_tools])] + "parameters": [ + ( + "action", + ToolParamType.STRING, + "选择要执行的操作", + True, + [sub_tool[0] for sub_tool in cls.sub_tools], + ) + ], } else: # 普通工具需要parameters @@ -88,12 +96,8 @@ class BaseTool(ABC): # 查找对应的子工具 for sub_name, sub_desc, sub_params in cls.sub_tools: if sub_name == sub_tool_name: - return { - "name": f"{cls.name}_{sub_tool_name}", - "description": sub_desc, - "parameters": sub_params - } - + return {"name": f"{cls.name}_{sub_tool_name}", "description": sub_desc, "parameters": sub_params} + raise ValueError(f"未找到子工具: {sub_tool_name}") @classmethod @@ -105,14 +109,10 @@ class BaseTool(ABC): """ if not cls.is_two_step_tool: return [] - + definitions = [] for sub_name, sub_desc, sub_params in cls.sub_tools: - definitions.append({ - "name": f"{cls.name}_{sub_name}", - "description": sub_desc, - "parameters": sub_params - }) + definitions.append({"name": f"{cls.name}_{sub_name}", "description": sub_desc, "parameters": sub_params}) return definitions @classmethod @@ -144,7 +144,7 @@ class BaseTool(ABC): # 如果是二步工具,处理第一步调用 if self.is_two_step_tool and "action" in function_args: return await self._handle_step_one(function_args) - + raise NotImplementedError("子类必须实现execute方法") async def _handle_step_one(self, function_args: dict[str, Any]) -> dict[str, Any]: @@ -174,17 +174,13 @@ class BaseTool(ABC): sub_name, sub_desc, sub_params = sub_tool_found # 返回第二步工具定义 - step_two_definition = { - "name": f"{self.name}_{sub_name}", - "description": sub_desc, - "parameters": sub_params - } + step_two_definition = {"name": f"{self.name}_{sub_name}", "description": sub_desc, "parameters": sub_params} return { "type": "two_step_tool_step_one", "content": f"已选择操作: {action}。请使用以下工具进行具体调用:", "next_tool_definition": step_two_definition, - "selected_action": action + "selected_action": action, } async def execute_step_two(self, sub_tool_name: str, function_args: dict[str, Any]) -> dict[str, Any]: diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 6d0590d43..2b1122b9f 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -40,7 +40,7 @@ class ActionActivationType(Enum): # 聊天模式枚举 class ChatMode(Enum): """聊天模式枚举""" - + FOCUS = "focus" # 专注模式 NORMAL = "normal" # Normal聊天模式 PROACTIVE = "proactive" # 主动思考模式 diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index 12797bafd..a61b8e04c 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -294,9 +294,7 @@ class PluginBase(ABC): changed = False # 内部递归函数 - def _sync_dicts( - schema_dict: Dict[str, Any], user_dict: Dict[str, Any], parent_key: str = "" - ) -> Dict[str, Any]: + def _sync_dicts(schema_dict: Dict[str, Any], user_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, Any]: nonlocal changed synced_dict = schema_dict.copy() diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 63594c53e..9c82553f8 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -1,7 +1,7 @@ from pathlib import Path import re -from typing import TYPE_CHECKING, Dict, List, Optional, Any, Pattern, Tuple, Union, Type +from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type from src.common.logger import get_logger from src.plugin_system.base.component_types import ( @@ -34,44 +34,46 @@ class ComponentRegistry: def __init__(self): # 命名空间式组件名构成法 f"{component_type}.{component_name}" - self._components: Dict[str, 'ComponentInfo'] = {} + self._components: Dict[str, "ComponentInfo"] = {} """组件注册表 命名空间式组件名 -> 组件信息""" - self._components_by_type: Dict['ComponentType', Dict[str, 'ComponentInfo']] = {types: {} for types in ComponentType} + self._components_by_type: Dict["ComponentType", Dict[str, "ComponentInfo"]] = { + types: {} for types in ComponentType + } """类型 -> 组件原名称 -> 组件信息""" self._components_classes: Dict[ - str, Type[Union['BaseCommand', 'BaseAction', 'BaseTool', 'BaseEventHandler', 'PlusCommand', 'BaseChatter']] + str, Type[Union["BaseCommand", "BaseAction", "BaseTool", "BaseEventHandler", "PlusCommand", "BaseChatter"]] ] = {} """命名空间式组件名 -> 组件类""" # 插件注册表 - self._plugins: Dict[str, 'PluginInfo'] = {} + self._plugins: Dict[str, "PluginInfo"] = {} """插件名 -> 插件信息""" # Action特定注册表 - self._action_registry: Dict[str, Type['BaseAction']] = {} + self._action_registry: Dict[str, Type["BaseAction"]] = {} """Action注册表 action名 -> action类""" - self._default_actions: Dict[str, 'ActionInfo'] = {} + self._default_actions: Dict[str, "ActionInfo"] = {} """默认动作集,即启用的Action集,用于重置ActionManager状态""" # Command特定注册表 - self._command_registry: Dict[str, Type['BaseCommand']] = {} + self._command_registry: Dict[str, Type["BaseCommand"]] = {} """Command类注册表 command名 -> command类""" self._command_patterns: Dict[Pattern, str] = {} """编译后的正则 -> command名""" # 工具特定注册表 - self._tool_registry: Dict[str, Type['BaseTool']] = {} # 工具名 -> 工具类 - self._llm_available_tools: Dict[str, Type['BaseTool']] = {} # llm可用的工具名 -> 工具类 + self._tool_registry: Dict[str, Type["BaseTool"]] = {} # 工具名 -> 工具类 + self._llm_available_tools: Dict[str, Type["BaseTool"]] = {} # llm可用的工具名 -> 工具类 # EventHandler特定注册表 - self._event_handler_registry: Dict[str, Type['BaseEventHandler']] = {} + self._event_handler_registry: Dict[str, Type["BaseEventHandler"]] = {} """event_handler名 -> event_handler类""" - self._enabled_event_handlers: Dict[str, Type['BaseEventHandler']] = {} + self._enabled_event_handlers: Dict[str, Type["BaseEventHandler"]] = {} """启用的事件处理器 event_handler名 -> event_handler类""" - self._chatter_registry: Dict[str, Type['BaseChatter']] = {} + self._chatter_registry: Dict[str, Type["BaseChatter"]] = {} """chatter名 -> chatter类""" - self._enabled_chatter_registry: Dict[str, Type['BaseChatter']] = {} + self._enabled_chatter_registry: Dict[str, Type["BaseChatter"]] = {} """启用的chatter名 -> chatter类""" logger.info("组件注册中心初始化完成") @@ -99,7 +101,7 @@ class ComponentRegistry: def register_component( self, component_info: ComponentInfo, - component_class: Type[Union['BaseCommand', 'BaseAction', 'BaseEventHandler', 'BaseTool', 'BaseChatter']], + component_class: Type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]], ) -> bool: """注册组件 @@ -172,7 +174,7 @@ class ComponentRegistry: ) return True - def _register_action_component(self, action_info: 'ActionInfo', action_class: Type['BaseAction']) -> bool: + def _register_action_component(self, action_info: "ActionInfo", action_class: Type["BaseAction"]) -> bool: """注册Action组件到Action特定注册表""" if not (action_name := action_info.name): logger.error(f"Action组件 {action_class.__name__} 必须指定名称") @@ -192,7 +194,7 @@ class ComponentRegistry: return True - def _register_command_component(self, command_info: 'CommandInfo', command_class: Type['BaseCommand']) -> bool: + def _register_command_component(self, command_info: "CommandInfo", command_class: Type["BaseCommand"]) -> bool: """注册Command组件到Command特定注册表""" if not (command_name := command_info.name): logger.error(f"Command组件 {command_class.__name__} 必须指定名称") @@ -219,7 +221,7 @@ class ComponentRegistry: return True def _register_plus_command_component( - self, plus_command_info: 'PlusCommandInfo', plus_command_class: Type['PlusCommand'] + self, plus_command_info: "PlusCommandInfo", plus_command_class: Type["PlusCommand"] ) -> bool: """注册PlusCommand组件到特定注册表""" plus_command_name = plus_command_info.name @@ -233,7 +235,7 @@ class ComponentRegistry: # 创建专门的PlusCommand注册表(如果还没有) if not hasattr(self, "_plus_command_registry"): - self._plus_command_registry: Dict[str, Type['PlusCommand']] = {} + self._plus_command_registry: Dict[str, Type["PlusCommand"]] = {} plus_command_class.plugin_name = plus_command_info.plugin_name # 设置插件配置 @@ -243,7 +245,7 @@ class ComponentRegistry: logger.debug(f"已注册PlusCommand组件: {plus_command_name}") return True - def _register_tool_component(self, tool_info: 'ToolInfo', tool_class: Type['BaseTool']) -> bool: + def _register_tool_component(self, tool_info: "ToolInfo", tool_class: Type["BaseTool"]) -> bool: """注册Tool组件到Tool特定注册表""" tool_name = tool_info.name @@ -259,7 +261,7 @@ class ComponentRegistry: return True def _register_event_handler_component( - self, handler_info: 'EventHandlerInfo', handler_class: Type['BaseEventHandler'] + self, handler_info: "EventHandlerInfo", handler_class: Type["BaseEventHandler"] ) -> bool: if not (handler_name := handler_info.name): logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称") @@ -285,7 +287,7 @@ class ComponentRegistry: handler_class, self.get_plugin_config(handler_info.plugin_name) or {} ) - def _register_chatter_component(self, chatter_info: 'ChatterInfo', chatter_class: Type['BaseChatter']) -> bool: + def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: Type["BaseChatter"]) -> bool: """注册Chatter组件到Chatter特定注册表""" chatter_name = chatter_info.name @@ -312,7 +314,7 @@ class ComponentRegistry: # === 组件移除相关 === - async def remove_component(self, component_name: str, component_type: 'ComponentType', plugin_name: str) -> bool: + async def remove_component(self, component_name: str, component_type: "ComponentType", plugin_name: str) -> bool: target_component_class = self.get_component_class(component_name, component_type) if not target_component_class: logger.warning(f"组件 {component_name} 未注册,无法移除") @@ -362,7 +364,7 @@ class ComponentRegistry: case ComponentType.CHATTER: # 移除Chatter注册 - if hasattr(self, '_chatter_registry'): + if hasattr(self, "_chatter_registry"): self._chatter_registry.pop(component_name, None) logger.debug(f"已移除Chatter组件: {component_name}") @@ -484,8 +486,8 @@ class ComponentRegistry: # === 组件查询方法 === def get_component_info( - self, component_name: str, component_type: Optional['ComponentType'] = None - ) -> Optional['ComponentInfo']: + self, component_name: str, component_type: Optional["ComponentType"] = None + ) -> Optional["ComponentInfo"]: # sourcery skip: class-extract-method """获取组件信息,支持自动命名空间解析 @@ -529,8 +531,8 @@ class ComponentRegistry: def get_component_class( self, component_name: str, - component_type: Optional['ComponentType'] = None, - ) -> Optional[Union[Type['BaseCommand'], Type['BaseAction'], Type['BaseEventHandler'], Type['BaseTool']]]: + component_type: Optional["ComponentType"] = None, + ) -> Optional[Union[Type["BaseCommand"], Type["BaseAction"], Type["BaseEventHandler"], Type["BaseTool"]]]: """获取组件类,支持自动命名空间解析 Args: @@ -572,22 +574,22 @@ class ComponentRegistry: # 4. 都没找到 return None - def get_components_by_type(self, component_type: 'ComponentType') -> Dict[str, 'ComponentInfo']: + def get_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]: """获取指定类型的所有组件""" return self._components_by_type.get(component_type, {}).copy() - def get_enabled_components_by_type(self, component_type: 'ComponentType') -> Dict[str, 'ComponentInfo']: + def get_enabled_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]: """获取指定类型的所有启用组件""" components = self.get_components_by_type(component_type) return {name: info for name, info in components.items() if info.enabled} # === Action特定查询方法 === - def get_action_registry(self) -> Dict[str, Type['BaseAction']]: + def get_action_registry(self) -> Dict[str, Type["BaseAction"]]: """获取Action注册表""" return self._action_registry.copy() - def get_registered_action_info(self, action_name: str) -> Optional['ActionInfo']: + def get_registered_action_info(self, action_name: str) -> Optional["ActionInfo"]: """获取Action信息""" info = self.get_component_info(action_name, ComponentType.ACTION) return info if isinstance(info, ActionInfo) else None @@ -598,11 +600,11 @@ class ComponentRegistry: # === Command特定查询方法 === - def get_command_registry(self) -> Dict[str, Type['BaseCommand']]: + def get_command_registry(self) -> Dict[str, Type["BaseCommand"]]: """获取Command注册表""" return self._command_registry.copy() - def get_registered_command_info(self, command_name: str) -> Optional['CommandInfo']: + def get_registered_command_info(self, command_name: str) -> Optional["CommandInfo"]: """获取Command信息""" info = self.get_component_info(command_name, ComponentType.COMMAND) return info if isinstance(info, CommandInfo) else None @@ -611,7 +613,7 @@ class ComponentRegistry: """获取Command模式注册表""" return self._command_patterns.copy() - def find_command_by_text(self, text: str) -> Optional[Tuple[Type['BaseCommand'], dict, 'CommandInfo']]: + def find_command_by_text(self, text: str) -> Optional[Tuple[Type["BaseCommand"], dict, "CommandInfo"]]: # sourcery skip: use-named-expression, use-next """根据文本查找匹配的命令 @@ -638,15 +640,15 @@ class ComponentRegistry: return None # === Tool 特定查询方法 === - def get_tool_registry(self) -> Dict[str, Type['BaseTool']]: + def get_tool_registry(self) -> Dict[str, Type["BaseTool"]]: """获取Tool注册表""" return self._tool_registry.copy() - def get_llm_available_tools(self) -> Dict[str, Type['BaseTool']]: + def get_llm_available_tools(self) -> Dict[str, Type["BaseTool"]]: """获取LLM可用的Tool列表""" return self._llm_available_tools.copy() - def get_registered_tool_info(self, tool_name: str) -> Optional['ToolInfo']: + def get_registered_tool_info(self, tool_name: str) -> Optional["ToolInfo"]: """获取Tool信息 Args: @@ -659,13 +661,13 @@ class ComponentRegistry: return info if isinstance(info, ToolInfo) else None # === PlusCommand 特定查询方法 === - def get_plus_command_registry(self) -> Dict[str, Type['PlusCommand']]: + def get_plus_command_registry(self) -> Dict[str, Type["PlusCommand"]]: """获取PlusCommand注册表""" if not hasattr(self, "_plus_command_registry"): self._plus_command_registry: Dict[str, Type[PlusCommand]] = {} return self._plus_command_registry.copy() - def get_registered_plus_command_info(self, command_name: str) -> Optional['PlusCommandInfo']: + def get_registered_plus_command_info(self, command_name: str) -> Optional["PlusCommandInfo"]: """获取PlusCommand信息 Args: @@ -679,44 +681,44 @@ class ComponentRegistry: # === EventHandler 特定查询方法 === - def get_event_handler_registry(self) -> Dict[str, Type['BaseEventHandler']]: + def get_event_handler_registry(self) -> Dict[str, Type["BaseEventHandler"]]: """获取事件处理器注册表""" return self._event_handler_registry.copy() - def get_registered_event_handler_info(self, handler_name: str) -> Optional['EventHandlerInfo']: + def get_registered_event_handler_info(self, handler_name: str) -> Optional["EventHandlerInfo"]: """获取事件处理器信息""" info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER) return info if isinstance(info, EventHandlerInfo) else None - def get_enabled_event_handlers(self) -> Dict[str, Type['BaseEventHandler']]: + def get_enabled_event_handlers(self) -> Dict[str, Type["BaseEventHandler"]]: """获取启用的事件处理器""" return self._enabled_event_handlers.copy() # === Chatter 特定查询方法 === - def get_chatter_registry(self) -> Dict[str, Type['BaseChatter']]: + def get_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]: """获取Chatter注册表""" - if not hasattr(self, '_chatter_registry'): + if not hasattr(self, "_chatter_registry"): self._chatter_registry: Dict[str, Type[BaseChatter]] = {} return self._chatter_registry.copy() - - def get_enabled_chatter_registry(self) -> Dict[str, Type['BaseChatter']]: + + def get_enabled_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]: """获取启用的Chatter注册表""" - if not hasattr(self, '_enabled_chatter_registry'): + if not hasattr(self, "_enabled_chatter_registry"): self._enabled_chatter_registry: Dict[str, Type[BaseChatter]] = {} return self._enabled_chatter_registry.copy() - - def get_registered_chatter_info(self, chatter_name: str) -> Optional['ChatterInfo']: + + def get_registered_chatter_info(self, chatter_name: str) -> Optional["ChatterInfo"]: """获取Chatter信息""" info = self.get_component_info(chatter_name, ComponentType.CHATTER) return info if isinstance(info, ChatterInfo) else None - + # === 插件查询方法 === - def get_plugin_info(self, plugin_name: str) -> Optional['PluginInfo']: + def get_plugin_info(self, plugin_name: str) -> Optional["PluginInfo"]: """获取插件信息""" return self._plugins.get(plugin_name) - def get_all_plugins(self) -> Dict[str, 'PluginInfo']: + def get_all_plugins(self) -> Dict[str, "PluginInfo"]: """获取所有插件""" return self._plugins.copy() @@ -724,7 +726,7 @@ class ComponentRegistry: # """获取所有启用的插件""" # return {name: info for name, info in self._plugins.items() if info.enabled} - def get_plugin_components(self, plugin_name: str) -> List['ComponentInfo']: + def get_plugin_components(self, plugin_name: str) -> List["ComponentInfo"]: """获取插件的所有组件""" plugin_info = self.get_plugin_info(plugin_name) return plugin_info.components if plugin_info else [] diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index c79012e13..0bb22afdf 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -95,17 +95,16 @@ class PermissionManager(IPermissionManager): # 检查用户是否有明确的权限设置 result = await session.execute( - select(UserPermissions) - .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) + select(UserPermissions).filter_by( + platform=user.platform, user_id=user.user_id, permission_node=permission_node + ) ) user_perm = result.scalar_one_or_none() if user_perm: # 有明确设置,返回设置的值 res = user_perm.granted - logger.debug( - f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {res}" - ) + logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {res}") return res else: # 没有明确设置,使用默认值 @@ -191,8 +190,9 @@ class PermissionManager(IPermissionManager): # 检查是否已有权限记录 result = await session.execute( - select(UserPermissions) - .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) + select(UserPermissions).filter_by( + platform=user.platform, user_id=user.user_id, permission_node=permission_node + ) ) existing_perm = result.scalar_one_or_none() @@ -244,8 +244,9 @@ class PermissionManager(IPermissionManager): # 检查是否已有权限记录 result = await session.execute( - select(UserPermissions) - .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) + select(UserPermissions).filter_by( + platform=user.platform, user_id=user.user_id, permission_node=permission_node + ) ) existing_perm = result.scalar_one_or_none() @@ -303,8 +304,9 @@ class PermissionManager(IPermissionManager): for node in all_nodes: # 检查用户是否有明确的权限设置 result = await session.execute( - select(UserPermissions) - .filter_by(platform=user.platform, user_id=user.user_id, permission_node=node.node_name) + select(UserPermissions).filter_by( + platform=user.platform, user_id=user.user_id, permission_node=node.node_name + ) ) user_perm = result.scalar_one_or_none() @@ -408,8 +410,7 @@ class PermissionManager(IPermissionManager): # 删除用户权限记录 result = await session.execute( - delete(UserPermissions) - .where(UserPermissions.permission_node.in_(node_names)) + delete(UserPermissions).where(UserPermissions.permission_node.in_(node_names)) ) deleted_user_perms = result.rowcount diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 9ec216a9b..2950101a9 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -1,7 +1,5 @@ import asyncio import os -import shutil -import hashlib import traceback import importlib @@ -106,7 +104,6 @@ class PluginManager: if not plugin_dir: return False, 1 - plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败) if not plugin_instance: logger.error(f"插件 {plugin_name} 实例化失败") @@ -545,9 +542,7 @@ class PluginManager: try: loop = asyncio.get_event_loop() if loop.is_running(): - fut = asyncio.run_coroutine_threadsafe( - component_registry.unregister_plugin(plugin_name), loop - ) + fut = asyncio.run_coroutine_threadsafe(component_registry.unregister_plugin(plugin_name), loop) fut.result(timeout=5) else: asyncio.run(component_registry.unregister_plugin(plugin_name)) diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index daa8244cf..e666e32d4 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -116,17 +116,17 @@ class ToolExecutor: def _get_tool_definitions(self) -> List[Dict[str, Any]]: all_tools = get_llm_available_tool_definitions() user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) - + # 获取基础工具定义(包括二步工具的第一步) tool_definitions = [definition for name, definition in all_tools if name not in user_disabled_tools] - + # 检查是否有待处理的二步工具第二步调用 - pending_step_two = getattr(self, '_pending_step_two_tools', {}) + pending_step_two = getattr(self, "_pending_step_two_tools", {}) if pending_step_two: # 添加第二步工具定义 for tool_name, step_two_def in pending_step_two.items(): tool_definitions.append(step_two_def) - + return tool_definitions async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: @@ -266,7 +266,7 @@ class ToolExecutor: f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}" ) function_args["llm_called"] = True # 标记为LLM调用 - + # 检查是否是二步工具的第二步调用 if "_" in function_name and function_name.count("_") >= 1: # 可能是二步工具的第二步调用,格式为 "tool_name_sub_tool_name" @@ -274,14 +274,14 @@ class ToolExecutor: if len(parts) == 2: base_tool_name, sub_tool_name = parts base_tool_instance = get_tool_instance(base_tool_name) - + if base_tool_instance and base_tool_instance.is_two_step_tool: logger.info(f"{self.log_prefix}执行二步工具第二步: {base_tool_name}.{sub_tool_name}") result = await base_tool_instance.execute_step_two(sub_tool_name, function_args) - + # 清理待处理的第二步工具 self._pending_step_two_tools.pop(base_tool_name, None) - + if result: logger.debug(f"{self.log_prefix}二步工具第二步 {function_name} 执行成功") return { @@ -291,7 +291,7 @@ class ToolExecutor: "type": "function", "content": result.get("content", ""), } - + # 获取对应工具实例 tool_instance = tool_instance or get_tool_instance(function_name) if not tool_instance: @@ -301,7 +301,7 @@ class ToolExecutor: # 执行工具并记录日志 logger.debug(f"{self.log_prefix}执行工具 {function_name},参数: {function_args}") result = await tool_instance.execute(function_args) - + # 检查是否是二步工具的第一步结果 if result and result.get("type") == "two_step_tool_step_one": logger.info(f"{self.log_prefix}二步工具第一步完成: {function_name}") @@ -310,7 +310,7 @@ class ToolExecutor: if next_tool_def: self._pending_step_two_tools[function_name] = next_tool_def logger.debug(f"{self.log_prefix}已保存第二步工具定义: {next_tool_def['name']}") - + if result: logger.debug(f"{self.log_prefix}工具 {function_name} 执行成功,结果: {result}") return { diff --git a/src/plugin_system/utils/permission_decorators.py b/src/plugin_system/utils/permission_decorators.py index 990f1c91c..278ab2068 100644 --- a/src/plugin_system/utils/permission_decorators.py +++ b/src/plugin_system/utils/permission_decorators.py @@ -79,9 +79,11 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None) if not iscoroutinefunction(func): logger.warning(f"函数 {func.__name__} 使用 require_permission 但非异步,已强制阻止执行") + async def blocked(*_a, **_k): logger.error("同步函数不再支持权限装饰器,请改为 async def") return None + return blocked return async_wrapper @@ -146,9 +148,11 @@ def require_master(deny_message: Optional[str] = None): if not iscoroutinefunction(func): logger.warning(f"函数 {func.__name__} 使用 require_master 但非异步,已强制阻止执行") + async def blocked(*_a, **_k): logger.error("同步函数不再支持 require_master,请改为 async def") return None + return blocked return async_wrapper @@ -164,7 +168,9 @@ class PermissionChecker: @staticmethod def check_permission(chat_stream: ChatStream, permission_node: str) -> bool: - raise RuntimeError("PermissionChecker.check_permission 已移除同步支持,请直接 await permission_api.check_permission") + raise RuntimeError( + "PermissionChecker.check_permission 已移除同步支持,请直接 await permission_api.check_permission" + ) @staticmethod def is_master(chat_stream: ChatStream) -> bool: diff --git a/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py index 8a331f99e..1bb60146b 100644 --- a/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py +++ b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py @@ -12,7 +12,7 @@ from src.common.data_models.info_data_model import InterestScore from src.chat.interest_system import bot_interest_manager from src.common.logger import get_logger from src.config.config import global_config -from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker + logger = get_logger("chatter_interest_scoring") # 定义颜色 @@ -45,7 +45,7 @@ class ChatterInterestScoringSystem: self.probability_boost_per_no_reply = ( affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count ) # 每次不回复增加的概率 - + # 用户关系数据 self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py index cf14c221e..8d322c880 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py @@ -65,7 +65,7 @@ class ChatterPlanExecutor: if not plan.decided_actions: logger.info("没有需要执行的动作。") return {"executed_count": 0, "results": []} - + # 像hfc一样,提前打印将要执行的动作 action_types = [action.action_type for action in plan.decided_actions] logger.info(f"选择动作: {', '.join(action_types) if action_types else '无'}") @@ -150,17 +150,19 @@ class ChatterPlanExecutor: for i, action_info in enumerate(unique_actions): is_last_action = i == total_actions - 1 if total_actions > 1: - logger.info(f"[多重回复] 正在执行第 {i+1}/{total_actions} 个回复...") + logger.info(f"[多重回复] 正在执行第 {i + 1}/{total_actions} 个回复...") # 传递 clear_unread 参数 result = await self._execute_single_reply_action(action_info, plan, clear_unread=is_last_action) results.append(result) if total_actions > 1: - logger.info(f"[多重回复] 所有回复任务执行完毕。") + logger.info("[多重回复] 所有回复任务执行完毕。") return {"results": results} - async def _execute_single_reply_action(self, action_info: ActionPlannerInfo, plan: Plan, clear_unread: bool = True) -> Dict[str, any]: + async def _execute_single_reply_action( + self, action_info: ActionPlannerInfo, plan: Plan, clear_unread: bool = True + ) -> Dict[str, any]: """执行单个回复动作""" start_time = time.time() success = False @@ -201,7 +203,7 @@ class ChatterPlanExecutor: execution_result = await self.action_manager.execute_action( action_name=action_info.action_type, **action_params ) - + # 从返回结果中提取真正的回复文本 if isinstance(execution_result, dict): reply_content = execution_result.get("reply_text", "") @@ -233,7 +235,9 @@ class ChatterPlanExecutor: "error_message": error_message, "execution_time": execution_time, "reasoning": action_info.reasoning, - "reply_content": reply_content[:200] + "..." if reply_content and len(reply_content) > 200 else reply_content, + "reply_content": reply_content[:200] + "..." + if reply_content and len(reply_content) > 200 + else reply_content, } async def _execute_other_actions(self, other_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]: diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py index f84bdec46..1bc153fad 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py @@ -100,7 +100,7 @@ class ChatterPlanFilter: # 预解析 action_type 来进行判断 thinking = item.get("thinking", "未提供思考过程") actions_obj = item.get("actions", {}) - + # 处理actions字段可能是字典或列表的情况 if isinstance(actions_obj, dict): action_type = actions_obj.get("action_type", "no_action") @@ -116,14 +116,12 @@ class ChatterPlanFilter: if action_type in reply_action_types: if not reply_action_added: - final_actions.extend( - await self._parse_single_action(item, used_message_id_list, plan) - ) + final_actions.extend(await self._parse_single_action(item, used_message_id_list, plan)) reply_action_added = True else: # 非回复类动作直接添加 final_actions.extend(await self._parse_single_action(item, used_message_id_list, plan)) - + if thinking and thinking != "未提供思考过程": logger.info(f"\n{SAKURA_PINK}思考: {thinking}{RESET_COLOR}\n") plan.decided_actions = self._filter_no_actions(final_actions) @@ -154,6 +152,7 @@ class ChatterPlanFilter: schedule_block = "" # 优先检查是否被吵醒 from src.chat.message_manager.message_manager import message_manager + angry_prompt_addition = "" wakeup_mgr = message_manager.wakeup_manager @@ -161,7 +160,7 @@ class ChatterPlanFilter: # 检查1: 直接从 wakeup_manager 获取 if wakeup_mgr.is_in_angry_state(): angry_prompt_addition = wakeup_mgr.get_angry_prompt_addition() - + # 检查2: 如果上面没获取到,再从 mood_manager 确认 if not angry_prompt_addition: chat_mood_for_check = mood_manager.get_mood_by_chat_id(plan.chat_id) @@ -274,7 +273,9 @@ class ChatterPlanFilter: is_group_chat = plan.chat_type == ChatType.GROUP chat_context_description = "你现在正在一个群聊中" if not is_group_chat and plan.target_info: - chat_target_name = plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方" + chat_target_name = ( + plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方" + ) chat_context_description = f"你正在和 {chat_target_name} 私聊" action_options_block = await self._build_action_options(plan.available_actions) @@ -315,12 +316,12 @@ class ChatterPlanFilter: """构建已读/未读历史消息块""" try: # 从message_manager获取真实的已读/未读消息 - from src.chat.message_manager.message_manager import message_manager from src.chat.utils.utils import assign_message_ids from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat # 获取聊天流的上下文 from src.plugin_system.apis.chat_api import get_chat_manager + chat_manager = get_chat_manager() chat_stream = chat_manager.get_stream(plan.chat_id) if not chat_stream: @@ -333,6 +334,7 @@ class ChatterPlanFilter: read_messages = stream_context.context.history_messages # 已读消息存储在history_messages中 if not read_messages: from src.common.data_models.database_data_model import DatabaseMessages + # 如果内存中没有已读消息(比如刚启动),则从数据库加载最近的上下文 fallback_messages_dicts = await get_raw_msg_before_timestamp_with_chat( chat_id=plan.chat_id, @@ -414,7 +416,7 @@ class ChatterPlanFilter: processed_plain_text=msg_dict.get("processed_plain_text", ""), key_words=msg_dict.get("key_words", "[]"), is_mentioned=msg_dict.get("is_mentioned", False), - **{"user_info": user_info_dict} # 通过kwargs传入user_info + **{"user_info": user_info_dict}, # 通过kwargs传入user_info ) else: # 如果没有user_info字段,使用平铺的字段(flatten()方法返回的格式) @@ -425,13 +427,12 @@ class ChatterPlanFilter: user_platform=msg_dict.get("user_platform", ""), processed_plain_text=msg_dict.get("processed_plain_text", ""), key_words=msg_dict.get("key_words", "[]"), - is_mentioned=msg_dict.get("is_mentioned", False) + is_mentioned=msg_dict.get("is_mentioned", False), ) # 计算消息兴趣度 interest_score_obj = await chatter_interest_scoring_system._calculate_single_message_score( - message=db_message, - bot_nickname=global_config.bot.nickname + message=db_message, bot_nickname=global_config.bot.nickname ) interest_score = interest_score_obj.total_score @@ -454,7 +455,7 @@ class ChatterPlanFilter: try: # 从新的actions结构中获取动作信息 actions_obj = action_json.get("actions", {}) - + # 处理actions字段可能是字典或列表的情况 actions_to_process = [] if isinstance(actions_obj, dict): @@ -463,19 +464,23 @@ class ChatterPlanFilter: actions_to_process.extend(actions_obj) if not actions_to_process: - actions_to_process.append({"action_type": "no_action", "reason": "actions格式错误"}) + actions_to_process.append({"action_type": "no_action", "reason": "actions格式错误"}) for single_action_obj in actions_to_process: if not isinstance(single_action_obj, dict): continue action = single_action_obj.get("action_type", "no_action") - reasoning = single_action_obj.get("reasoning", "未提供原因") # 兼容旧的reason字段 + reasoning = single_action_obj.get("reasoning", "未提供原因") # 兼容旧的reason字段 action_data = single_action_obj.get("action_data", {}) - + # 为了向后兼容,如果action_data不存在,则从顶层字段获取 if not action_data: - action_data = {k: v for k, v in single_action_obj.items() if k not in ["action_type", "reason", "reasoning", "thinking"]} + action_data = { + k: v + for k, v in single_action_obj.items() + if k not in ["action_type", "reason", "reasoning", "thinking"] + } # 保留原始的thinking字段(如果有) thinking = action_json.get("thinking", "") @@ -501,7 +506,9 @@ class ChatterPlanFilter: # reply动作必须有目标消息,使用最新消息作为兜底 target_message_dict = self._get_latest_message(message_id_list) if target_message_dict: - logger.info(f"[{action}] 使用最新消息作为目标: {target_message_dict.get('message_id')}") + logger.info( + f"[{action}] 使用最新消息作为目标: {target_message_dict.get('message_id')}" + ) else: logger.error(f"[{action}] 无法找到任何目标消息,降级为no_action") action = "no_action" @@ -509,15 +516,21 @@ class ChatterPlanFilter: elif action in ["poke_user", "set_emoji_like"]: # 这些动作可以尝试其他策略 - target_message_dict = self._find_poke_notice(message_id_list) or self._get_latest_message(message_id_list) + target_message_dict = self._find_poke_notice( + message_id_list + ) or self._get_latest_message(message_id_list) if target_message_dict: - logger.info(f"[{action}] 使用替代消息作为目标: {target_message_dict.get('message_id')}") + logger.info( + f"[{action}] 使用替代消息作为目标: {target_message_dict.get('message_id')}" + ) else: # 其他动作使用最新消息或跳过 target_message_dict = self._get_latest_message(message_id_list) if target_message_dict: - logger.info(f"[{action}] 使用最新消息作为目标: {target_message_dict.get('message_id')}") + logger.info( + f"[{action}] 使用最新消息作为目标: {target_message_dict.get('message_id')}" + ) else: # 如果LLM没有指定target_message_id,进行特殊处理 if action == "poke_user": @@ -614,7 +627,7 @@ class ChatterPlanFilter: query_text=query, user_id="system", # 系统查询 scope_id="system", - limit=5 + limit=5, ) if not enhanced_memories: @@ -627,7 +640,9 @@ class ChatterPlanFilter: memory_type = memory_chunk.memory_type.value if memory_chunk.memory_type else "unknown" retrieved_memories.append((memory_type, content)) - memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories] + memory_statements = [ + f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories + ] except Exception as e: logger.warning(f"增强记忆系统检索失败,使用默认回复: {e}") @@ -648,12 +663,17 @@ class ChatterPlanFilter: if action_name == "set_emoji_like" and p_name == "emoji": # 特殊处理set_emoji_like的emoji参数 from src.plugins.built_in.social_toolkit_plugin.qq_emoji_list import qq_face - emoji_options = [re.search(r"\[表情:(.+?)\]", name).group(1) for name in qq_face.values() if re.search(r"\[表情:(.+?)\]", name)] + + emoji_options = [ + re.search(r"\[表情:(.+?)\]", name).group(1) + for name in qq_face.values() + if re.search(r"\[表情:(.+?)\]", name) + ] example_value = f"<从'{', '.join(emoji_options[:10])}...'中选择一个>" else: example_value = f"<{p_desc}>" params_json_list.append(f' "{p_name}": "{example_value}"') - + # 基础动作信息 action_description = action_info.description action_require = "\n".join(f"- {req}" for req in action_info.action_require) @@ -666,11 +686,11 @@ class ChatterPlanFilter: # 将参数列表合并到JSON示例中 if params_json_list: # 移除最后一行的逗号 - json_example_lines.extend([line.rstrip(',') for line in params_json_list]) + json_example_lines.extend([line.rstrip(",") for line in params_json_list]) json_example_lines.append(' "reason": "<执行该动作的详细原因>"') json_example_lines.append(" }") - + # 使用逗号连接内部元素,除了最后一个 json_parts = [] for i, line in enumerate(json_example_lines): @@ -678,14 +698,14 @@ class ChatterPlanFilter: if line.strip() in ["{", "}"]: json_parts.append(line) continue - + # 检查是否是最后一个需要逗号的元素 is_last_item = True - for next_line in json_example_lines[i+1:]: + for next_line in json_example_lines[i + 1 :]: if next_line.strip() not in ["}"]: is_last_item = False break - + if not is_last_item: json_parts.append(f"{line},") else: @@ -713,7 +733,7 @@ class ChatterPlanFilter: # 1. 标准化处理:去除可能的格式干扰 original_id = str(message_id).strip() - normalized_id = original_id.strip('<>"\'').strip() + normalized_id = original_id.strip("<>\"'").strip() if not normalized_id: return None @@ -731,12 +751,13 @@ class ChatterPlanFilter: # 处理包含在文本中的ID格式 (如 "消息m123" -> 提取 m123) import re + # 尝试提取各种格式的ID id_patterns = [ - r'm\d+', # m123格式 - r'\d+', # 纯数字格式 - r'buffered-[a-f0-9-]+', # buffered-xxxx格式 - r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', # UUID格式 + r"m\d+", # m123格式 + r"\d+", # 纯数字格式 + r"buffered-[a-f0-9-]+", # buffered-xxxx格式 + r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", # UUID格式 ] for pattern in id_patterns: @@ -771,12 +792,12 @@ class ChatterPlanFilter: # 4. 尝试模糊匹配(数字部分匹配) for candidate in candidate_ids: # 提取数字部分进行模糊匹配 - number_part = re.sub(r'[^0-9]', '', candidate) + number_part = re.sub(r"[^0-9]", "", candidate) if number_part: for item in message_id_list: if isinstance(item, dict): item_id = item.get("id", "") - item_number = re.sub(r'[^0-9]', '', item_id) + item_number = re.sub(r"[^0-9]", "", item_id) # 数字部分匹配 if item_number == number_part: @@ -787,7 +808,7 @@ class ChatterPlanFilter: message_obj = item.get("message") if isinstance(message_obj, dict): orig_mid = message_obj.get("message_id") or message_obj.get("id") - orig_number = re.sub(r'[^0-9]', '', str(orig_mid)) if orig_mid else "" + orig_number = re.sub(r"[^0-9]", "", str(orig_mid)) if orig_mid else "" if orig_number == number_part: logger.debug(f"模糊匹配成功(消息对象): {candidate} -> {orig_mid}") return message_obj diff --git a/src/plugins/built_in/affinity_flow_chatter/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner.py index 420487d4f..b6be512b9 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner.py @@ -4,7 +4,6 @@ """ from dataclasses import asdict -import time from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor @@ -90,15 +89,16 @@ class ChatterActionPlanner: try: # 在规划前,先进行动作修改 from src.chat.planner_actions.action_modifier import ActionModifier + action_modifier = ActionModifier(self.action_manager, self.chat_id) await action_modifier.modify_actions() - + # 1. 生成初始 Plan initial_plan = await self.generator.generate(context.chat_mode) # 确保Plan中包含所有当前可用的动作 initial_plan.available_actions = self.action_manager.get_using_actions() - + unread_messages = context.get_unread_messages() if context else [] # 2. 使用新的兴趣度管理系统进行评分 score = 0.0 @@ -117,7 +117,9 @@ class ChatterActionPlanner: message_interest = interest_score.total_score message.interest_value = message_interest - message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold + message.should_reply = ( + message_interest > global_config.affinity_flow.non_reply_action_interest_threshold + ) interest_updates.append( { diff --git a/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py b/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py index 66b3ca31f..2320670a0 100644 --- a/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py +++ b/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py @@ -596,7 +596,7 @@ class ChatterRelationshipTracker: quality = response_data.get("interaction_quality", "medium") # 更新数据库 - await self._update_user_relationship_in_db(user_id, new_text, new_score) + await self._update_user_relationship_in_db(user_id, new_text, new_score) # 更新缓存 self.user_relationship_cache[user_id] = { diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index f043e4f49..a477fdf0a 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -268,7 +268,7 @@ class EmojiAction(BaseAction): if not success: logger.error(f"{self.log_prefix} 表情包发送失败") await self.store_action_info( - action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包,但失败了", action_done=False + action_build_into_prompt=True, action_prompt_display="发送了一个表情包,但失败了", action_done=False ) return False, "表情包发送失败" @@ -279,7 +279,7 @@ class EmojiAction(BaseAction): logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}") await self.store_action_info( - action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包", action_done=True + action_build_into_prompt=True, action_prompt_display="发送了一个表情包", action_done=True ) return True, f"发送表情包: {emoji_description}" diff --git a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py index 91dea3105..194a2c5ef 100644 --- a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py +++ b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py @@ -47,9 +47,9 @@ class SearchKnowledgeFromLPMMTool(BaseTool): knowledge_parts = [] for i, item in enumerate(knowledge_info["knowledge_items"]): knowledge_parts.append(f"- {item.get('content', 'N/A')}") - + knowledge_text = "\n".join(knowledge_parts) - summary = knowledge_info.get('summary', '无总结') + summary = knowledge_info.get("summary", "无总结") content = f"关于 '{query}', 你知道以下信息:\n{knowledge_text}\n\n总结: {summary}" else: content = f"关于 '{query}',你的知识库里好像没有相关的信息呢" diff --git a/src/plugins/built_in/maizone_refactored/plugin.py b/src/plugins/built_in/maizone_refactored/plugin.py index 4ea543f68..e8259b5cb 100644 --- a/src/plugins/built_in/maizone_refactored/plugin.py +++ b/src/plugins/built_in/maizone_refactored/plugin.py @@ -87,6 +87,7 @@ class MaiZoneRefactoredPlugin(BasePlugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + async def on_plugin_loaded(self): """插件加载完成后的回调,初始化服务并启动后台任务""" # --- 注册权限节点 --- diff --git a/src/plugins/built_in/maizone_refactored/services/cookie_service.py b/src/plugins/built_in/maizone_refactored/services/cookie_service.py index e3692f883..9da05582c 100644 --- a/src/plugins/built_in/maizone_refactored/services/cookie_service.py +++ b/src/plugins/built_in/maizone_refactored/services/cookie_service.py @@ -140,5 +140,7 @@ class CookieService: self._save_cookies_to_file(qq_account, cookies) return cookies - logger.error(f"为 {qq_account} 获取Cookie的所有方法均失败。请确保Napcat HTTP服务或Adapter连接至少有一个正常工作,或存在有效的本地Cookie文件。") + logger.error( + f"为 {qq_account} 获取Cookie的所有方法均失败。请确保Napcat HTTP服务或Adapter连接至少有一个正常工作,或存在有效的本地Cookie文件。" + ) return None diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index 186079965..c0e00b80d 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -290,9 +290,7 @@ class QZoneService: comment_content = comment.get("content", "") try: - reply_content = await self.content_service.generate_comment_reply( - content, comment_content, nickname - ) + reply_content = await self.content_service.generate_comment_reply(content, comment_content, nickname) if reply_content: success = await api_client["reply"](fid, qq_account, nickname, reply_content, comment_tid) if success: @@ -532,7 +530,9 @@ class QZoneService: async def _get_api_client(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict]: cookies = await self.cookie_service.get_cookies(qq_account, stream_id) if not cookies: - logger.error("获取API客户端失败:未能获取到Cookie。请检查Napcat连接是否正常,或是否存在有效的本地Cookie文件。") + logger.error( + "获取API客户端失败:未能获取到Cookie。请检查Napcat连接是否正常,或是否存在有效的本地Cookie文件。" + ) return None p_skey = cookies.get("p_skey") or cookies.get("p_skey".upper()) @@ -726,7 +726,8 @@ class QZoneService: return {"pic_bo": picbo, "richval": richval} except Exception as e: logger.error( - f"从上传结果中提取图片参数失败: {e}, 上传结果: {upload_result}", exc_info=True + f"从上传结果中提取图片参数失败: {e}, 上传结果: {upload_result}", + exc_info=True, ) return None else: @@ -764,7 +765,9 @@ class QZoneService: json_data = orjson.loads(res_text) if json_data.get("code") != 0: - logger.warning(f"获取说说列表API返回错误: code={json_data.get('code')}, message={json_data.get('message')}") + logger.warning( + f"获取说说列表API返回错误: code={json_data.get('code')}, message={json_data.get('message')}" + ) return [] feeds_list = [] @@ -797,7 +800,7 @@ class QZoneService: for c in commentlist: if not isinstance(c, dict): continue - + # 添加主评论 comments.append( { @@ -822,7 +825,7 @@ class QZoneService: "parent_tid": c.get("tid"), # 父ID是主评论的ID } ) - + feeds_list.append( { "tid": msg.get("tid", ""), diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py index 8725edeff..2c1c4d861 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py @@ -835,7 +835,7 @@ class MessageHandler: if music: tag = music.get("tag", "未知来源") logger.debug(f"检测到【{tag}】音乐分享消息 (music view),开始提取信息") - + title = music.get("title", "未知歌曲") desc = music.get("desc", "未知艺术家") jump_url = music.get("jumpUrl", "") @@ -853,7 +853,7 @@ class MessageHandler: artist = parts[1] else: artist = desc - + formatted_content = ( f"这是一张来自【{tag}】的音乐分享卡片:\n" f"歌曲: {song_title}\n" @@ -870,12 +870,12 @@ class MessageHandler: if news and "网易云音乐" in news.get("tag", ""): tag = news.get("tag") logger.debug(f"检测到【{tag}】音乐分享消息 (news view),开始提取信息") - + title = news.get("title", "未知歌曲") desc = news.get("desc", "未知艺术家") jump_url = news.get("jumpUrl", "") preview_url = news.get("preview", "") - + formatted_content = ( f"这是一张来自【{tag}】的音乐分享卡片:\n" f"标题: {title}\n" diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py index 411b957dc..11e4275f4 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py @@ -3,7 +3,6 @@ import time import random import websockets as Server import uuid -import asyncio from maim_message import ( UserInfo, GroupInfo, @@ -96,7 +95,9 @@ class SendHandler: logger.error("无法识别的消息类型") return logger.info("尝试发送到napcat") - logger.debug(f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'") + logger.debug( + f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'" + ) response = await self.send_message_to_napcat( action, { diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py index 1a8c3a95d..3ec4ca181 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py @@ -6,7 +6,6 @@ import urllib3 import ssl import io -import asyncio import time from asyncio import Lock @@ -75,7 +74,7 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d except Exception as e: logger.error(f"获取群信息失败: {e}") return None - + data = socket_response.get("data") if data: await set_to_cache(cache_key, data) @@ -114,7 +113,7 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use cached_data = await get_from_cache(cache_key) if cached_data: return cached_data - + logger.debug(f"获取群成员信息中 (无缓存): group={group_id}, user={user_id}") request_uuid = str(uuid.uuid4()) payload = json.dumps( @@ -133,7 +132,7 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use except Exception as e: logger.error(f"获取成员信息失败: {e}") return None - + data = socket_response.get("data") if data: await set_to_cache(cache_key, data) @@ -203,7 +202,7 @@ async def get_self_info(websocket: Server.ServerConnection) -> dict | None: except Exception as e: logger.error(f"获取自身信息失败: {e}") return None - + data = response.get("data") if data: await set_to_cache(cache_key, data) diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index 0f575e13f..5061cf496 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -260,7 +260,6 @@ class ManagementCommand(PlusCommand): except Exception as e: await self.send_text(f"❌ 强制重载过程中发生错误: {str(e)}") - async def _add_dir(self, dir_path: str): """添加插件目录""" await self.send_text(f"📁 正在添加插件目录: `{dir_path}`") @@ -501,13 +500,13 @@ class PluginManagementPlugin(BasePlugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # 注册权限节点 - + async def on_plugin_loaded(self): await permission_api.register_permission_node( - "plugin.management.admin", - "插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作", - "plugin_management", - False, + "plugin.management.admin", + "插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作", + "plugin_management", + False, ) def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]: diff --git a/src/plugins/built_in/proactive_thinker/plugin.py b/src/plugins/built_in/proactive_thinker/plugin.py index c04f82927..5e55e9101 100644 --- a/src/plugins/built_in/proactive_thinker/plugin.py +++ b/src/plugins/built_in/proactive_thinker/plugin.py @@ -1,27 +1,23 @@ -from typing import List, Tuple, Union, Type, Optional +from typing import List, Tuple, Type from src.common.logger import get_logger -from src.config.official_configs import AffinityFlowConfig from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system import ( BasePlugin, ConfigField, register_plugin, - plugin_manage_api, - component_manage_api, - ComponentInfo, - ComponentType, EventHandlerInfo, - EventType, BaseEventHandler, ) from .proacive_thinker_event import ProactiveThinkerEventHandler logger = get_logger(__name__) + @register_plugin class ProactiveThinkerPlugin(BasePlugin): """一个主动思考的插件,但现在还只是个空壳子""" + plugin_name: str = "proactive_thinker" enable_plugin: bool = False dependencies: list[str] = [] @@ -33,6 +29,7 @@ class ProactiveThinkerPlugin(BasePlugin): "config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本"), }, } + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -42,4 +39,3 @@ class ProactiveThinkerPlugin(BasePlugin): (ProactiveThinkerEventHandler.get_handler_info(), ProactiveThinkerEventHandler) ] return components - diff --git a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py index 55492c7f9..5ad560243 100644 --- a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py +++ b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py @@ -2,7 +2,7 @@ import asyncio import random import time from datetime import datetime -from typing import List, Union, Type, Optional +from typing import List, Union from maim_message import UserInfo @@ -69,7 +69,7 @@ class ColdStartTask(AsyncTask): # 创建 UserInfo 对象,这是创建聊天流的必要信息 user_info = UserInfo(platform=platform, user_id=str(user_id), user_nickname=user_nickname) - + # 【关键步骤】主动创建聊天流。 # 创建后,该用户就进入了机器人的“好友列表”,后续将由 ProactiveThinkingTask 接管 stream = await self.chat_manager.get_or_create_stream(platform, user_info) @@ -175,10 +175,12 @@ class ProactiveThinkingTask(AsyncTask): # 2. 【核心逻辑】检查聊天冷却时间是否足够长 time_since_last_active = time.time() - stream.last_active_time if time_since_last_active > next_interval: - logger.info(f"【日常唤醒】聊天流 {stream.stream_id} 已冷却 {time_since_last_active:.2f} 秒,触发主动对话。") - + logger.info( + f"【日常唤醒】聊天流 {stream.stream_id} 已冷却 {time_since_last_active:.2f} 秒,触发主动对话。" + ) + await self.executor.execute(stream_id=stream.stream_id, start_mode="wake_up") - + # 【关键步骤】在触发后,立刻更新活跃时间并保存。 # 这可以防止在同一个检查周期内,对同一个目标因为意外的延迟而发送多条消息。 stream.update_active_time() diff --git a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py index b5f280fee..ab3631450 100644 --- a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py +++ b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py @@ -3,7 +3,16 @@ from typing import Optional, Dict, Any from datetime import datetime from src.common.logger import get_logger -from src.plugin_system.apis import chat_api, person_api, schedule_api, send_api, llm_api, message_api, generator_api, database_api +from src.plugin_system.apis import ( + chat_api, + person_api, + schedule_api, + send_api, + llm_api, + message_api, + generator_api, + database_api, +) from src.config.config import global_config, model_config from src.person_info.person_info import get_person_info_manager @@ -39,17 +48,16 @@ class ProactiveThinkerExecutor: # 2. 决策阶段 decision_result = await self._make_decision(context, start_mode) - if not decision_result or not decision_result.get("should_reply"): reason = decision_result.get("reason", "未提供") if decision_result else "决策过程返回None" logger.info(f"决策结果为:不回复。原因: {reason}") await database_api.store_action_info( - chat_stream=self._get_stream_from_id(stream_id), - action_name="proactive_decision", - action_prompt_display=f"主动思考决定不回复,原因: {reason}", - action_done = True, - action_data=decision_result - ) + chat_stream=self._get_stream_from_id(stream_id), + action_name="proactive_decision", + action_prompt_display=f"主动思考决定不回复,原因: {reason}", + action_done=True, + action_data=decision_result, + ) return # 3. 规划与执行阶段 @@ -59,15 +67,17 @@ class ProactiveThinkerExecutor: chat_stream=self._get_stream_from_id(stream_id), action_name="proactive_decision", action_prompt_display=f"主动思考决定回复,原因: {reason},话题:{topic}", - action_done = True, - action_data=decision_result + action_done=True, + action_data=decision_result, ) logger.info(f"决策结果为:回复。话题: {topic}") - + plan_prompt = self._build_plan_prompt(context, start_mode, topic, reason) - - is_success, response, _, _ = await llm_api.generate_with_model(prompt=plan_prompt, model_config=model_config.model_task_config.utils) - + + is_success, response, _, _ = await llm_api.generate_with_model( + prompt=plan_prompt, model_config=model_config.model_task_config.utils + ) + if is_success and response: stream = self._get_stream_from_id(stream_id) if stream: @@ -104,33 +114,41 @@ class ProactiveThinkerExecutor: if not user_info or not user_info.platform or not user_info.user_id: logger.warning(f"Stream {stream_id} 的 user_info 不完整") return None - + person_id = person_api.get_person_id(user_info.platform, int(user_info.user_id)) person_info_manager = get_person_info_manager() # 获取日程 schedules = await schedule_api.ScheduleAPI.get_today_schedule() - schedule_context = "\n".join([f"- {s['title']} ({s['start_time']}-{s['end_time']})" for s in schedules]) if schedules else "今天没有日程安排。" + schedule_context = ( + "\n".join([f"- {s['title']} ({s['start_time']}-{s['end_time']})" for s in schedules]) + if schedules + else "今天没有日程安排。" + ) # 获取关系信息 short_impression = await person_info_manager.get_value(person_id, "short_impression") or "无" impression = await person_info_manager.get_value(person_id, "impression") or "无" attitude = await person_info_manager.get_value(person_id, "attitude") or 50 - + # 获取最近聊天记录 recent_messages = await message_api.get_recent_messages(stream_id, limit=10) - recent_chat_history = await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else "无" - + recent_chat_history = ( + await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else "无" + ) + # 获取最近的动作历史 action_history = await database_api.db_query( database_api.MODEL_MAPPING["ActionRecords"], filters={"chat_id": stream_id, "action_name": "proactive_decision"}, limit=3, - order_by=["-time"] + order_by=["-time"], ) action_history_context = "无" if isinstance(action_history, list): - action_history_context = "\n".join([f"- {a['action_data']}" for a in action_history if isinstance(a, dict)]) or "无" + action_history_context = ( + "\n".join([f"- {a['action_data']}" for a in action_history if isinstance(a, dict)]) or "无" + ) return { "person_id": person_id, @@ -138,47 +156,43 @@ class ProactiveThinkerExecutor: "schedule_context": schedule_context, "recent_chat_history": recent_chat_history, "action_history_context": action_history_context, - "relationship": { - "short_impression": short_impression, - "impression": impression, - "attitude": attitude - }, + "relationship": {"short_impression": short_impression, "impression": impression, "attitude": attitude}, "persona": { "core": global_config.personality.personality_core, "side": global_config.personality.personality_side, "identity": global_config.personality.identity, }, - "current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } async def _make_decision(self, context: Dict[str, Any], start_mode: str) -> Optional[Dict[str, Any]]: """ 决策模块:判断是否应该主动发起对话,以及聊什么话题 """ - persona = context['persona'] - user_info = context['user_info'] - relationship = context['relationship'] + persona = context["persona"] + user_info = context["user_info"] + relationship = context["relationship"] prompt = f""" # 角色 你的名字是{global_config.bot.nickname},你的人设如下: -- 核心人设: {persona['core']} -- 侧面人设: {persona['side']} -- 身份: {persona['identity']} +- 核心人设: {persona["core"]} +- 侧面人设: {persona["side"]} +- 身份: {persona["identity"]} # 任务 -现在是 {context['current_time']},你需要根据当前的情境,决定是否要主动向用户 '{user_info.user_nickname}' 发起对话。 +现在是 {context["current_time"]},你需要根据当前的情境,决定是否要主动向用户 '{user_info.user_nickname}' 发起对话。 # 情境分析 -1. **启动模式**: {start_mode} ({'初次见面/很久未见' if start_mode == 'cold_start' else '日常唤醒'}) +1. **启动模式**: {start_mode} ({"初次见面/很久未见" if start_mode == "cold_start" else "日常唤醒"}) 2. **你的日程**: -{context['schedule_context']} +{context["schedule_context"]} 3. **你和Ta的关系**: - - 简短印象: {relationship['short_impression']} - - 详细印象: {relationship['impression']} - - 好感度: {relationship['attitude']}/100 + - 简短印象: {relationship["short_impression"]} + - 详细印象: {relationship["impression"]} + - 好感度: {relationship["attitude"]}/100 4. **最近的聊天摘要**: -{context['recent_chat_history']} +{context["recent_chat_history"]} # 决策指令 请综合以上所有信息,做出决策。你的决策需要以JSON格式输出,包含以下字段: @@ -204,9 +218,11 @@ class ProactiveThinkerExecutor: 请输出你的决策: """ - - is_success, response, _, _ = await llm_api.generate_with_model(prompt=prompt, model_config=model_config.model_task_config.utils) - + + is_success, response, _, _ = await llm_api.generate_with_model( + prompt=prompt, model_config=model_config.model_task_config.utils + ) + if not is_success: return {"should_reply": False, "reason": "决策模型生成失败"} @@ -222,17 +238,17 @@ class ProactiveThinkerExecutor: """ 根据启动模式和决策话题,构建最终的规划提示词 """ - persona = context['persona'] - user_info = context['user_info'] - relationship = context['relationship'] + persona = context["persona"] + user_info = context["user_info"] + relationship = context["relationship"] if start_mode == "cold_start": prompt = f""" # 角色 你的名字是{global_config.bot.nickname},你的人设如下: -- 核心人设: {persona['core']} -- 侧面人设: {persona['side']} -- 身份: {persona['identity']} +- 核心人设: {persona["core"]} +- 侧面人设: {persona["side"]} +- 身份: {persona["identity"]} # 任务 你需要主动向一个新朋友 '{user_info.user_nickname}' 发起对话。这是你们的第一次交流,或者很久没聊了。 @@ -240,9 +256,9 @@ class ProactiveThinkerExecutor: # 决策上下文 - **决策理由**: {reason} - **你和Ta的关系**: - - 简短印象: {relationship['short_impression']} - - 详细印象: {relationship['impression']} - - 好感度: {relationship['attitude']}/100 + - 简短印象: {relationship["short_impression"]} + - 详细印象: {relationship["impression"]} + - 好感度: {relationship["attitude"]}/100 # 对话指引 - 你的目标是“破冰”,让对话自然地开始。 @@ -254,26 +270,26 @@ class ProactiveThinkerExecutor: prompt = f""" # 角色 你的名字是{global_config.bot.nickname},你的人设如下: -- 核心人设: {persona['core']} -- 侧面人设: {persona['side']} -- 身份: {persona['identity']} +- 核心人设: {persona["core"]} +- 侧面人设: {persona["side"]} +- 身份: {persona["identity"]} # 任务 -现在是 {context['current_time']},你需要主动向你的朋友 '{user_info.user_nickname}' 发起对话。 +现在是 {context["current_time"]},你需要主动向你的朋友 '{user_info.user_nickname}' 发起对话。 # 决策上下文 - **决策理由**: {reason} # 情境分析 1. **你的日程**: -{context['schedule_context']} +{context["schedule_context"]} 2. **你和Ta的关系**: - - 详细印象: {relationship['impression']} - - 好感度: {relationship['attitude']}/100 + - 详细印象: {relationship["impression"]} + - 好感度: {relationship["attitude"]}/100 3. **最近的聊天摘要**: -{context['recent_chat_history']} +{context["recent_chat_history"]} 4. **你最近的相关动作**: -{context['action_history_context']} +{context["action_history_context"]} # 对话指引 - 你决定和Ta聊聊关于“{topic}”的话题。 diff --git a/src/plugins/built_in/social_toolkit_plugin/plugin.py b/src/plugins/built_in/social_toolkit_plugin/plugin.py index 1dcd5c08e..a26879da7 100644 --- a/src/plugins/built_in/social_toolkit_plugin/plugin.py +++ b/src/plugins/built_in/social_toolkit_plugin/plugin.py @@ -16,8 +16,6 @@ from src.person_info.person_info import get_person_info_manager from dateutil.parser import parse as parse_datetime from src.manager.async_task_manager import AsyncTask, async_task_manager from src.plugin_system.apis import send_api, llm_api, generator_api -from src.plugin_system.base.component_types import ComponentType -from typing import Optional from src.chat.message_receive.chat_stream import ChatStream import asyncio import datetime diff --git a/src/schedule/database.py b/src/schedule/database.py index 9117b9586..b33bfb953 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -95,9 +95,7 @@ async def mark_plans_completed(plan_ids: List[int]): plan_details = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans_to_mark)]) logger.info(f"以下 {len(plans_to_mark)} 条月度计划将被标记为已完成:\n{plan_details}") - await session.execute( - update(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)).values(status="completed") - ) + await session.execute(update(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)).values(status="completed")) await session.commit() except Exception as e: logger.error(f"标记月度计划为完成时发生错误: {e}") @@ -184,9 +182,7 @@ async def update_plan_usage(plan_ids: List[int], used_date: str): raise -async def get_smart_plans_for_daily_schedule( - month: str, max_count: int = 3, avoid_days: int = 7 -) -> List[MonthlyPlan]: +async def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_days: int = 7) -> List[MonthlyPlan]: """ 智能抽取月度计划用于每日日程生成。 @@ -208,14 +204,10 @@ async def get_smart_plans_for_daily_schedule( avoid_date = (datetime.now() - timedelta(days=avoid_days)).strftime("%Y-%m-%d") # 查询符合条件的计划 - query = select(MonthlyPlan).where( - MonthlyPlan.target_month == month, MonthlyPlan.status == "active" - ) + query = select(MonthlyPlan).where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") # 排除最近使用过的计划 - query = query.where( - (MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date) - ) + query = query.where((MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date)) # 按使用次数升序排列,优先选择使用次数少的 result = await session.execute(query.order_by(MonthlyPlan.usage_count.asc())) @@ -274,9 +266,7 @@ async def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]: async with get_db_session() as session: try: result = await session.execute( - select(MonthlyPlan).where( - MonthlyPlan.target_month == month, MonthlyPlan.status == "archived" - ) + select(MonthlyPlan).where(MonthlyPlan.target_month == month, MonthlyPlan.status == "archived") ) return result.scalars().all() except Exception as e: