From ecb02cae316b8f2381a1b9593478ed2fba2ad34e Mon Sep 17 00:00:00 2001 From: John Richard Date: Thu, 2 Oct 2025 19:38:39 +0800 Subject: [PATCH 1/3] =?UTF-8?q?style:=20=E6=A0=BC=E5=BC=8F=E5=8C=96?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 3 +- scripts/lpmm_learning_tool.py | 10 +- scripts/rebuild_metadata_index.py | 64 +-- scripts/run_multi_stage_smoke.py | 5 +- .../management/statistics.py | 16 +- src/chat/chatter_manager.py | 12 +- src/chat/emoji_system/emoji_history.py | 1 - src/chat/emoji_system/emoji_manager.py | 9 +- src/chat/energy_system/__init__.py | 6 +- src/chat/energy_system/energy_manager.py | 51 +- src/chat/express/expression_learner.py | 9 +- .../interest_system/bot_interest_manager.py | 52 +- src/chat/knowledge/embedding_store.py | 2 +- src/chat/knowledge/qa_manager.py | 27 +- src/chat/memory_system/__init__.py | 44 +- .../enhanced_memory_adapter.py | 72 +-- .../enhanced_memory_hooks.py | 23 +- .../enhanced_memory_integration.py | 61 +-- .../deprecated_backup/enhanced_reranker.py | 218 +++++--- .../deprecated_backup/integration_layer.py | 29 +- .../memory_integration_hooks.py | 67 +-- .../deprecated_backup/metadata_index.py | 126 ++--- .../multi_stage_retrieval.py | 314 ++++++----- .../deprecated_backup/vector_storage.py | 155 +++--- .../enhanced_memory_activator.py | 30 +- .../memory_system/memory_activator_new.py | 30 +- src/chat/memory_system/memory_builder.py | 149 +++--- src/chat/memory_system/memory_chunk.py | 135 +++-- .../memory_system/memory_forgetting_engine.py | 57 +- src/chat/memory_system/memory_fusion.py | 75 +-- src/chat/memory_system/memory_manager.py | 57 +- .../memory_system/memory_metadata_index.py | 137 +++-- .../memory_system/memory_query_planner.py | 19 +- src/chat/memory_system/memory_system.py | 343 ++++++------ .../memory_system/vector_memory_storage_v2.py | 489 +++++++++--------- src/chat/message_manager/__init__.py | 9 +- src/chat/message_manager/context_manager.py | 2 + .../message_manager/distribution_manager.py | 16 +- src/chat/message_manager/message_manager.py | 39 +- .../sleep_manager/notification_sender.py | 40 +- .../sleep_manager/sleep_manager.py | 32 +- .../sleep_manager/sleep_state.py | 3 +- .../sleep_manager/time_checker.py | 42 +- .../sleep_manager/wakeup_context.py | 4 +- .../sleep_manager/wakeup_manager.py | 10 +- src/chat/message_receive/bot.py | 12 +- src/chat/message_receive/chat_stream.py | 9 +- src/chat/message_receive/message.py | 2 +- src/chat/message_receive/storage.py | 54 +- src/chat/planner_actions/action_manager.py | 23 +- src/chat/replyer/default_generator.py | 80 +-- src/chat/utils/chat_message_builder.py | 121 +++-- src/chat/utils/memory_mappings.py | 12 +- src/chat/utils/prompt.py | 47 +- src/chat/utils/statistic.py | 71 ++- src/chat/utils/utils.py | 1 - src/chat/utils/utils_image.py | 36 +- src/chat/utils/utils_video.py | 4 +- src/common/cache_manager.py | 5 +- src/common/database/db_migration.py | 15 +- .../database/sqlalchemy_database_api.py | 2 +- src/common/database/sqlalchemy_models.py | 1 + src/common/logger.py | 8 +- src/common/vector_db/chromadb_impl.py | 22 +- src/config/config.py | 2 +- src/config/official_configs.py | 77 ++- src/individuality/individuality.py | 6 +- .../model_client/aiohttp_gemini_client.py | 4 +- src/llm_models/payload_content/message.py | 2 +- src/llm_models/utils_model.py | 257 +++++---- src/main.py | 3 +- src/mais4u/mais4u_chat/s4u_msg_processor.py | 4 +- src/mais4u/mais4u_chat/s4u_prompt.py | 7 +- src/mood/mood_manager.py | 2 +- src/person_info/person_info.py | 13 +- src/person_info/relationship_builder.py | 20 +- src/person_info/relationship_fetcher.py | 1 - src/plugin_system/apis/__init__.py | 2 +- src/plugin_system/apis/message_api.py | 10 +- src/plugin_system/apis/permission_api.py | 3 +- src/plugin_system/apis/schedule_api.py | 3 +- src/plugin_system/apis/send_api.py | 4 +- src/plugin_system/base/base_action.py | 18 +- src/plugin_system/base/base_chatter.py | 9 +- src/plugin_system/base/base_tool.py | 36 +- src/plugin_system/base/component_types.py | 2 +- src/plugin_system/base/plugin_base.py | 4 +- src/plugin_system/core/component_registry.py | 110 ++-- src/plugin_system/core/permission_manager.py | 27 +- src/plugin_system/core/plugin_manager.py | 7 +- src/plugin_system/core/tool_use.py | 22 +- .../utils/permission_decorators.py | 8 +- .../affinity_flow_chatter/interest_scoring.py | 4 +- .../affinity_flow_chatter/plan_executor.py | 16 +- .../affinity_flow_chatter/plan_filter.py | 97 ++-- .../built_in/affinity_flow_chatter/planner.py | 10 +- .../relationship_tracker.py | 2 +- src/plugins/built_in/core_actions/emoji.py | 4 +- .../built_in/knowledge/lpmm_get_knowledge.py | 4 +- .../built_in/maizone_refactored/plugin.py | 1 + .../services/cookie_service.py | 4 +- .../services/qzone_service.py | 19 +- .../src/recv_handler/message_handler.py | 8 +- .../napcat_adapter_plugin/src/send_handler.py | 5 +- .../napcat_adapter_plugin/src/utils.py | 9 +- .../built_in/plugin_management/plugin.py | 11 +- .../built_in/proactive_thinker/plugin.py | 12 +- .../proacive_thinker_event.py | 12 +- .../proactive_thinker_executor.py | 138 ++--- .../built_in/social_toolkit_plugin/plugin.py | 2 - src/schedule/database.py | 20 +- 111 files changed, 2344 insertions(+), 2316 deletions(-) diff --git a/bot.py b/bot.py index 0399b0d19..798247c96 100644 --- a/bot.py +++ b/bot.py @@ -21,7 +21,7 @@ initialize_logging() from src.main import MainSystem # noqa from src import BaseMain # noqa from src.manager.async_task_manager import async_task_manager # noqa -from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge # noqa +from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge # noqa from src.config.config import global_config # noqa from src.common.database.database import initialize_sql_database # noqa from src.common.database.sqlalchemy_models import initialize_database as init_db # noqa @@ -269,4 +269,3 @@ if __name__ == "__main__": # 在程序退出前暂停,让你有机会看到输出 # input("按 Enter 键退出...") # <--- 添加这行 sys.exit(exit_code) # <--- 使用记录的退出码 - \ No newline at end of file diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index f0888d552..9caafc7fd 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -40,6 +40,7 @@ file_lock = Lock() # --- 缓存清理 --- + def clear_cache(): """清理 lpmm_learning_tool.py 生成的缓存文件""" logger.info("--- 开始清理缓存 ---") @@ -53,6 +54,7 @@ def clear_cache(): logger.info("缓存目录不存在,无需清理。") logger.info("--- 缓存清理完成 ---") + # --- 模块一:数据预处理 --- @@ -108,7 +110,7 @@ def _parse_and_repair_json(json_string: str) -> Optional[dict]: cleaned_string = cleaned_string[7:].strip() elif cleaned_string.startswith("```"): cleaned_string = cleaned_string[3:].strip() - + if cleaned_string.endswith("```"): cleaned_string = cleaned_string[:-3].strip() @@ -117,7 +119,7 @@ def _parse_and_repair_json(json_string: str) -> Optional[dict]: return orjson.loads(cleaned_string) except orjson.JSONDecodeError: logger.warning("直接解析JSON失败,将尝试修复...") - + # 3. 修复与最终解析 repaired_json_str = "" try: @@ -164,10 +166,10 @@ async def extract_info_async(pg_hash, paragraph, llm_api): content = None try: content, (_, _, _) = await llm_api.generate_response_async(prompt) - + # 改进点:调用封装好的函数处理JSON解析和修复 extracted_data = _parse_and_repair_json(content) - + if extracted_data is None: # 如果解析失败,抛出异常以触发统一的错误处理逻辑 raise ValueError("无法从LLM输出中解析有效的JSON数据") diff --git a/scripts/rebuild_metadata_index.py b/scripts/rebuild_metadata_index.py index f5cf652ab..d1990fecc 100644 --- a/scripts/rebuild_metadata_index.py +++ b/scripts/rebuild_metadata_index.py @@ -3,6 +3,7 @@ """ 从现有ChromaDB数据重建JSON元数据索引 """ + import asyncio import sys import os @@ -15,53 +16,53 @@ from src.common.logger import get_logger logger = get_logger(__name__) + async def rebuild_metadata_index(): """从ChromaDB重建元数据索引""" - print("="*80) + print("=" * 80) print("重建JSON元数据索引") - print("="*80) - + print("=" * 80) + # 初始化记忆系统 print("\n🔧 初始化记忆系统...") ms = MemorySystem() await ms.initialize() print("✅ 记忆系统已初始化") - - if not hasattr(ms.unified_storage, 'metadata_index'): + + if not hasattr(ms.unified_storage, "metadata_index"): print("❌ 元数据索引管理器未初始化") return - + # 获取所有记忆 print("\n📥 从ChromaDB获取所有记忆...") from src.common.vector_db import vector_db_service - + try: # 获取集合中的所有记忆ID collection_name = ms.unified_storage.config.memory_collection result = vector_db_service.get( - collection_name=collection_name, - include=["documents", "metadatas", "embeddings"] + collection_name=collection_name, include=["documents", "metadatas", "embeddings"] ) - + if not result or not result.get("ids"): print("❌ ChromaDB中没有找到记忆数据") return - + ids = result["ids"] metadatas = result.get("metadatas", []) - + print(f"✅ 找到 {len(ids)} 条记忆") - + # 重建元数据索引 print("\n🔨 开始重建元数据索引...") entries = [] success_count = 0 - - for i, (memory_id, metadata) in enumerate(zip(ids, metadatas), 1): + + for i, (memory_id, metadata) in enumerate(zip(ids, metadatas, strict=False), 1): try: # 从ChromaDB元数据重建索引条目 import orjson - + entry = MemoryMetadataIndexEntry( memory_id=memory_id, user_id=metadata.get("user_id", "unknown"), @@ -75,9 +76,9 @@ async def rebuild_metadata_index(): created_at=metadata.get("created_at", 0.0), access_count=metadata.get("access_count", 0), chat_id=metadata.get("chat_id"), - content_preview=None + content_preview=None, ) - + # 尝试解析importance和confidence的枚举名称 if "importance" in metadata: imp_str = metadata["importance"] @@ -89,7 +90,7 @@ async def rebuild_metadata_index(): entry.importance = 3 elif imp_str == "CRITICAL": entry.importance = 4 - + if "confidence" in metadata: conf_str = metadata["confidence"] if conf_str == "LOW": @@ -100,40 +101,41 @@ async def rebuild_metadata_index(): entry.confidence = 3 elif conf_str == "VERIFIED": entry.confidence = 4 - + entries.append(entry) success_count += 1 - + if i % 100 == 0: print(f" 处理进度: {i}/{len(ids)} ({success_count} 成功)") - + except Exception as e: logger.warning(f"处理记忆 {memory_id} 失败: {e}") continue - + print(f"\n✅ 成功解析 {success_count}/{len(ids)} 条记忆元数据") - + # 批量更新索引 print("\n💾 保存元数据索引...") ms.unified_storage.metadata_index.batch_add_or_update(entries) ms.unified_storage.metadata_index.save() - + # 显示统计信息 stats = ms.unified_storage.metadata_index.get_stats() - print(f"\n📊 重建后的索引统计:") + print("\n📊 重建后的索引统计:") print(f" - 总记忆数: {stats['total_memories']}") print(f" - 主语数量: {stats['subjects_count']}") print(f" - 关键词数量: {stats['keywords_count']}") print(f" - 标签数量: {stats['tags_count']}") - print(f" - 类型分布:") - for mtype, count in stats['types'].items(): + print(" - 类型分布:") + for mtype, count in stats["types"].items(): print(f" - {mtype}: {count}") - + print("\n✅ 元数据索引重建完成!") - + except Exception as e: logger.error(f"重建索引失败: {e}", exc_info=True) print(f"❌ 重建索引失败: {e}") -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(rebuild_metadata_index()) diff --git a/scripts/run_multi_stage_smoke.py b/scripts/run_multi_stage_smoke.py index bfb3c417f..000336244 100644 --- a/scripts/run_multi_stage_smoke.py +++ b/scripts/run_multi_stage_smoke.py @@ -3,6 +3,7 @@ """ 轻量烟雾测试:初始化 MemorySystem 并运行一次检索,验证 MemoryMetadata.source 访问不再报错 """ + import asyncio import sys import os @@ -11,6 +12,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.chat.memory_system.memory_system import MemorySystem + async def main(): ms = MemorySystem() await ms.initialize() @@ -19,5 +21,6 @@ async def main(): for i, m in enumerate(results, 1): print(f"{i}. id={m.metadata.memory_id} source={getattr(m.metadata, 'source', None)}") -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index 4df22f152..9d44faa78 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -31,9 +31,11 @@ class AntiInjectionStatistics: try: async with get_db_session() as session: # 获取最新的统计记录,如果没有则创建 - stats = (await session.execute( - select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()) - )).scalars().first() + stats = ( + (await session.execute(select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()))) + .scalars() + .first() + ) if not stats: stats = AntiInjectionStats() session.add(stats) @@ -49,9 +51,11 @@ class AntiInjectionStatistics: """更新统计数据""" try: async with get_db_session() as session: - stats = (await session.execute( - select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()) - )).scalars().first() + stats = ( + (await session.execute(select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()))) + .scalars() + .first() + ) if not stats: stats = AntiInjectionStats() session.add(stats) diff --git a/src/chat/chatter_manager.py b/src/chat/chatter_manager.py index 7aee10562..d22d39440 100644 --- a/src/chat/chatter_manager.py +++ b/src/chat/chatter_manager.py @@ -2,13 +2,13 @@ from typing import Dict, List, Optional, Any import time from src.plugin_system.base.base_chatter import BaseChatter from src.common.data_models.message_manager_data_model import StreamContext -from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner as ActionPlanner from src.chat.planner_actions.action_manager import ChatterActionManager -from src.plugin_system.base.component_types import ChatType, ComponentType +from src.plugin_system.base.component_types import ChatType from src.common.logger import get_logger logger = get_logger("chatter_manager") + class ChatterManager: def __init__(self, action_manager: ChatterActionManager): self.action_manager = action_manager @@ -27,6 +27,7 @@ class ChatterManager: """从组件注册表自动注册已注册的chatter组件""" try: from src.plugin_system.core.component_registry import component_registry + # 获取所有CHATTER类型的组件 chatter_components = component_registry.get_enabled_chatter_registry() for chatter_name, chatter_class in chatter_components.items(): @@ -70,7 +71,7 @@ class ChatterManager: inactive_streams = [] for stream_id, instance in self.instances.items(): - if hasattr(instance, 'get_activity_time'): + if hasattr(instance, "get_activity_time"): activity_time = instance.get_activity_time() if (current_time - activity_time) > max_inactive_seconds: inactive_streams.append(stream_id) @@ -91,6 +92,7 @@ class ChatterManager: if not chatter_class: # 如果没有找到精确匹配,尝试查找支持ALL类型的chatter from src.plugin_system.base.component_types import ChatType + all_chatter_class = self.get_chatter_class(ChatType.ALL) if all_chatter_class: chatter_class = all_chatter_class @@ -110,6 +112,7 @@ class ChatterManager: # 从 mood_manager 获取最新的 chat_stream 并同步回 StreamContext try: from src.mood.mood_manager import mood_manager + mood = mood_manager.get_mood_by_chat_id(stream_id) if mood and mood.chat_stream: context.chat_stream = mood.chat_stream @@ -125,6 +128,7 @@ class ChatterManager: # 在处理完成后,清除该流的未读消息 try: from src.chat.message_manager.message_manager import message_manager + await message_manager.clear_stream_unread_messages(stream_id) except Exception as clear_e: logger.error(f"清除流 {stream_id} 未读消息时发生错误: {clear_e}") @@ -149,4 +153,4 @@ class ChatterManager: "streams_processed": 0, "successful_executions": 0, "failed_executions": 0, - } \ No newline at end of file + } diff --git a/src/chat/emoji_system/emoji_history.py b/src/chat/emoji_system/emoji_history.py index 804f61e0a..dadd152a1 100644 --- a/src/chat/emoji_system/emoji_history.py +++ b/src/chat/emoji_system/emoji_history.py @@ -3,7 +3,6 @@ 表情包发送历史记录模块 """ -import os from typing import List, Dict from collections import deque diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 9e4829a56..cd472ec0c 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -204,9 +204,7 @@ class MaiEmoji: # 2. 删除数据库记录 try: async with get_db_session() as session: - result = await session.execute( - select(Emoji).where(Emoji.emoji_hash == self.hash) - ) + result = await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)) will_delete_emoji = result.scalar_one_or_none() if will_delete_emoji is None: logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") @@ -946,10 +944,7 @@ class EmojiManager: existing_description = None try: async with get_db_session() as session: - stmt = select(Images).where( - Images.emoji_hash == image_hash, - Images.type == "emoji" - ) + stmt = select(Images).where(Images.emoji_hash == image_hash, Images.type == "emoji") result = await session.execute(stmt) existing_image = result.scalar_one_or_none() if existing_image and existing_image.description: diff --git a/src/chat/energy_system/__init__.py b/src/chat/energy_system/__init__.py index 0addfd070..6cdf96da5 100644 --- a/src/chat/energy_system/__init__.py +++ b/src/chat/energy_system/__init__.py @@ -12,7 +12,7 @@ from .energy_manager import ( ActivityEnergyCalculator, RecencyEnergyCalculator, RelationshipEnergyCalculator, - energy_manager + energy_manager, ) __all__ = [ @@ -24,5 +24,5 @@ __all__ = [ "ActivityEnergyCalculator", "RecencyEnergyCalculator", "RelationshipEnergyCalculator", - "energy_manager" -] \ No newline at end of file + "energy_manager", +] diff --git a/src/chat/energy_system/energy_manager.py b/src/chat/energy_system/energy_manager.py index b07222b4b..4a92349bf 100644 --- a/src/chat/energy_system/energy_manager.py +++ b/src/chat/energy_system/energy_manager.py @@ -17,16 +17,18 @@ logger = get_logger("energy_system") class EnergyLevel(Enum): """能量等级""" + VERY_LOW = 0.1 # 非常低 - LOW = 0.3 # 低 - NORMAL = 0.5 # 正常 - HIGH = 0.7 # 高 - VERY_HIGH = 0.9 # 非常高 + LOW = 0.3 # 低 + NORMAL = 0.5 # 正常 + HIGH = 0.7 # 高 + VERY_HIGH = 0.9 # 非常高 @dataclass class EnergyComponent: """能量组件""" + name: str value: float weight: float = 1.0 @@ -47,6 +49,7 @@ class EnergyComponent: class EnergyContext(TypedDict): """能量计算上下文""" + stream_id: str messages: List[Any] user_id: Optional[str] @@ -54,6 +57,7 @@ class EnergyContext(TypedDict): class EnergyResult(TypedDict): """能量计算结果""" + energy: float level: EnergyLevel distribution_interval: float @@ -114,12 +118,7 @@ class ActivityEnergyCalculator(EnergyCalculator): """活跃度能量计算器""" def __init__(self): - self.action_weights = { - "reply": 0.4, - "react": 0.3, - "mention": 0.2, - "other": 0.1 - } + self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1} def calculate(self, context: Dict[str, Any]) -> float: """基于活跃度计算能量""" @@ -188,7 +187,7 @@ class RecencyEnergyCalculator(EnergyCalculator): else: recency_score = 0.1 - logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age/3600:.1f}小时)") + logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age / 3600:.1f}小时)") return recency_score def get_weight(self) -> float: @@ -236,11 +235,7 @@ class EnergyManager: self.cache_ttl: int = 60 # 1分钟缓存 # AFC阈值配置 - self.thresholds: Dict[str, float] = { - "high_match": 0.8, - "reply": 0.4, - "non_reply": 0.2 - } + self.thresholds: Dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2} # 统计信息 self.stats: Dict[str, Union[int, float, str]] = { @@ -260,9 +255,13 @@ class EnergyManager: """从配置加载AFC阈值""" try: if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None: - self.thresholds["high_match"] = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8) + self.thresholds["high_match"] = getattr( + global_config.affinity_flow, "high_match_interest_threshold", 0.8 + ) self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4) - self.thresholds["non_reply"] = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2) + self.thresholds["non_reply"] = getattr( + global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2 + ) # 确保阈值关系合理 self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1) @@ -306,6 +305,7 @@ class EnergyManager: # 支持同步和异步计算器 if callable(calculator.calculate): import inspect + if inspect.iscoroutinefunction(calculator.calculate): score = await calculator.calculate(context) else: @@ -347,11 +347,12 @@ class EnergyManager: calculation_time = time.time() - start_time total_calculations = self.stats["total_calculations"] self.stats["average_calculation_time"] = ( - (self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time) - / total_calculations - ) + self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time + ) / total_calculations - logger.debug(f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)") + logger.debug( + f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)" + ) return final_energy def _apply_threshold_adjustment(self, energy: float) -> float: @@ -405,6 +406,7 @@ class EnergyManager: # 添加随机扰动避免同步 import random + jitter = random.uniform(0.8, 1.2) final_interval = base_interval * jitter @@ -424,7 +426,8 @@ class EnergyManager: """清理过期缓存""" current_time = time.time() expired_keys = [ - stream_id for stream_id, (_, timestamp) in self.energy_cache.items() + stream_id + for stream_id, (_, timestamp) in self.energy_cache.items() if current_time - timestamp > self.cache_ttl ] @@ -479,4 +482,4 @@ class EnergyManager: # 全局能量管理器实例 -energy_manager = EnergyManager() \ No newline at end of file +energy_manager = EnergyManager() diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index bb663a1ad..596322ebd 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -382,7 +382,7 @@ class ExpressionLearner: create_date=current_time, # 手动设置创建日期 ) session.add(new_expression) - + # 限制最大数量 exprs_result = await session.execute( select(Expression) @@ -492,11 +492,10 @@ class ExpressionLearnerManager: self._ensure_expression_directories() - async def get_expression_learner(self, chat_id: str) -> ExpressionLearner: await self._auto_migrate_json_to_db() await self._migrate_old_data_create_date() - + if chat_id not in self.expression_learners: self.expression_learners[chat_id] = ExpressionLearner(chat_id) return self.expression_learners[chat_id] @@ -644,7 +643,9 @@ class ExpressionLearnerManager: try: async with get_db_session() as session: # 查找所有create_date为空的表达方式 - old_expressions_result = await session.execute(select(Expression).where(Expression.create_date.is_(None))) + old_expressions_result = await session.execute( + select(Expression).where(Expression.create_date.is_(None)) + ) old_expressions = old_expressions_result.scalars().all() updated_count = 0 diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index 37a315197..8fee48d1c 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -131,7 +131,9 @@ class BotInterestManager: self.current_interests = generated_interests active_count = len(generated_interests.get_active_tags()) logger.info(f"成功生成 {active_count} 个新兴趣标签。") - tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in generated_interests.get_active_tags()] + tags_info = [ + f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in generated_interests.get_active_tags() + ] tags_str = "\n".join(tags_info) logger.info(f"当前兴趣标签:\n{tags_str}") @@ -639,11 +641,19 @@ class BotInterestManager: async with get_db_session() as session: # 查询最新的兴趣标签配置 - db_interests = (await session.execute( - select(DBBotPersonalityInterests) - .where(DBBotPersonalityInterests.personality_id == personality_id) - .order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc()) - )).scalars().first() + db_interests = ( + ( + await session.execute( + select(DBBotPersonalityInterests) + .where(DBBotPersonalityInterests.personality_id == personality_id) + .order_by( + DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc() + ) + ) + ) + .scalars() + .first() + ) if db_interests: logger.debug(f"在数据库中找到兴趣标签配置, 版本: {db_interests.version}") @@ -728,10 +738,17 @@ class BotInterestManager: async with get_db_session() as session: # 检查是否已存在相同personality_id的记录 - existing_record = (await session.execute( - select(DBBotPersonalityInterests) - .where(DBBotPersonalityInterests.personality_id == interests.personality_id) - )).scalars().first() + existing_record = ( + ( + await session.execute( + select(DBBotPersonalityInterests).where( + DBBotPersonalityInterests.personality_id == interests.personality_id + ) + ) + ) + .scalars() + .first() + ) if existing_record: # 更新现有记录 @@ -763,10 +780,17 @@ class BotInterestManager: # 验证保存是否成功 async with get_db_session() as session: - saved_record = (await session.execute( - select(DBBotPersonalityInterests) - .where(DBBotPersonalityInterests.personality_id == interests.personality_id) - )).scalars().first() + saved_record = ( + ( + await session.execute( + select(DBBotPersonalityInterests).where( + DBBotPersonalityInterests.personality_id == interests.personality_id + ) + ) + ) + .scalars() + .first() + ) if saved_record: logger.info(f"✅ 验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录") logger.info(f" 版本: {saved_record.version}") diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 162a00b7f..f6fae8d6c 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -161,7 +161,7 @@ class EmbeddingStore: @staticmethod def _get_embeddings_batch_threaded( - strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None + strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None ) -> List[Tuple[str, List[float]]]: """使用多线程批量获取嵌入向量 diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index f539659fb..c340fc30e 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -101,7 +101,7 @@ class QAManager: async def get_knowledge(self, question: str) -> Optional[Dict[str, Any]]: """ 获取知识,返回结构化字典 - + Args: question: 用户提出的问题 @@ -114,30 +114,27 @@ class QAManager: return None query_res = processed_result[0] - + knowledge_items = [] for res_hash, relevance, *_ in query_res: if store_item := self.embed_manager.paragraphs_embedding_store.store.get(res_hash): - knowledge_items.append({ - "content": store_item.str, - "source": "内部知识库", - "relevance": f"{relevance:.4f}" - }) + knowledge_items.append( + {"content": store_item.str, "source": "内部知识库", "relevance": f"{relevance:.4f}"} + ) if not knowledge_items: return None - + # 使用LLM生成总结 - knowledge_text_for_summary = "\n\n".join([item['content'] for item in knowledge_items[:5]]) # 最多总结前5条 - summary_prompt = f"根据以下信息,为问题 '{question}' 生成一个简洁的、不超过50字的摘要:\n\n{knowledge_text_for_summary}" - + knowledge_text_for_summary = "\n\n".join([item["content"] for item in knowledge_items[:5]]) # 最多总结前5条 + summary_prompt = ( + f"根据以下信息,为问题 '{question}' 生成一个简洁的、不超过50字的摘要:\n\n{knowledge_text_for_summary}" + ) + try: summary, (_, _, _) = await self.qa_model.generate_response_async(summary_prompt) except Exception as e: logger.error(f"生成知识摘要失败: {e}") summary = "无法生成摘要。" - return { - "knowledge_items": knowledge_items, - "summary": summary.strip() if summary else "没有可用的摘要。" - } + return {"knowledge_items": knowledge_items, "summary": summary.strip() if summary else "没有可用的摘要。"} diff --git a/src/chat/memory_system/__init__.py b/src/chat/memory_system/__init__.py index 75daf0fb2..a1c176a10 100644 --- a/src/chat/memory_system/__init__.py +++ b/src/chat/memory_system/__init__.py @@ -12,44 +12,23 @@ from .memory_chunk import ( MemoryType, ImportanceLevel, ConfidenceLevel, - create_memory_chunk + create_memory_chunk, ) # 遗忘引擎 -from .memory_forgetting_engine import ( - MemoryForgettingEngine, - ForgettingConfig, - get_memory_forgetting_engine -) +from .memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig, get_memory_forgetting_engine # Vector DB存储系统 -from .vector_memory_storage_v2 import ( - VectorMemoryStorage, - VectorStorageConfig, - get_vector_memory_storage -) +from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage # 记忆核心系统 -from .memory_system import ( - MemorySystem, - MemorySystemConfig, - get_memory_system, - initialize_memory_system -) +from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system # 记忆管理器 -from .memory_manager import ( - MemoryManager, - MemoryResult, - memory_manager -) +from .memory_manager import MemoryManager, MemoryResult, memory_manager # 激活器 -from .enhanced_memory_activator import ( - MemoryActivator, - memory_activator, - enhanced_memory_activator -) +from .enhanced_memory_activator import MemoryActivator, memory_activator, enhanced_memory_activator # 兼容性别名 from .memory_chunk import MemoryChunk as Memory @@ -64,28 +43,23 @@ __all__ = [ "ImportanceLevel", "ConfidenceLevel", "create_memory_chunk", - # 遗忘引擎 "MemoryForgettingEngine", "ForgettingConfig", "get_memory_forgetting_engine", - # Vector DB存储 "VectorMemoryStorage", - "VectorStorageConfig", + "VectorStorageConfig", "get_vector_memory_storage", - # 记忆系统 "MemorySystem", "MemorySystemConfig", "get_memory_system", "initialize_memory_system", - # 记忆管理器 "MemoryManager", - "MemoryResult", + "MemoryResult", "memory_manager", - # 激活器 "MemoryActivator", "memory_activator", @@ -95,4 +69,4 @@ __all__ = [ # 版本信息 __version__ = "3.0.0" __author__ = "MoFox Team" -__description__ = "简化记忆系统 - 统一记忆架构与智能遗忘机制" \ No newline at end of file +__description__ = "简化记忆系统 - 统一记忆架构与智能遗忘机制" diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py index 7d34df8d9..aae09c08b 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py @@ -4,15 +4,14 @@ 将增强记忆系统集成到现有MoFox Bot架构中 """ -import asyncio import time -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict, List, Optional, Any from dataclasses import dataclass from src.common.logger import get_logger from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType -from src.chat.memory_system.memory_formatter import MemoryFormatter, FormatterConfig, format_memories_for_llm +from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm from src.llm_models.utils_model import LLMRequest logger = get_logger(__name__) @@ -36,6 +35,7 @@ MEMORY_TYPE_LABELS = { @dataclass class AdapterConfig: """适配器配置""" + enable_enhanced_memory: bool = True integration_mode: str = "enhanced_only" # replace, enhanced_only auto_migration: bool = True @@ -61,7 +61,7 @@ class EnhancedMemoryAdapter: "hybrid_used": 0, "memories_created": 0, "memories_retrieved": 0, - "average_processing_time": 0.0 + "average_processing_time": 0.0, } async def initialize(self): @@ -79,14 +79,11 @@ class EnhancedMemoryAdapter: memory_value_threshold=self.config.memory_value_threshold, fusion_threshold=self.config.fusion_threshold, max_retrieval_results=self.config.max_retrieval_results, - enable_learning=True # 启用学习功能 + enable_learning=True, # 启用学习功能 ) # 创建集成层 - self.integration_layer = MemoryIntegrationLayer( - llm_model=self.llm_model, - config=integration_config - ) + self.integration_layer = MemoryIntegrationLayer(llm_model=self.llm_model, config=integration_config) # 初始化集成层 await self.integration_layer.initialize() @@ -99,10 +96,7 @@ class EnhancedMemoryAdapter: # 如果初始化失败,禁用增强记忆功能 self.config.enable_enhanced_memory = False - async def process_conversation_memory( - self, - context: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: + async def process_conversation_memory(self, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """处理对话记忆,以上下文为唯一输入""" if not self._initialized or not self.config.enable_enhanced_memory: return {"success": False, "error": "Enhanced memory not available"} @@ -152,11 +146,7 @@ class EnhancedMemoryAdapter: return {"success": False, "error": str(e)} async def retrieve_relevant_memories( - self, - query: str, - user_id: str, - context: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None + self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: Optional[int] = None ) -> List[MemoryChunk]: """检索相关记忆""" if not self._initialized or not self.config.enable_enhanced_memory: @@ -164,9 +154,7 @@ class EnhancedMemoryAdapter: try: limit = limit or self.config.max_retrieval_results - memories = await self.integration_layer.retrieve_relevant_memories( - query, None, context, limit - ) + memories = await self.integration_layer.retrieve_relevant_memories(query, None, context, limit) self.adapter_stats["memories_retrieved"] += len(memories) logger.debug(f"检索到 {len(memories)} 条相关记忆") @@ -178,11 +166,7 @@ class EnhancedMemoryAdapter: return [] async def get_memory_context_for_prompt( - self, - query: str, - user_id: str, - context: Optional[Dict[str, Any]] = None, - max_memories: int = 5 + self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, max_memories: int = 5 ) -> str: """获取用于提示词的记忆上下文""" memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories) @@ -197,15 +181,11 @@ class EnhancedMemoryAdapter: include_confidence=False, use_emoji_icons=True, group_by_type=False, - max_display_length=150 - ) - - return format_memories_for_llm( - memories=memories, - query_context=query, - config=formatter_config + max_display_length=150, ) + return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config) + async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]: """获取增强记忆系统摘要""" if not self._initialized or not self.config.enable_enhanced_memory: @@ -227,7 +207,7 @@ class EnhancedMemoryAdapter: "adapter_stats": adapter_stats, "integration_stats": integration_stats, "total_memories_created": adapter_stats["memories_created"], - "total_memories_retrieved": adapter_stats["memories_retrieved"] + "total_memories_retrieved": adapter_stats["memories_retrieved"], } except Exception as e: @@ -285,12 +265,12 @@ async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAd from src.config.config import global_config adapter_config = AdapterConfig( - enable_enhanced_memory=getattr(global_config.memory, 'enable_enhanced_memory', True), - integration_mode=getattr(global_config.memory, 'enhanced_memory_mode', 'enhanced_only'), - auto_migration=getattr(global_config.memory, 'enable_memory_migration', True), - memory_value_threshold=getattr(global_config.memory, 'memory_value_threshold', 0.6), - fusion_threshold=getattr(global_config.memory, 'fusion_threshold', 0.85), - max_retrieval_results=getattr(global_config.memory, 'max_retrieval_results', 10) + enable_enhanced_memory=getattr(global_config.memory, "enable_enhanced_memory", True), + integration_mode=getattr(global_config.memory, "enhanced_memory_mode", "enhanced_only"), + auto_migration=getattr(global_config.memory, "enable_memory_migration", True), + memory_value_threshold=getattr(global_config.memory, "memory_value_threshold", 0.6), + fusion_threshold=getattr(global_config.memory, "fusion_threshold", 0.85), + max_retrieval_results=getattr(global_config.memory, "max_retrieval_results", 10), ) _enhanced_memory_adapter = EnhancedMemoryAdapter(llm_model, adapter_config) @@ -312,13 +292,13 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest): async def process_conversation_with_enhanced_memory( - context: Dict[str, Any], - llm_model: Optional[LLMRequest] = None + context: Dict[str, Any], llm_model: Optional[LLMRequest] = None ) -> Dict[str, Any]: """使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息""" if not llm_model: # 获取默认的LLM模型 from src.llm_models.utils_model import get_global_llm_model + llm_model = get_global_llm_model() try: @@ -345,12 +325,13 @@ async def retrieve_memories_with_enhanced_system( user_id: str, context: Optional[Dict[str, Any]] = None, limit: int = 10, - llm_model: Optional[LLMRequest] = None + llm_model: Optional[LLMRequest] = None, ) -> List[MemoryChunk]: """使用增强记忆系统检索记忆""" if not llm_model: # 获取默认的LLM模型 from src.llm_models.utils_model import get_global_llm_model + llm_model = get_global_llm_model() try: @@ -366,12 +347,13 @@ async def get_memory_context_for_prompt( user_id: str, context: Optional[Dict[str, Any]] = None, max_memories: int = 5, - llm_model: Optional[LLMRequest] = None + llm_model: Optional[LLMRequest] = None, ) -> str: """获取用于提示词的记忆上下文""" if not llm_model: # 获取默认的LLM模型 from src.llm_models.utils_model import get_global_llm_model + llm_model = get_global_llm_model() try: @@ -379,4 +361,4 @@ async def get_memory_context_for_prompt( return await adapter.get_memory_context_for_prompt(query, user_id, context, max_memories) except Exception as e: logger.error(f"获取记忆上下文失败: {e}", exc_info=True) - return "" \ No newline at end of file + return "" diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py index 37188a08e..a1b374510 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py @@ -4,7 +4,6 @@ 用于在消息处理过程中自动构建和检索记忆 """ -import asyncio from typing import Dict, List, Any, Optional from datetime import datetime @@ -19,8 +18,7 @@ class EnhancedMemoryHooks: """增强记忆系统钩子 - 自动处理消息的记忆构建和检索""" def __init__(self): - self.enabled = (global_config.memory.enable_memory and - global_config.memory.enable_enhanced_memory) + self.enabled = global_config.memory.enable_memory and global_config.memory.enable_enhanced_memory self.processed_messages = set() # 避免重复处理 async def process_message_for_memory( @@ -29,7 +27,7 @@ class EnhancedMemoryHooks: user_id: str, chat_id: str, message_id: str, - context: Optional[Dict[str, Any]] = None + context: Optional[Dict[str, Any]] = None, ) -> bool: """ 处理消息并构建记忆 @@ -76,7 +74,7 @@ class EnhancedMemoryHooks: "timestamp": datetime.now().timestamp(), "message_type": "user_message", **bot_context, - **(context or {}) + **(context or {}), } # 处理对话并构建记忆 @@ -84,7 +82,7 @@ class EnhancedMemoryHooks: conversation_text=message_content, context=memory_context, user_id=user_id, - timestamp=memory_context["timestamp"] + timestamp=memory_context["timestamp"], ) # 标记消息已处理 @@ -108,7 +106,7 @@ class EnhancedMemoryHooks: user_id: str, chat_id: str, limit: int = 5, - extra_context: Optional[Dict[str, Any]] = None + extra_context: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """ 为回复获取相关记忆 @@ -134,9 +132,7 @@ class EnhancedMemoryHooks: context = { "chat_id": chat_id, "query_intent": "response_generation", - "expected_memory_types": [ - "personal_fact", "event", "preference", "opinion" - ] + "expected_memory_types": ["personal_fact", "event", "preference", "opinion"], } if extra_context: @@ -144,10 +140,7 @@ class EnhancedMemoryHooks: # 获取相关记忆 enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context( - query_text=query_text, - user_id=user_id, - context=context, - limit=limit + query_text=query_text, user_id=user_id, context=context, limit=limit ) # 转换为字典格式 @@ -199,4 +192,4 @@ class EnhancedMemoryHooks: # 创建全局实例 -enhanced_memory_hooks = EnhancedMemoryHooks() \ No newline at end of file +enhanced_memory_hooks = EnhancedMemoryHooks() diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py index 2d5e8f44e..913c2aed0 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py @@ -4,7 +4,6 @@ 用于在现有系统中无缝集成增强记忆功能 """ -import asyncio from typing import Dict, Any, Optional from src.common.logger import get_logger @@ -14,11 +13,7 @@ logger = get_logger(__name__) async def process_user_message_memory( - message_content: str, - user_id: str, - chat_id: str, - message_id: str, - context: Optional[Dict[str, Any]] = None + message_content: str, user_id: str, chat_id: str, message_id: str, context: Optional[Dict[str, Any]] = None ) -> bool: """ 处理用户消息并构建记忆 @@ -35,11 +30,7 @@ async def process_user_message_memory( """ try: success = await enhanced_memory_hooks.process_message_for_memory( - message_content=message_content, - user_id=user_id, - chat_id=chat_id, - message_id=message_id, - context=context + message_content=message_content, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context ) if success: @@ -53,11 +44,7 @@ async def process_user_message_memory( async def get_relevant_memories_for_response( - query_text: str, - user_id: str, - chat_id: str, - limit: int = 5, - extra_context: Optional[Dict[str, Any]] = None + query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ 为回复获取相关记忆 @@ -74,29 +61,17 @@ async def get_relevant_memories_for_response( """ try: memories = await enhanced_memory_hooks.get_memory_for_response( - query_text=query_text, - user_id=user_id, - chat_id=chat_id, - limit=limit, - extra_context=extra_context + query_text=query_text, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=extra_context ) - result = { - "has_memories": len(memories) > 0, - "memories": memories, - "memory_count": len(memories) - } + result = {"has_memories": len(memories) > 0, "memories": memories, "memory_count": len(memories)} logger.debug(f"为回复获取到 {len(memories)} 条相关记忆") return result except Exception as e: logger.error(f"获取回复记忆失败: {e}") - return { - "has_memories": False, - "memories": [], - "memory_count": 0 - } + return {"has_memories": False, "memories": [], "memory_count": 0} def format_memories_for_prompt(memories: Dict[str, Any]) -> str: @@ -152,16 +127,13 @@ def get_memory_system_status() -> Dict[str, Any]: "enabled": enhanced_memory_hooks.enabled, "enhanced_system_initialized": enhanced_memory_manager.is_initialized, "processed_messages_count": len(enhanced_memory_hooks.processed_messages), - "system_type": "enhanced_memory_system" + "system_type": "enhanced_memory_system", } # 便捷函数 async def remember_message( - message: str, - user_id: str = "default_user", - chat_id: str = "default_chat", - context: Optional[Dict[str, Any]] = None + message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: Optional[Dict[str, Any]] = None ) -> bool: """ 便捷的记忆构建函数 @@ -175,13 +147,10 @@ async def remember_message( bool: 是否成功 """ import uuid + message_id = str(uuid.uuid4()) return await process_user_message_memory( - message_content=message, - user_id=user_id, - chat_id=chat_id, - message_id=message_id, - context=context + message_content=message, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context ) @@ -190,7 +159,7 @@ async def recall_memories( user_id: str = "default_user", chat_id: str = "default_chat", limit: int = 5, - context: Optional[Dict[str, Any]] = None + context: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """ 便捷的记忆检索函数 @@ -205,9 +174,5 @@ async def recall_memories( Dict: 记忆信息 """ return await get_relevant_memories_for_response( - query_text=query, - user_id=user_id, - chat_id=chat_id, - limit=limit, - extra_context=context - ) \ No newline at end of file + query_text=query, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=context + ) diff --git a/src/chat/memory_system/deprecated_backup/enhanced_reranker.py b/src/chat/memory_system/deprecated_backup/enhanced_reranker.py index a6dfafb01..e5b368460 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_reranker.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_reranker.py @@ -18,32 +18,34 @@ logger = get_logger(__name__) class IntentType(Enum): """对话意图类型""" - FACT_QUERY = "fact_query" # 事实查询 - EVENT_RECALL = "event_recall" # 事件回忆 + + FACT_QUERY = "fact_query" # 事实查询 + EVENT_RECALL = "event_recall" # 事件回忆 PREFERENCE_CHECK = "preference_check" # 偏好检查 - GENERAL_CHAT = "general_chat" # 一般对话 - UNKNOWN = "unknown" # 未知意图 + GENERAL_CHAT = "general_chat" # 一般对话 + UNKNOWN = "unknown" # 未知意图 @dataclass class ReRankingConfig: """重排序配置""" + # 权重配置 (w1 + w2 + w3 + w4 = 1.0) - semantic_weight: float = 0.5 # 语义相似度权重 - recency_weight: float = 0.2 # 时效性权重 - usage_freq_weight: float = 0.2 # 使用频率权重 - type_match_weight: float = 0.1 # 类型匹配权重 - + semantic_weight: float = 0.5 # 语义相似度权重 + recency_weight: float = 0.2 # 时效性权重 + usage_freq_weight: float = 0.2 # 使用频率权重 + type_match_weight: float = 0.1 # 类型匹配权重 + # 时效性衰减参数 - recency_decay_rate: float = 0.1 # 时效性衰减率 (天) - + recency_decay_rate: float = 0.1 # 时效性衰减率 (天) + # 使用频率计算参数 - freq_log_base: float = 2.0 # 对数底数 - freq_max_score: float = 5.0 # 最大频率得分 - + freq_log_base: float = 2.0 # 对数底数 + freq_max_score: float = 5.0 # 最大频率得分 + # 类型匹配权重映射 type_match_weights: Dict[str, Dict[str, float]] = None - + def __post_init__(self): """初始化类型匹配权重""" if self.type_match_weights is None: @@ -53,102 +55,150 @@ class ReRankingConfig: MemoryType.KNOWLEDGE.value: 0.8, MemoryType.PREFERENCE.value: 0.5, MemoryType.EVENT.value: 0.3, - "default": 0.3 + "default": 0.3, }, IntentType.EVENT_RECALL.value: { MemoryType.EVENT.value: 1.0, MemoryType.EXPERIENCE.value: 0.8, MemoryType.EMOTION.value: 0.6, MemoryType.PERSONAL_FACT.value: 0.5, - "default": 0.5 + "default": 0.5, }, IntentType.PREFERENCE_CHECK.value: { MemoryType.PREFERENCE.value: 1.0, MemoryType.OPINION.value: 0.8, MemoryType.GOAL.value: 0.6, MemoryType.PERSONAL_FACT.value: 0.4, - "default": 0.4 + "default": 0.4, }, - IntentType.GENERAL_CHAT.value: { - "default": 0.8 - }, - IntentType.UNKNOWN.value: { - "default": 0.8 - } + IntentType.GENERAL_CHAT.value: {"default": 0.8}, + IntentType.UNKNOWN.value: {"default": 0.8}, } class IntentClassifier: """轻量级意图识别器""" - + def __init__(self): # 关键词模式匹配规则 self.patterns = { IntentType.FACT_QUERY: [ # 中文模式 - "我是", "我的", "我叫", "我在", "我住在", "我的职业", "我的工作", - "什么时候", "在哪里", "是什么", "多少", "几岁", "年龄", + "我是", + "我的", + "我叫", + "我在", + "我住在", + "我的职业", + "我的工作", + "什么时候", + "在哪里", + "是什么", + "多少", + "几岁", + "年龄", # 英文模式 - "what is", "where is", "when is", "how old", "my name", "i am", "i live" + "what is", + "where is", + "when is", + "how old", + "my name", + "i am", + "i live", ], IntentType.EVENT_RECALL: [ # 中文模式 - "记得", "想起", "还记得", "那次", "上次", "之前", "以前", "曾经", - "发生过", "经历", "做过", "去过", "见过", + "记得", + "想起", + "还记得", + "那次", + "上次", + "之前", + "以前", + "曾经", + "发生过", + "经历", + "做过", + "去过", + "见过", # 英文模式 - "remember", "recall", "last time", "before", "previously", "happened", "experience" + "remember", + "recall", + "last time", + "before", + "previously", + "happened", + "experience", ], IntentType.PREFERENCE_CHECK: [ # 中文模式 - "喜欢", "不喜欢", "偏好", "爱好", "兴趣", "讨厌", "最爱", "最喜欢", - "习惯", "通常", "一般", "倾向于", "更喜欢", + "喜欢", + "不喜欢", + "偏好", + "爱好", + "兴趣", + "讨厌", + "最爱", + "最喜欢", + "习惯", + "通常", + "一般", + "倾向于", + "更喜欢", # 英文模式 - "like", "love", "hate", "prefer", "favorite", "usually", "tend to", "interest" - ] + "like", + "love", + "hate", + "prefer", + "favorite", + "usually", + "tend to", + "interest", + ], } - + def classify_intent(self, query: str, context: Dict[str, Any]) -> IntentType: """识别对话意图""" if not query: return IntentType.UNKNOWN - + query_lower = query.lower() - + # 统计各意图的匹配分数 intent_scores = {intent: 0 for intent in IntentType} - + for intent, patterns in self.patterns.items(): for pattern in patterns: if pattern in query_lower: intent_scores[intent] += 1 - + # 返回得分最高的意图 max_score = max(intent_scores.values()) if max_score == 0: return IntentType.GENERAL_CHAT - + for intent, score in intent_scores.items(): if score == max_score: return intent - + return IntentType.GENERAL_CHAT class EnhancedReRanker: """增强重排序器 - 实现文档设计的多维度评分模型""" - + def __init__(self, config: Optional[ReRankingConfig] = None): self.config = config or ReRankingConfig() self.intent_classifier = IntentClassifier() - + # 验证权重和为1.0 total_weight = ( - self.config.semantic_weight + - self.config.recency_weight + - self.config.usage_freq_weight + - self.config.type_match_weight + self.config.semantic_weight + + self.config.recency_weight + + self.config.usage_freq_weight + + self.config.type_match_weight ) - + if abs(total_weight - 1.0) > 0.01: logger.warning(f"重排序权重和不为1.0: {total_weight}, 将进行归一化") # 归一化权重 @@ -156,94 +206,94 @@ class EnhancedReRanker: self.config.recency_weight /= total_weight self.config.usage_freq_weight /= total_weight self.config.type_match_weight /= total_weight - + def rerank_memories( self, query: str, candidate_memories: List[Tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity) context: Dict[str, Any], - limit: int = 10 + limit: int = 10, ) -> List[Tuple[str, MemoryChunk, float]]: """ 对候选记忆进行重排序 - + Args: query: 查询文本 candidate_memories: 候选记忆列表 [(memory_id, memory, vector_similarity)] context: 上下文信息 limit: 返回数量限制 - + Returns: 重排序后的记忆列表 [(memory_id, memory, final_score)] """ if not candidate_memories: return [] - + # 识别查询意图 intent = self.intent_classifier.classify_intent(query, context) logger.debug(f"识别到查询意图: {intent.value}") - + # 计算每个候选记忆的最终得分 scored_memories = [] current_time = time.time() - + for memory_id, memory, vector_sim in candidate_memories: try: # 1. 语义相似度得分 (已归一化到[0,1]) semantic_score = self._normalize_similarity(vector_sim) - + # 2. 时效性得分 recency_score = self._calculate_recency_score(memory, current_time) - + # 3. 使用频率得分 usage_freq_score = self._calculate_usage_frequency_score(memory) - + # 4. 类型匹配得分 type_match_score = self._calculate_type_match_score(memory, intent) - + # 计算最终得分 final_score = ( - self.config.semantic_weight * semantic_score + - self.config.recency_weight * recency_score + - self.config.usage_freq_weight * usage_freq_score + - self.config.type_match_weight * type_match_score + self.config.semantic_weight * semantic_score + + self.config.recency_weight * recency_score + + self.config.usage_freq_weight * usage_freq_score + + self.config.type_match_weight * type_match_score ) - + scored_memories.append((memory_id, memory, final_score)) - + # 记录调试信息 logger.debug( f"记忆评分 {memory_id[:8]}: semantic={semantic_score:.3f}, " f"recency={recency_score:.3f}, freq={usage_freq_score:.3f}, " f"type={type_match_score:.3f}, final={final_score:.3f}" ) - + except Exception as e: logger.error(f"计算记忆 {memory_id} 得分时出错: {e}") # 使用向量相似度作为后备得分 scored_memories.append((memory_id, memory, vector_sim)) - + # 按最终得分降序排序 scored_memories.sort(key=lambda x: x[2], reverse=True) - + # 返回前N个结果 result = scored_memories[:limit] - + highest_score = result[0][2] if result else 0.0 logger.info( f"重排序完成: 候选={len(candidate_memories)}, 返回={len(result)}, " f"意图={intent.value}, 最高分={highest_score:.3f}" ) - + return result - + def _normalize_similarity(self, raw_similarity: float) -> float: """归一化相似度到[0,1]区间""" # 假设原始相似度已经在[-1,1]或[0,1]区间 if raw_similarity < 0: return (raw_similarity + 1) / 2 # 从[-1,1]映射到[0,1] return min(1.0, max(0.0, raw_similarity)) # 确保在[0,1]区间 - + def _calculate_recency_score(self, memory: MemoryChunk, current_time: float) -> float: """ 计算时效性得分 @@ -251,13 +301,13 @@ class EnhancedReRanker: """ last_accessed = memory.metadata.last_accessed or memory.metadata.created_at days_old = (current_time - last_accessed) / (24 * 3600) # 转换为天数 - + if days_old < 0: days_old = 0 # 处理时间异常 - + score = 1 / (1 + self.config.recency_decay_rate * days_old) return min(1.0, max(0.0, score)) - + def _calculate_usage_frequency_score(self, memory: MemoryChunk) -> float: """ 计算使用频率得分 @@ -266,22 +316,22 @@ class EnhancedReRanker: access_count = memory.metadata.access_count if access_count <= 0: return 0.0 - + log_count = math.log2(access_count + 1) score = log_count / self.config.freq_max_score return min(1.0, max(0.0, score)) - + def _calculate_type_match_score(self, memory: MemoryChunk, intent: IntentType) -> float: """计算类型匹配得分""" memory_type = memory.memory_type.value intent_value = intent.value - + # 获取对应意图的类型权重映射 type_weights = self.config.type_match_weights.get(intent_value, {}) - + # 查找具体类型的权重,如果没有则使用默认权重 score = type_weights.get(memory_type, type_weights.get("default", 0.8)) - + return min(1.0, max(0.0, score)) @@ -294,7 +344,7 @@ def rerank_candidate_memories( candidate_memories: List[Tuple[str, MemoryChunk, float]], context: Dict[str, Any], limit: int = 10, - config: Optional[ReRankingConfig] = None + config: Optional[ReRankingConfig] = None, ) -> List[Tuple[str, MemoryChunk, float]]: """ 便捷函数:对候选记忆进行重排序 @@ -303,5 +353,5 @@ def rerank_candidate_memories( reranker = EnhancedReRanker(config) else: reranker = default_reranker - - return reranker.rerank_memories(query, candidate_memories, context, limit) \ No newline at end of file + + return reranker.rerank_memories(query, candidate_memories, context, limit) diff --git a/src/chat/memory_system/deprecated_backup/integration_layer.py b/src/chat/memory_system/deprecated_backup/integration_layer.py index 3f8b7f1ce..5b9282a84 100644 --- a/src/chat/memory_system/deprecated_backup/integration_layer.py +++ b/src/chat/memory_system/deprecated_backup/integration_layer.py @@ -12,7 +12,7 @@ from enum import Enum from src.common.logger import get_logger from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel +from src.chat.memory_system.memory_chunk import MemoryChunk from src.llm_models.utils_model import LLMRequest logger = get_logger(__name__) @@ -20,13 +20,15 @@ logger = get_logger(__name__) class IntegrationMode(Enum): """集成模式""" - REPLACE = "replace" # 完全替换现有记忆系统 + + REPLACE = "replace" # 完全替换现有记忆系统 ENHANCED_ONLY = "enhanced_only" # 仅使用增强记忆系统 @dataclass class IntegrationConfig: """集成配置""" + mode: IntegrationMode = IntegrationMode.ENHANCED_ONLY enable_enhanced_memory: bool = True memory_value_threshold: float = 0.6 @@ -51,7 +53,7 @@ class MemoryIntegrationLayer: "enhanced_queries": 0, "memory_creations": 0, "average_response_time": 0.0, - "success_rate": 0.0 + "success_rate": 0.0, } # 初始化锁 @@ -88,6 +90,7 @@ class MemoryIntegrationLayer: # 创建增强记忆系统配置 from src.chat.memory_system.enhanced_memory_core import MemorySystemConfig + memory_config = MemorySystemConfig.from_global_config() # 使用集成配置覆盖部分值 @@ -96,9 +99,7 @@ class MemoryIntegrationLayer: memory_config.final_recall_limit = self.config.max_retrieval_results # 创建增强记忆系统 - self.enhanced_memory = EnhancedMemorySystem( - config=memory_config - ) + self.enhanced_memory = EnhancedMemorySystem(config=memory_config) # 如果外部提供了LLM模型,注入到系统中 if self.llm_model is not None: @@ -112,10 +113,7 @@ class MemoryIntegrationLayer: logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True) raise - async def process_conversation( - self, - context: Dict[str, Any] - ) -> Dict[str, Any]: + async def process_conversation(self, context: Dict[str, Any]) -> Dict[str, Any]: """处理对话记忆,仅使用上下文信息""" if not self._initialized or not self.enhanced_memory: return {"success": False, "error": "Memory system not available"} @@ -154,7 +152,7 @@ class MemoryIntegrationLayer: query: str, user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None + limit: Optional[int] = None, ) -> List[MemoryChunk]: """检索相关记忆""" if not self._initialized or not self.enhanced_memory: @@ -163,10 +161,7 @@ class MemoryIntegrationLayer: try: limit = limit or self.config.max_retrieval_results memories = await self.enhanced_memory.retrieve_relevant_memories( - query=query, - user_id=None, - context=context or {}, - limit=limit + query=query, user_id=None, context=context or {}, limit=limit ) memory_count = len(memories) @@ -191,7 +186,7 @@ class MemoryIntegrationLayer: "status": "initialized", "mode": self.config.mode.value, "enhanced_memory": enhanced_status, - "integration_stats": self.integration_stats.copy() + "integration_stats": self.integration_stats.copy(), } except Exception as e: @@ -248,4 +243,4 @@ class MemoryIntegrationLayer: logger.info("✅ 记忆系统集成层已关闭") except Exception as e: - logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True) \ No newline at end of file + logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True) diff --git a/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py b/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py index 2dab63b7a..4659389cb 100644 --- a/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py +++ b/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py @@ -4,17 +4,15 @@ 提供与现有MoFox Bot系统的无缝集成点 """ -import asyncio import time -from typing import Dict, List, Optional, Any, Callable +from typing import Dict, Optional, Any from dataclasses import dataclass from src.common.logger import get_logger from src.chat.memory_system.enhanced_memory_adapter import ( - get_enhanced_memory_adapter, process_conversation_with_enhanced_memory, retrieve_memories_with_enhanced_system, - get_memory_context_for_prompt + get_memory_context_for_prompt, ) logger = get_logger(__name__) @@ -23,6 +21,7 @@ logger = get_logger(__name__) @dataclass class HookResult: """钩子执行结果""" + success: bool data: Any = None error: Optional[str] = None @@ -39,7 +38,7 @@ class MemoryIntegrationHooks: "memory_retrieval_hooks": 0, "prompt_enhancement_hooks": 0, "total_hook_executions": 0, - "average_hook_time": 0.0 + "average_hook_time": 0.0, } async def register_hooks(self): @@ -130,10 +129,7 @@ class MemoryIntegrationHooks: from src.plugin_system.base.component_types import EventType # 注册消息后处理事件 - event_manager.subscribe( - EventType.MESSAGE_PROCESSED, - self._on_message_processed_handler - ) + event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler) logger.debug("已注册到事件系统的消息处理钩子") except ImportError: @@ -144,10 +140,8 @@ class MemoryIntegrationHooks: from src.chat.message_manager import message_manager # 如果消息管理器支持钩子注册 - if hasattr(message_manager, 'register_post_process_hook'): - message_manager.register_post_process_hook( - self._on_message_processed_hook - ) + if hasattr(message_manager, "register_post_process_hook"): + message_manager.register_post_process_hook(self._on_message_processed_hook) logger.debug("已注册到消息管理器的处理钩子") except ImportError: @@ -164,10 +158,8 @@ class MemoryIntegrationHooks: from src.chat.message_receive.chat_stream import get_chat_manager chat_manager = get_chat_manager() - if hasattr(chat_manager, 'register_save_hook'): - chat_manager.register_save_hook( - self._on_chat_stream_save_hook - ) + if hasattr(chat_manager, "register_save_hook"): + chat_manager.register_save_hook(self._on_chat_stream_save_hook) logger.debug("已注册到聊天流管理器的保存钩子") except ImportError: @@ -183,10 +175,8 @@ class MemoryIntegrationHooks: try: from src.chat.replyer.default_generator import default_generator - if hasattr(default_generator, 'register_pre_generation_hook'): - default_generator.register_pre_generation_hook( - self._on_pre_response_hook - ) + if hasattr(default_generator, "register_pre_generation_hook"): + default_generator.register_pre_generation_hook(self._on_pre_response_hook) logger.debug("已注册到回复生成器的前置钩子") except ImportError: @@ -202,10 +192,8 @@ class MemoryIntegrationHooks: try: from src.chat.knowledge.knowledge_lib import knowledge_manager - if hasattr(knowledge_manager, 'register_query_enhancer'): - knowledge_manager.register_query_enhancer( - self._on_knowledge_query_hook - ) + if hasattr(knowledge_manager, "register_query_enhancer"): + knowledge_manager.register_query_enhancer(self._on_knowledge_query_hook) logger.debug("已注册到知识库的查询增强钩子") except ImportError: @@ -221,10 +209,8 @@ class MemoryIntegrationHooks: try: from src.chat.utils.prompt import prompt_manager - if hasattr(prompt_manager, 'register_enhancer'): - prompt_manager.register_enhancer( - self._on_prompt_building_hook - ) + if hasattr(prompt_manager, "register_enhancer"): + prompt_manager.register_enhancer(self._on_prompt_building_hook) logger.debug("已注册到提示词管理器的增强钩子") except ImportError: @@ -278,7 +264,7 @@ class MemoryIntegrationHooks: "platform": message_info.get("platform", "unknown"), "interest_value": message_data.get("interest_value", 0.0), "keywords": message_data.get("key_words", []), - "timestamp": message_data.get("time", time.time()) + "timestamp": message_data.get("time", time.time()), } # 使用增强记忆系统处理对话 @@ -296,7 +282,7 @@ class MemoryIntegrationHooks: return HookResult(success=True, data=result, processing_time=processing_time) else: logger.warning(f"消息处理钩子执行失败: {result.get('error')}") - return HookResult(success=False, error=result.get('error'), processing_time=processing_time) + return HookResult(success=False, error=result.get("error"), processing_time=processing_time) except Exception as e: processing_time = time.time() - start_time @@ -334,7 +320,7 @@ class MemoryIntegrationHooks: "stream_id": chat_stream_data.get("stream_id"), "platform": chat_stream_data.get("platform", "unknown"), "message_count": len(messages), - "timestamp": time.time() + "timestamp": time.time(), } # 使用增强记忆系统处理对话 @@ -352,7 +338,7 @@ class MemoryIntegrationHooks: return HookResult(success=True, data=result, processing_time=processing_time) else: logger.warning(f"聊天流保存钩子执行失败: {result.get('error')}") - return HookResult(success=False, error=result.get('error'), processing_time=processing_time) + return HookResult(success=False, error=result.get("error"), processing_time=processing_time) except Exception as e: processing_time = time.time() - start_time @@ -375,9 +361,7 @@ class MemoryIntegrationHooks: return HookResult(success=True, data="No query provided") # 检索相关记忆 - memories = await retrieve_memories_with_enhanced_system( - query, user_id, context, limit=5 - ) + memories = await retrieve_memories_with_enhanced_system(query, user_id, context, limit=5) processing_time = time.time() - start_time self._update_hook_stats(processing_time) @@ -411,9 +395,7 @@ class MemoryIntegrationHooks: return HookResult(success=True, data="No query provided") # 获取记忆上下文并增强查询 - memory_context = await get_memory_context_for_prompt( - query, user_id, context, max_memories=3 - ) + memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=3) processing_time = time.time() - start_time self._update_hook_stats(processing_time) @@ -445,9 +427,7 @@ class MemoryIntegrationHooks: return HookResult(success=True, data="No query provided") # 获取记忆上下文 - memory_context = await get_memory_context_for_prompt( - query, user_id, context, max_memories=5 - ) + memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=5) processing_time = time.time() - start_time self._update_hook_stats(processing_time) @@ -499,6 +479,7 @@ class MemoryMaintenanceTask: # 获取适配器实例 try: from src.chat.memory_system.enhanced_memory_adapter import _enhanced_memory_adapter + if _enhanced_memory_adapter: await _enhanced_memory_adapter.maintenance() logger.info("✅ 增强记忆系统维护任务完成") @@ -543,4 +524,4 @@ async def initialize_memory_integration_hooks(): return hooks except Exception as e: logger.error(f"❌ 记忆集成钩子初始化失败: {e}", exc_info=True) - return None \ No newline at end of file + return None diff --git a/src/chat/memory_system/deprecated_backup/metadata_index.py b/src/chat/memory_system/deprecated_backup/metadata_index.py index 10f8ff266..f7ab8ecda 100644 --- a/src/chat/memory_system/deprecated_backup/metadata_index.py +++ b/src/chat/memory_system/deprecated_backup/metadata_index.py @@ -4,12 +4,10 @@ 为记忆系统提供多维度的精准过滤和查询能力 """ -import os import time import orjson from typing import Dict, List, Optional, Tuple, Set, Any, Union -from datetime import datetime, timedelta -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import Enum import threading from collections import defaultdict @@ -23,23 +21,25 @@ logger = get_logger(__name__) class IndexType(Enum): """索引类型""" - MEMORY_TYPE = "memory_type" # 记忆类型索引 - USER_ID = "user_id" # 用户ID索引 - SUBJECT = "subject" # 主体索引 - KEYWORD = "keyword" # 关键词索引 - TAG = "tag" # 标签索引 - CATEGORY = "category" # 分类索引 - TIMESTAMP = "timestamp" # 时间索引 - CONFIDENCE = "confidence" # 置信度索引 - IMPORTANCE = "importance" # 重要性索引 + + MEMORY_TYPE = "memory_type" # 记忆类型索引 + USER_ID = "user_id" # 用户ID索引 + SUBJECT = "subject" # 主体索引 + KEYWORD = "keyword" # 关键词索引 + TAG = "tag" # 标签索引 + CATEGORY = "category" # 分类索引 + TIMESTAMP = "timestamp" # 时间索引 + CONFIDENCE = "confidence" # 置信度索引 + IMPORTANCE = "importance" # 重要性索引 RELATIONSHIP_SCORE = "relationship_score" # 关系分索引 ACCESS_FREQUENCY = "access_frequency" # 访问频率索引 - SEMANTIC_HASH = "semantic_hash" # 语义哈希索引 + SEMANTIC_HASH = "semantic_hash" # 语义哈希索引 @dataclass class IndexQuery: """索引查询条件""" + user_ids: Optional[List[str]] = None memory_types: Optional[List[MemoryType]] = None subjects: Optional[List[str]] = None @@ -61,6 +61,7 @@ class IndexQuery: @dataclass class IndexResult: """索引结果""" + memory_ids: List[str] total_count: int query_time: float @@ -102,7 +103,7 @@ class MetadataIndexManager: "average_query_time": 0.0, "total_queries": 0, "cache_hit_rate": 0.0, - "cache_hits": 0 + "cache_hits": 0, } # 线程锁 @@ -171,9 +172,8 @@ class MetadataIndexManager: index_time = time.time() - start_time self.index_stats["index_build_time"] = ( - (self.index_stats["index_build_time"] * (len(memories) - 1) + index_time) / - len(memories) - ) + self.index_stats["index_build_time"] * (len(memories) - 1) + index_time + ) / len(memories) logger.debug(f"元数据索引完成,{len(memories)} 条记忆,耗时 {index_time:.3f}秒") @@ -258,7 +258,7 @@ class MetadataIndexManager: "relationship_score": memory.metadata.relationship_score, "relevance_score": memory.metadata.relevance_score, "semantic_hash": memory.semantic_hash, - "subjects": memory.subjects + "subjects": memory.subjects, } # 记忆类型索引 @@ -355,21 +355,20 @@ class MetadataIndexManager: # 限制数量 if query.limit and len(filtered_ids) > query.limit: - filtered_ids = filtered_ids[:query.limit] + filtered_ids = filtered_ids[: query.limit] # 记录查询统计 query_time = time.time() - start_time self.index_stats["total_queries"] += 1 self.index_stats["average_query_time"] = ( - (self.index_stats["average_query_time"] * (self.index_stats["total_queries"] - 1) + query_time) / - self.index_stats["total_queries"] - ) + self.index_stats["average_query_time"] * (self.index_stats["total_queries"] - 1) + query_time + ) / self.index_stats["total_queries"] return IndexResult( memory_ids=filtered_ids, total_count=len(filtered_ids), query_time=query_time, - filtered_by=self._get_applied_filters(query) + filtered_by=self._get_applied_filters(query), ) except Exception as e: @@ -486,15 +485,15 @@ class MetadataIndexManager: if query.time_range: start_time, end_time = query.time_range filtered_ids = [ - memory_id for memory_id in filtered_ids - if self._is_in_time_range(memory_id, start_time, end_time) + memory_id for memory_id in filtered_ids if self._is_in_time_range(memory_id, start_time, end_time) ] # 置信度过滤 if query.confidence_levels: confidence_set = set(query.confidence_levels) filtered_ids = [ - memory_id for memory_id in filtered_ids + memory_id + for memory_id in filtered_ids if self.memory_metadata_cache[memory_id]["confidence"] in confidence_set ] @@ -502,27 +501,31 @@ class MetadataIndexManager: if query.importance_levels: importance_set = set(query.importance_levels) filtered_ids = [ - memory_id for memory_id in filtered_ids + memory_id + for memory_id in filtered_ids if self.memory_metadata_cache[memory_id]["importance"] in importance_set ] # 关系分范围过滤 if query.min_relationship_score is not None: filtered_ids = [ - memory_id for memory_id in filtered_ids + memory_id + for memory_id in filtered_ids if self.memory_metadata_cache[memory_id]["relationship_score"] >= query.min_relationship_score ] if query.max_relationship_score is not None: filtered_ids = [ - memory_id for memory_id in filtered_ids + memory_id + for memory_id in filtered_ids if self.memory_metadata_cache[memory_id]["relationship_score"] <= query.max_relationship_score ] # 最小访问次数过滤 if query.min_access_count is not None: filtered_ids = [ - memory_id for memory_id in filtered_ids + memory_id + for memory_id in filtered_ids if self.memory_metadata_cache[memory_id]["access_count"] >= query.min_access_count ] @@ -530,7 +533,8 @@ class MetadataIndexManager: if query.semantic_hashes: hash_set = set(query.semantic_hashes) filtered_ids = [ - memory_id for memory_id in filtered_ids + memory_id + for memory_id in filtered_ids if self.memory_metadata_cache[memory_id]["semantic_hash"] in hash_set ] @@ -560,8 +564,7 @@ class MetadataIndexManager: elif sort_by == "relevance_score": # 按相关度排序 memory_ids.sort( - key=lambda mid: self.memory_metadata_cache[mid]["relevance_score"], - reverse=(sort_order == "desc") + key=lambda mid: self.memory_metadata_cache[mid]["relevance_score"], reverse=(sort_order == "desc") ) elif sort_by == "relationship_score": @@ -574,8 +577,7 @@ class MetadataIndexManager: elif sort_by == "last_accessed": # 按最后访问时间排序 memory_ids.sort( - key=lambda mid: self.memory_metadata_cache[mid]["last_accessed"], - reverse=(sort_order == "desc") + key=lambda mid: self.memory_metadata_cache[mid]["last_accessed"], reverse=(sort_order == "desc") ) return memory_ids @@ -665,7 +667,9 @@ class MetadataIndexManager: self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid != memory_id] # 从访问频率索引中移除 - self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid != memory_id] + self.access_frequency_index = [ + (count, mid) for count, mid in self.access_frequency_index if mid != memory_id + ] # 注意:关键词、标签、分类索引需要从原始记忆中获取,这里简化处理 # 实际实现中可能需要重新加载记忆或维护反向索引 @@ -704,7 +708,7 @@ class MetadataIndexManager: "average_importance": 0.0, "average_relationship_score": 0.0, "top_keywords": [], - "top_tags": [] + "top_tags": [], } if user_id: @@ -789,23 +793,23 @@ class MetadataIndexManager: indices_data[index_type.value] = serialized_index indices_file = self.index_path / "indices.json" - with open(indices_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(indices_data, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(indices_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(indices_data, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存时间索引 time_index_file = self.index_path / "time_index.json" - with open(time_index_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(self.time_index, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(time_index_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(self.time_index, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存关系分索引 relationship_index_file = self.index_path / "relationship_index.json" - with open(relationship_index_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(self.relationship_index, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(relationship_index_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(self.relationship_index, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存访问频率索引 access_frequency_index_file = self.index_path / "access_frequency_index.json" - with open(access_frequency_index_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(self.access_frequency_index, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(access_frequency_index_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(self.access_frequency_index, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存元数据缓存 metadata_cache_file = self.index_path / "metadata_cache.json" @@ -813,13 +817,13 @@ class MetadataIndexManager: memory_id: self._serialize_metadata_entry(metadata) for memory_id, metadata in self.memory_metadata_cache.items() } - with open(metadata_cache_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(metadata_cache_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存统计信息 stats_file = self.index_path / "index_stats.json" - with open(stats_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(self.index_stats, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(stats_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(self.index_stats, option=orjson.OPT_INDENT_2).decode("utf-8")) self._dirty = False logger.info("✅ 元数据索引保存完成") @@ -835,7 +839,7 @@ class MetadataIndexManager: # 加载各类索引 indices_file = self.index_path / "indices.json" if indices_file.exists(): - with open(indices_file, 'r', encoding='utf-8') as f: + with open(indices_file, "r", encoding="utf-8") as f: indices_data = orjson.loads(f.read()) for index_type_value, index_data in indices_data.items(): @@ -849,25 +853,25 @@ class MetadataIndexManager: # 加载时间索引 time_index_file = self.index_path / "time_index.json" if time_index_file.exists(): - with open(time_index_file, 'r', encoding='utf-8') as f: + with open(time_index_file, "r", encoding="utf-8") as f: self.time_index = orjson.loads(f.read()) # 加载关系分索引 relationship_index_file = self.index_path / "relationship_index.json" if relationship_index_file.exists(): - with open(relationship_index_file, 'r', encoding='utf-8') as f: + with open(relationship_index_file, "r", encoding="utf-8") as f: self.relationship_index = orjson.loads(f.read()) # 加载访问频率索引 access_frequency_index_file = self.index_path / "access_frequency_index.json" if access_frequency_index_file.exists(): - with open(access_frequency_index_file, 'r', encoding='utf-8') as f: + with open(access_frequency_index_file, "r", encoding="utf-8") as f: self.access_frequency_index = orjson.loads(f.read()) # 加载元数据缓存 metadata_cache_file = self.index_path / "metadata_cache.json" if metadata_cache_file.exists(): - with open(metadata_cache_file, 'r', encoding='utf-8') as f: + with open(metadata_cache_file, "r", encoding="utf-8") as f: cache_data = orjson.loads(f.read()) # 转换置信度和重要性为枚举类型 @@ -910,7 +914,7 @@ class MetadataIndexManager: # 加载统计信息 stats_file = self.index_path / "index_stats.json" if stats_file.exists(): - with open(stats_file, 'r', encoding='utf-8') as f: + with open(stats_file, "r", encoding="utf-8") as f: self.index_stats = orjson.loads(f.read()) # 更新记忆计数 @@ -937,9 +941,7 @@ class MetadataIndexManager: # 更新统计信息 if self.index_stats["total_queries"] > 0: - self.index_stats["cache_hit_rate"] = ( - self.index_stats["cache_hits"] / self.index_stats["total_queries"] - ) + self.index_stats["cache_hit_rate"] = self.index_stats["cache_hits"] / self.index_stats["total_queries"] logger.info("✅ 元数据索引优化完成") @@ -967,7 +969,9 @@ class MetadataIndexManager: self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid in valid_memory_ids] # 清理访问频率索引中的无效引用 - self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid in valid_memory_ids] + self.access_frequency_index = [ + (count, mid) for count, mid in self.access_frequency_index if mid in valid_memory_ids + ] # 更新总记忆数 self.index_stats["total_memories"] = len(valid_memory_ids) @@ -1017,7 +1021,7 @@ class MetadataIndexManager: "categories": len(self.indices[IndexType.CATEGORY]), "confidence_levels": len(self.indices[IndexType.CONFIDENCE]), "importance_levels": len(self.indices[IndexType.IMPORTANCE]), - "semantic_hashes": len(self.indices[IndexType.SEMANTIC_HASH]) + "semantic_hashes": len(self.indices[IndexType.SEMANTIC_HASH]), } - return stats \ No newline at end of file + return stats diff --git a/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py b/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py index 21b99f7f4..bc0a1a0f4 100644 --- a/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py +++ b/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py @@ -5,15 +5,13 @@ """ import time -import asyncio -from typing import Dict, List, Optional, Tuple, Set, Any +from typing import Dict, List, Optional, Set, Any from dataclasses import dataclass, field from enum import Enum -import numpy as np import orjson from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType from src.chat.memory_system.enhanced_reranker import EnhancedReRanker, ReRankingConfig logger = get_logger(__name__) @@ -21,30 +19,32 @@ logger = get_logger(__name__) class RetrievalStage(Enum): """检索阶段""" - METADATA_FILTERING = "metadata_filtering" # 元数据过滤阶段 - VECTOR_SEARCH = "vector_search" # 向量搜索阶段 - SEMANTIC_RERANKING = "semantic_reranking" # 语义重排序阶段 - CONTEXTUAL_FILTERING = "contextual_filtering" # 上下文过滤阶段 + + METADATA_FILTERING = "metadata_filtering" # 元数据过滤阶段 + VECTOR_SEARCH = "vector_search" # 向量搜索阶段 + SEMANTIC_RERANKING = "semantic_reranking" # 语义重排序阶段 + CONTEXTUAL_FILTERING = "contextual_filtering" # 上下文过滤阶段 @dataclass class RetrievalConfig: """检索配置""" + # 各阶段配置 - 优化召回率 - metadata_filter_limit: int = 150 # 元数据过滤阶段返回数量(增加) - vector_search_limit: int = 80 # 向量搜索阶段返回数量(增加) - semantic_rerank_limit: int = 30 # 语义重排序阶段返回数量(增加) - final_result_limit: int = 10 # 最终结果数量 + metadata_filter_limit: int = 150 # 元数据过滤阶段返回数量(增加) + vector_search_limit: int = 80 # 向量搜索阶段返回数量(增加) + semantic_rerank_limit: int = 30 # 语义重排序阶段返回数量(增加) + final_result_limit: int = 10 # 最终结果数量 # 相似度阈值 - 优化召回率 - vector_similarity_threshold: float = 0.5 # 向量相似度阈值(降低以提升召回率) + vector_similarity_threshold: float = 0.5 # 向量相似度阈值(降低以提升召回率) semantic_similarity_threshold: float = 0.05 # 语义相似度阈值(保持较低以获得更多相关记忆) # 权重配置 - vector_weight: float = 0.4 # 向量相似度权重 - semantic_weight: float = 0.3 # 语义相似度权重 - context_weight: float = 0.2 # 上下文权重 - recency_weight: float = 0.1 # 时效性权重 + vector_weight: float = 0.4 # 向量相似度权重 + semantic_weight: float = 0.3 # 语义相似度权重 + context_weight: float = 0.2 # 上下文权重 + recency_weight: float = 0.1 # 时效性权重 @classmethod def from_global_config(cls): @@ -53,26 +53,25 @@ class RetrievalConfig: return cls( # 各阶段配置 - 优化召回率 - metadata_filter_limit=max(150, global_config.memory.metadata_filter_limit), # 增加候选池 - vector_search_limit=max(80, global_config.memory.vector_search_limit), # 增加向量搜索结果 - semantic_rerank_limit=max(30, global_config.memory.semantic_rerank_limit), # 增加重排序候选 + metadata_filter_limit=max(150, global_config.memory.metadata_filter_limit), # 增加候选池 + vector_search_limit=max(80, global_config.memory.vector_search_limit), # 增加向量搜索结果 + semantic_rerank_limit=max(30, global_config.memory.semantic_rerank_limit), # 增加重排序候选 final_result_limit=global_config.memory.final_result_limit, - # 相似度阈值 - 优化召回率 vector_similarity_threshold=max(0.5, global_config.memory.vector_similarity_threshold), # 确保不低于0.5 semantic_similarity_threshold=0.05, # 进一步降低以提升召回率 - # 权重配置 vector_weight=global_config.memory.vector_weight, semantic_weight=global_config.memory.semantic_weight, context_weight=global_config.memory.context_weight, - recency_weight=global_config.memory.recency_weight + recency_weight=global_config.memory.recency_weight, ) @dataclass class StageResult: """阶段结果""" + stage: RetrievalStage memory_ids: List[str] processing_time: float @@ -84,6 +83,7 @@ class StageResult: @dataclass class RetrievalResult: """检索结果""" + query: str user_id: str final_memories: List[MemoryChunk] @@ -98,16 +98,16 @@ class MultiStageRetrieval: def __init__(self, config: Optional[RetrievalConfig] = None): self.config = config or RetrievalConfig.from_global_config() - + # 初始化增强重排序器 reranker_config = ReRankingConfig( semantic_weight=self.config.vector_weight, recency_weight=self.config.recency_weight, usage_freq_weight=0.2, # 新增的使用频率权重 - type_match_weight=0.1 # 新增的类型匹配权重 + type_match_weight=0.1, # 新增的类型匹配权重 ) self.reranker = EnhancedReRanker(reranker_config) - + self.retrieval_stats = { "total_queries": 0, "average_retrieval_time": 0.0, @@ -116,8 +116,8 @@ class MultiStageRetrieval: "vector_search": {"calls": 0, "avg_time": 0.0}, "semantic_reranking": {"calls": 0, "avg_time": 0.0}, "contextual_filtering": {"calls": 0, "avg_time": 0.0}, - "enhanced_reranking": {"calls": 0, "avg_time": 0.0} # 新增统计 - } + "enhanced_reranking": {"calls": 0, "avg_time": 0.0}, # 新增统计 + }, } async def retrieve_memories( @@ -128,7 +128,7 @@ class MultiStageRetrieval: metadata_index, vector_storage, all_memories_cache: Dict[str, MemoryChunk], - limit: Optional[int] = None + limit: Optional[int] = None, ) -> RetrievalResult: """多阶段记忆检索""" start_time = time.time() @@ -143,31 +143,39 @@ class MultiStageRetrieval: # 阶段1:元数据过滤 stage1_result = await self._metadata_filtering_stage( - query, user_id, context, metadata_index, all_memories_cache, - debug_log=memory_debug_info + query, user_id, context, metadata_index, all_memories_cache, debug_log=memory_debug_info ) stage_results.append(stage1_result) current_memory_ids.update(stage1_result.memory_ids) # 阶段2:向量搜索 stage2_result = await self._vector_search_stage( - query, user_id, context, vector_storage, current_memory_ids, all_memories_cache, - debug_log=memory_debug_info + query, + user_id, + context, + vector_storage, + current_memory_ids, + all_memories_cache, + debug_log=memory_debug_info, ) stage_results.append(stage2_result) current_memory_ids.update(stage2_result.memory_ids) # 阶段3:语义重排序 stage3_result = await self._semantic_reranking_stage( - query, user_id, context, current_memory_ids, all_memories_cache, - debug_log=memory_debug_info + query, user_id, context, current_memory_ids, all_memories_cache, debug_log=memory_debug_info ) stage_results.append(stage3_result) # 阶段4:上下文过滤 stage4_result = await self._contextual_filtering_stage( - query, user_id, context, stage3_result.memory_ids, all_memories_cache, limit, - debug_log=memory_debug_info + query, + user_id, + context, + stage3_result.memory_ids, + all_memories_cache, + limit, + debug_log=memory_debug_info, ) stage_results.append(stage4_result) @@ -176,18 +184,27 @@ class MultiStageRetrieval: logger.debug(f"上下文过滤结果过少({len(stage4_result.memory_ids)}),启用回退机制") # 回退到更宽松的检索策略 fallback_result = await self._fallback_retrieval_stage( - query, user_id, context, all_memories_cache, limit, + query, + user_id, + context, + all_memories_cache, + limit, excluded_ids=set(stage4_result.memory_ids), - debug_log=memory_debug_info + debug_log=memory_debug_info, ) if fallback_result.memory_ids: - stage4_result.memory_ids.extend(fallback_result.memory_ids[:limit - len(stage4_result.memory_ids)]) + stage4_result.memory_ids.extend(fallback_result.memory_ids[: limit - len(stage4_result.memory_ids)]) logger.debug(f"回退机制补充了 {len(fallback_result.memory_ids)} 条记忆") # 阶段5:增强重排序 (新增) stage5_result = await self._enhanced_reranking_stage( - query, user_id, context, stage4_result.memory_ids, all_memories_cache, limit, - debug_log=memory_debug_info + query, + user_id, + context, + stage4_result.memory_ids, + all_memories_cache, + limit, + debug_log=memory_debug_info, ) stage_results.append(stage5_result) @@ -226,13 +243,21 @@ class MultiStageRetrieval: "semantic_score": trace.get("semantic_stage", {}).get("score"), "context_score": trace.get("context_stage", {}).get("context_score"), "final_score": trace.get("context_stage", {}).get("final_score"), - "status": trace.get("context_stage", {}).get("status") or trace.get("vector_stage", {}).get("status") or trace.get("semantic_stage", {}).get("status"), + "status": trace.get("context_stage", {}).get("status") + or trace.get("vector_stage", {}).get("status") + or trace.get("semantic_stage", {}).get("status"), "is_final": memory_id in final_ids_set, } debug_entries.append(entry) # 限制日志输出数量 - debug_entries.sort(key=lambda item: (item.get("is_final", False), item.get("final_score") or item.get("vector_similarity") or 0.0), reverse=True) + debug_entries.sort( + key=lambda item: ( + item.get("is_final", False), + item.get("final_score") or item.get("vector_similarity") or 0.0, + ), + reverse=True, + ) debug_payload = { "query": query, "semantic_query": context.get("resolved_query_text", query), @@ -266,7 +291,7 @@ class MultiStageRetrieval: stage_results=stage_results, total_processing_time=total_time, total_filtered=total_filtered, - retrieval_stats=self.retrieval_stats.copy() + retrieval_stats=self.retrieval_stats.copy(), ) except Exception as e: @@ -279,7 +304,7 @@ class MultiStageRetrieval: stage_results=stage_results, total_processing_time=time.time() - start_time, total_filtered=0, - retrieval_stats=self.retrieval_stats.copy() + retrieval_stats=self.retrieval_stats.copy(), ) async def _metadata_filtering_stage( @@ -290,7 +315,7 @@ class MultiStageRetrieval: metadata_index, all_memories_cache: Dict[str, MemoryChunk], *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None + debug_log: Optional[Dict[str, Dict[str, Any]]] = None, ) -> StageResult: """阶段1:元数据过滤""" start_time = time.time() @@ -302,7 +327,9 @@ class MultiStageRetrieval: memory_types = self._extract_memory_types_from_context(context) keywords = self._extract_keywords_from_query(query, query_plan) - subjects = query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None + subjects = ( + query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None + ) index_query = IndexQuery( user_ids=None, @@ -311,7 +338,7 @@ class MultiStageRetrieval: keywords=keywords, limit=self.config.metadata_filter_limit, sort_by="last_accessed", - sort_order="desc" + sort_order="desc", ) # 执行查询 @@ -328,19 +355,16 @@ class MultiStageRetrieval: reverse=True, ) if memory_types: - type_filtered = [ - mid for mid in sorted_ids - if all_memories_cache[mid].memory_type in memory_types - ] + type_filtered = [mid for mid in sorted_ids if all_memories_cache[mid].memory_type in memory_types] sorted_ids = type_filtered or sorted_ids if subjects: subject_candidates = [s.lower() for s in subjects if isinstance(s, str) and s.strip()] if subject_candidates: subject_filtered = [ - mid for mid in sorted_ids + mid + for mid in sorted_ids if any( - subj.strip().lower() in subject_candidates - for subj in all_memories_cache[mid].subjects + subj.strip().lower() in subject_candidates for subj in all_memories_cache[mid].subjects ) ] sorted_ids = subject_filtered or sorted_ids @@ -367,12 +391,14 @@ class MultiStageRetrieval: bool(subjects), bool(keywords), ) - details.append({ - "note": "fallback_recent", - "requested_types": [mt.value for mt in memory_types] if memory_types else [], - "subjects": subjects or [], - "keywords": keywords or [], - }) + details.append( + { + "note": "fallback_recent", + "requested_types": [mt.value for mt in memory_types] if memory_types else [], + "subjects": subjects or [], + "keywords": keywords or [], + } + ) logger.debug( "元数据过滤:候选=%d, 返回=%d", @@ -419,7 +445,7 @@ class MultiStageRetrieval: candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None + debug_log: Optional[Dict[str, Dict[str, Any]]] = None, ) -> StageResult: """阶段2:向量搜索""" start_time = time.time() @@ -441,8 +467,7 @@ class MultiStageRetrieval: # 执行向量搜索 search_result = await vector_storage.search_similar_memories( - query_vector=query_embedding, - limit=self.config.vector_search_limit + query_vector=query_embedding, limit=self.config.vector_search_limit ) if not search_result: @@ -464,16 +489,18 @@ class MultiStageRetrieval: if in_metadata_candidates and above_threshold: filtered_memories.append((memory_id, similarity)) - raw_details.append({ - "memory_id": memory_id, - "similarity": similarity, - "in_metadata": in_metadata_candidates, - "above_threshold": above_threshold, - }) + raw_details.append( + { + "memory_id": memory_id, + "similarity": similarity, + "in_metadata": in_metadata_candidates, + "above_threshold": above_threshold, + } + ) # 按相似度排序 filtered_memories.sort(key=lambda x: x[1], reverse=True) - result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]] + result_ids = [memory_id for memory_id, _ in filtered_memories[: self.config.vector_search_limit]] kept_ids = set(result_ids) for entry in raw_details: @@ -534,11 +561,7 @@ class MultiStageRetrieval: ) def _create_text_search_fallback( - self, - candidate_ids: Set[str], - all_memories_cache: Dict[str, MemoryChunk], - query_text: str, - start_time: float + self, candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], query_text: str, start_time: float ) -> StageResult: """当向量搜索失败时,使用文本搜索作为回退策略""" try: @@ -561,15 +584,13 @@ class MultiStageRetrieval: # 按匹配度排序 text_matches.sort(key=lambda x: x[1], reverse=True) - result_ids = [memory_id for memory_id, _ in text_matches[:self.config.vector_search_limit]] + result_ids = [memory_id for memory_id, _ in text_matches[: self.config.vector_search_limit]] details = [] - for memory_id, score in text_matches[:self.config.vector_search_limit]: - details.append({ - "memory_id": memory_id, - "text_match_score": round(score, 4), - "status": "text_match_fallback" - }) + for memory_id, score in text_matches[: self.config.vector_search_limit]: + details.append( + {"memory_id": memory_id, "text_match_score": round(score, 4), "status": "text_match_fallback"} + ) logger.debug(f"向量搜索回退到文本匹配:找到 {len(result_ids)} 条匹配记忆") @@ -579,18 +600,18 @@ class MultiStageRetrieval: processing_time=time.time() - start_time, filtered_count=len(candidate_ids) - len(result_ids), score_threshold=0.0, # 文本匹配无严格阈值 - details=details + details=details, ) except Exception as e: logger.error(f"文本搜索回退失败: {e}") return StageResult( stage=RetrievalStage.VECTOR_SEARCH, - memory_ids=list(candidate_ids)[:self.config.vector_search_limit], + memory_ids=list(candidate_ids)[: self.config.vector_search_limit], processing_time=time.time() - start_time, filtered_count=0, score_threshold=0.0, - details=[{"error": str(e), "note": "text_fallback_failed"}] + details=[{"error": str(e), "note": "text_fallback_failed"}], ) async def _semantic_reranking_stage( @@ -601,7 +622,7 @@ class MultiStageRetrieval: candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None + debug_log: Optional[Dict[str, Dict[str, Any]]] = None, ) -> StageResult: """阶段3:语义重排序""" start_time = time.time() @@ -643,7 +664,7 @@ class MultiStageRetrieval: # 按语义相似度排序 reranked_memories.sort(key=lambda x: x[1], reverse=True) - result_ids = [memory_id for memory_id, _ in reranked_memories[:self.config.semantic_rerank_limit]] + result_ids = [memory_id for memory_id, _ in reranked_memories[: self.config.semantic_rerank_limit]] kept_ids = set(result_ids) filtered_count = len(candidate_ids) - len(result_ids) @@ -688,7 +709,7 @@ class MultiStageRetrieval: all_memories_cache: Dict[str, MemoryChunk], limit: int, *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None + debug_log: Optional[Dict[str, Dict[str, Any]]] = None, ) -> StageResult: """阶段4:上下文过滤""" start_time = time.time() @@ -777,7 +798,7 @@ class MultiStageRetrieval: limit: int, *, excluded_ids: Optional[Set[str]] = None, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None + debug_log: Optional[Dict[str, Dict[str, Any]]] = None, ) -> StageResult: """回退检索阶段 - 当主检索失败时使用更宽松的策略""" start_time = time.time() @@ -806,13 +827,15 @@ class MultiStageRetrieval: if not fallback_candidates: logger.debug("关键词匹配无结果,使用时序最近策略") recent_memories = sorted( - [(mid, mem.metadata.last_accessed or mem.metadata.created_at) - for mid, mem in all_memories_cache.items() - if mid not in excluded_ids], + [ + (mid, mem.metadata.last_accessed or mem.metadata.created_at) + for mid, mem in all_memories_cache.items() + if mid not in excluded_ids + ], key=lambda x: x[1], - reverse=True + reverse=True, ) - fallback_candidates = [(mid, 0.5) for mid, _ in recent_memories[:limit*2]] + fallback_candidates = [(mid, 0.5) for mid, _ in recent_memories[: limit * 2]] # 按分数排序 fallback_candidates.sort(key=lambda x: x[1], reverse=True) @@ -857,7 +880,9 @@ class MultiStageRetrieval: details=[{"error": str(e)}], ) - async def _generate_query_embedding(self, query: str, context: Dict[str, Any], vector_storage) -> Optional[List[float]]: + async def _generate_query_embedding( + self, query: str, context: Dict[str, Any], vector_storage + ) -> Optional[List[float]]: """生成查询向量""" try: query_plan = context.get("query_plan") @@ -875,15 +900,15 @@ class MultiStageRetrieval: logger.debug(f"正在生成查询向量,文本: '{query_text[:100]}'") embedding = await vector_storage.generate_query_embedding(query_text) - + if embedding is None: logger.warning("向量存储返回空的查询向量") return None - + if len(embedding) == 0: logger.warning("向量存储返回空列表作为查询向量") return None - + logger.debug(f"查询向量生成成功,维度: {len(embedding)}") return embedding @@ -926,8 +951,8 @@ class MultiStageRetrieval: import re # 分词处理 - query_words = list(jieba.cut(query_text)) + re.findall(r'[a-zA-Z]+', query_text) - memory_words = list(jieba.cut(memory_text)) + re.findall(r'[a-zA-Z]+', memory_text) + query_words = list(jieba.cut(query_text)) + re.findall(r"[a-zA-Z]+", query_text) + memory_words = list(jieba.cut(memory_text)) + re.findall(r"[a-zA-Z]+", memory_text) # 清理和标准化 query_words = [w.strip().lower() for w in query_words if w.strip() and len(w.strip()) > 1] @@ -953,8 +978,9 @@ class MultiStageRetrieval: except ImportError: # 如果jieba不可用,使用简单分词 import re - query_words = re.findall(r'[\w\u4e00-\u9fa5]+', query_lower) - memory_words = re.findall(r'[\w\u4e00-\u9fa5]+', memory_lower) + + query_words = re.findall(r"[\w\u4e00-\u9fa5]+", query_lower) + memory_words = re.findall(r"[\w\u4e00-\u9fa5]+", memory_lower) if query_words and memory_words: query_set = set(w for w in query_words if len(w) > 1) @@ -971,13 +997,19 @@ class MultiStageRetrieval: "天气": ["天气", "阳光", "雨", "晴", "阴", "温度", "weather", "sunny", "rain"], "编程": ["编程", "代码", "程序", "开发", "语言", "programming", "code", "develop", "python"], "时间": ["今天", "昨天", "明天", "现在", "时间", "today", "yesterday", "tomorrow", "time"], - "情感": ["好", "坏", "开心", "难过", "有趣", "good", "bad", "happy", "sad", "fun"] + "情感": ["好", "坏", "开心", "难过", "有趣", "good", "bad", "happy", "sad", "fun"], } - query_concepts = {concept for concept, keywords in concept_groups.items() - if any(keyword in query_lower for keyword in keywords)} - memory_concepts = {concept for concept, keywords in concept_groups.items() - if any(keyword in memory_lower for keyword in keywords)} + query_concepts = { + concept + for concept, keywords in concept_groups.items() + if any(keyword in query_lower for keyword in keywords) + } + memory_concepts = { + concept + for concept, keywords in concept_groups.items() + if any(keyword in memory_lower for keyword in keywords) + } if query_concepts and memory_concepts: concept_overlap = query_concepts & memory_concepts @@ -987,19 +1019,19 @@ class MultiStageRetrieval: plan_bonus = 0.0 if query_plan: # 主体匹配 - if hasattr(query_plan, 'subjects') and query_plan.subjects: + if hasattr(query_plan, "subjects") and query_plan.subjects: for subject in query_plan.subjects: if subject.lower() in memory_lower: plan_bonus += 0.15 # 对象匹配 - if hasattr(query_plan, 'objects') and query_plan.objects: + if hasattr(query_plan, "objects") and query_plan.objects: for obj in query_plan.objects: if obj.lower() in memory_lower: plan_bonus += 0.1 # 记忆类型匹配 - if hasattr(query_plan, 'memory_types') and query_plan.memory_types: + if hasattr(query_plan, "memory_types") and query_plan.memory_types: if memory.memory_type in query_plan.memory_types: plan_bonus += 0.1 @@ -1059,14 +1091,22 @@ class MultiStageRetrieval: object_keywords = getattr(query_plan, "object_includes", []) or [] if object_keywords: display_text = (memory.display or memory.text_content or "").lower() - hits = sum(1 for kw in object_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text) + hits = sum( + 1 + for kw in object_keywords + if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text + ) if hits: score += min(0.3, hits * 0.1) optional_keywords = getattr(query_plan, "optional_keywords", []) or [] if optional_keywords: display_text = (memory.display or memory.text_content or "").lower() - hits = sum(1 for kw in optional_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text) + hits = sum( + 1 + for kw in optional_keywords + if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text + ) if hits: score += min(0.2, hits * 0.05) @@ -1091,7 +1131,9 @@ class MultiStageRetrieval: logger.warning(f"计算上下文相关度失败: {e}") return 0.0 - async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float: + async def _calculate_final_score( + self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float + ) -> float: """计算最终评分""" try: query_plan = context.get("query_plan") @@ -1126,10 +1168,10 @@ class MultiStageRetrieval: context_weight += 0.05 final_score = ( - semantic_score * semantic_weight + - vector_score * vector_weight + - context_score * context_weight + - recency_score * recency_weight + semantic_score * semantic_weight + + vector_score * vector_weight + + context_score * context_weight + + recency_score * recency_weight ) # 加入记忆重要性权重 @@ -1259,7 +1301,9 @@ class MultiStageRetrieval: stage_stat["calls"] += 1 current_stage_avg = stage_stat["avg_time"] - new_stage_avg = (current_stage_avg * (stage_stat["calls"] - 1) + result.processing_time) / stage_stat["calls"] + new_stage_avg = (current_stage_avg * (stage_stat["calls"] - 1) + result.processing_time) / stage_stat[ + "calls" + ] stage_stat["avg_time"] = new_stage_avg def get_retrieval_stats(self) -> Dict[str, Any]: @@ -1276,8 +1320,8 @@ class MultiStageRetrieval: "vector_search": {"calls": 0, "avg_time": 0.0}, "semantic_reranking": {"calls": 0, "avg_time": 0.0}, "contextual_filtering": {"calls": 0, "avg_time": 0.0}, - "enhanced_reranking": {"calls": 0, "avg_time": 0.0} - } + "enhanced_reranking": {"calls": 0, "avg_time": 0.0}, + }, } async def _enhanced_reranking_stage( @@ -1289,7 +1333,7 @@ class MultiStageRetrieval: all_memories_cache: Dict[str, MemoryChunk], limit: int, *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None + debug_log: Optional[Dict[str, Dict[str, Any]]] = None, ) -> StageResult: """阶段5:增强重排序 - 使用多维度评分模型""" start_time = time.time() @@ -1326,15 +1370,12 @@ class MultiStageRetrieval: # 使用增强重排序器 reranked_memories = self.reranker.rerank_memories( - query=query, - candidate_memories=candidate_memories, - context=context, - limit=limit + query=query, candidate_memories=candidate_memories, context=context, limit=limit ) # 提取重排序后的记忆ID result_ids = [memory_id for memory_id, _, _ in reranked_memories] - + # 生成调试详情 details = [] for memory_id, memory, final_score in reranked_memories: @@ -1346,7 +1387,7 @@ class MultiStageRetrieval: "access_count": memory.metadata.access_count, } details.append(detail_entry) - + if debug_log is not None: stage_entry = debug_log.setdefault(memory_id, {}).setdefault("enhanced_rerank_stage", {}) stage_entry["final_score"] = round(final_score, 4) @@ -1357,13 +1398,9 @@ class MultiStageRetrieval: kept_ids = set(result_ids) for memory_id in candidate_ids: if memory_id not in kept_ids: - detail_entry = { - "memory_id": memory_id, - "status": "filtered_out", - "reason": "ranked_below_limit" - } + detail_entry = {"memory_id": memory_id, "status": "filtered_out", "reason": "ranked_below_limit"} details.append(detail_entry) - + if debug_log is not None: stage_entry = debug_log.setdefault(memory_id, {}).setdefault("enhanced_rerank_stage", {}) stage_entry["status"] = "filtered_out" @@ -1371,10 +1408,7 @@ class MultiStageRetrieval: filtered_count = len(candidate_ids) - len(result_ids) - logger.debug( - f"增强重排序完成:候选={len(candidate_ids)}, 返回={len(result_ids)}, " - f"过滤={filtered_count}" - ) + logger.debug(f"增强重排序完成:候选={len(candidate_ids)}, 返回={len(result_ids)}, 过滤={filtered_count}") return StageResult( stage=RetrievalStage.CONTEXTUAL_FILTERING, # 保持与原有枚举兼容 @@ -1394,4 +1428,4 @@ class MultiStageRetrieval: filtered_count=0, score_threshold=0.0, details=[{"error": str(e)}], - ) \ No newline at end of file + ) diff --git a/src/chat/memory_system/deprecated_backup/vector_storage.py b/src/chat/memory_system/deprecated_backup/vector_storage.py index 73ddbb6a6..5d2e4fb91 100644 --- a/src/chat/memory_system/deprecated_backup/vector_storage.py +++ b/src/chat/memory_system/deprecated_backup/vector_storage.py @@ -4,22 +4,19 @@ 为记忆系统提供高效的向量存储和语义搜索能力 """ -import os import time import orjson import asyncio -from typing import Dict, List, Optional, Tuple, Set, Any +from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass -from datetime import datetime import threading import numpy as np -import pandas as pd from pathlib import Path from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config, global_config +from src.config.config import model_config from src.common.config_helpers import resolve_embedding_dimension from src.chat.memory_system.memory_chunk import MemoryChunk @@ -28,6 +25,7 @@ logger = get_logger(__name__) # 尝试导入FAISS,如果不可用则使用简单替代 try: import faiss + FAISS_AVAILABLE = True except ImportError: FAISS_AVAILABLE = False @@ -37,6 +35,7 @@ except ImportError: @dataclass class VectorStorageConfig: """向量存储配置""" + dimension: int = 1024 similarity_threshold: float = 0.8 index_type: str = "flat" # flat, ivf, hnsw @@ -79,7 +78,7 @@ class VectorStorageManager: "average_search_time": 0.0, "cache_hit_rate": 0.0, "total_searches": 0, - "cache_hits": 0 + "cache_hits": 0, } # 线程锁 @@ -122,8 +121,7 @@ class VectorStorageManager: """初始化嵌入模型""" if self.embedding_model is None: self.embedding_model = LLMRequest( - model_set=model_config.model_task_config.embedding, - request_type="memory.embedding" + model_set=model_config.model_task_config.embedding, request_type="memory.embedding" ) logger.info("✅ 嵌入模型初始化完成") @@ -137,20 +135,16 @@ class VectorStorageManager: await self.initialize_embedding_model() logger.debug(f"开始生成查询向量,文本: '{query_text[:50]}{'...' if len(query_text) > 50 else ''}'") - + embedding, _ = await self.embedding_model.get_embedding(query_text) if not embedding: logger.warning("嵌入模型返回空向量") return None logger.debug(f"生成的向量维度: {len(embedding)}, 期望维度: {self.config.dimension}") - + if len(embedding) != self.config.dimension: - logger.error( - "查询向量维度不匹配: 期望 %d, 实际 %d", - self.config.dimension, - len(embedding) - ) + logger.error("查询向量维度不匹配: 期望 %d, 实际 %d", self.config.dimension, len(embedding)) return None normalized_vector = self._normalize_vector(embedding) @@ -287,7 +281,7 @@ class VectorStorageManager: logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc) results[memory_id] = [] - tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts)] + tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)] await asyncio.gather(*tasks, return_exceptions=True) except Exception as e: @@ -313,12 +307,12 @@ class VectorStorageManager: memory.set_embedding(embedding) # 添加到向量索引 - if hasattr(self.vector_index, 'add'): + if hasattr(self.vector_index, "add"): # FAISS索引 if isinstance(embedding, np.ndarray): - vector_array = embedding.reshape(1, -1).astype('float32') + vector_array = embedding.reshape(1, -1).astype("float32") else: - vector_array = np.array([embedding], dtype='float32') + vector_array = np.array([embedding], dtype="float32") # 特殊处理IVF索引 if self.config.index_type == "ivf" and self.vector_index.ntotal == 0: @@ -367,14 +361,14 @@ class VectorStorageManager: *, query_text: Optional[str] = None, limit: int = 10, - scope_id: Optional[str] = None + scope_id: Optional[str] = None, ) -> List[Tuple[str, float]]: """搜索相似记忆""" start_time = time.time() try: logger.debug(f"开始向量搜索: query_text='{query_text[:30] if query_text else 'None'}', limit={limit}") - + if query_vector is None: if not query_text: logger.warning("查询向量和查询文本都为空") @@ -395,34 +389,34 @@ class VectorStorageManager: # 规范化查询向量 query_vector = self._normalize_vector(query_vector) - + logger.debug(f"查询向量维度: {len(query_vector)}, 存储总向量数: {self.storage_stats['total_vectors']}") # 检查向量索引状态 if not self.vector_index: logger.error("向量索引未初始化") return [] - + total_vectors = 0 - if hasattr(self.vector_index, 'ntotal'): + if hasattr(self.vector_index, "ntotal"): total_vectors = self.vector_index.ntotal - elif hasattr(self.vector_index, 'vectors'): + elif hasattr(self.vector_index, "vectors"): total_vectors = len(self.vector_index.vectors) - + logger.debug(f"向量索引中实际向量数: {total_vectors}") - + if total_vectors == 0: logger.warning("向量索引为空,无法执行搜索") return [] # 执行向量搜索 with self._lock: - if hasattr(self.vector_index, 'search'): + if hasattr(self.vector_index, "search"): # FAISS索引 if isinstance(query_vector, np.ndarray): - query_array = query_vector.reshape(1, -1).astype('float32') + query_array = query_vector.reshape(1, -1).astype("float32") else: - query_array = np.array([query_vector], dtype='float32') + query_array = np.array([query_vector], dtype="float32") if self.config.index_type == "ivf" and self.vector_index.ntotal > 0: # 设置IVF搜索参数 @@ -432,11 +426,11 @@ class VectorStorageManager: search_limit = min(limit, total_vectors) logger.debug(f"执行FAISS搜索,搜索限制: {search_limit}") - + distances, indices = self.vector_index.search(query_array, search_limit) distances = distances.flatten().tolist() indices = indices.flatten().tolist() - + logger.debug(f"FAISS搜索结果: {len(distances)} 个距离值, {len(indices)} 个索引") else: # 简单索引 @@ -451,8 +445,8 @@ class VectorStorageManager: valid_results = 0 invalid_indices = 0 filtered_by_scope = 0 - - for distance, index in zip(distances, indices): + + for distance, index in zip(distances, indices, strict=False): if index == -1: # FAISS的无效索引标记 invalid_indices += 1 continue @@ -462,7 +456,7 @@ class VectorStorageManager: logger.debug(f"索引 {index} 没有对应的记忆ID") invalid_indices += 1 continue - + if scope_filter: memory = self.memory_cache.get(memory_id) if memory and str(memory.user_id) != scope_filter: @@ -482,16 +476,15 @@ class VectorStorageManager: search_time = time.time() - start_time self.storage_stats["total_searches"] += 1 self.storage_stats["average_search_time"] = ( - (self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time) / - self.storage_stats["total_searches"] - ) + self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time + ) / self.storage_stats["total_searches"] final_results = results[:limit] logger.info( f"向量搜索完成: 查询='{query_text[:20] if query_text else 'vector'}' " f"耗时={search_time:.3f}s, 返回={len(final_results)}个结果" ) - + return final_results except Exception as e: @@ -520,7 +513,7 @@ class VectorStorageManager: old_index = self.memory_id_to_index[memory_id] # 删除旧向量(如果支持) - if hasattr(self.vector_index, 'remove_ids'): + if hasattr(self.vector_index, "remove_ids"): try: self.vector_index.remove_ids(np.array([old_index])) except: @@ -530,11 +523,11 @@ class VectorStorageManager: new_embedding = self._normalize_vector(new_embedding) # 添加新向量 - if hasattr(self.vector_index, 'add'): + if hasattr(self.vector_index, "add"): if isinstance(new_embedding, np.ndarray): - vector_array = new_embedding.reshape(1, -1).astype('float32') + vector_array = new_embedding.reshape(1, -1).astype("float32") else: - vector_array = np.array([new_embedding], dtype='float32') + vector_array = np.array([new_embedding], dtype="float32") self.vector_index.add(vector_array) new_index = self.vector_index.ntotal - 1 @@ -569,7 +562,7 @@ class VectorStorageManager: index = self.memory_id_to_index[memory_id] # 从向量索引中删除(如果支持) - if hasattr(self.vector_index, 'remove_ids'): + if hasattr(self.vector_index, "remove_ids"): try: self.vector_index.remove_ids(np.array([index])) except: @@ -598,44 +591,37 @@ class VectorStorageManager: logger.info("正在保存向量存储...") # 保存记忆缓存 - cache_data = { - memory_id: memory.to_dict() - for memory_id, memory in self.memory_cache.items() - } + cache_data = {memory_id: memory.to_dict() for memory_id, memory in self.memory_cache.items()} cache_file = self.storage_path / "memory_cache.json" - with open(cache_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(cache_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存向量缓存 vector_cache_file = self.storage_path / "vector_cache.json" - with open(vector_cache_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(vector_cache_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存映射关系 mapping_file = self.storage_path / "id_mapping.json" mapping_data = { "memory_id_to_index": { - str(memory_id): int(index) - for memory_id, index in self.memory_id_to_index.items() + str(memory_id): int(index) for memory_id, index in self.memory_id_to_index.items() }, - "index_to_memory_id": { - str(index): memory_id - for index, memory_id in self.index_to_memory_id.items() - } + "index_to_memory_id": {str(index): memory_id for index, memory_id in self.index_to_memory_id.items()}, } - with open(mapping_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(mapping_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode("utf-8")) # 保存FAISS索引(如果可用) - if FAISS_AVAILABLE and hasattr(self.vector_index, 'save'): + if FAISS_AVAILABLE and hasattr(self.vector_index, "save"): index_file = self.storage_path / "vector_index.faiss" faiss.write_index(self.vector_index, str(index_file)) # 保存统计信息 stats_file = self.storage_path / "storage_stats.json" - with open(stats_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode('utf-8')) + with open(stats_file, "w", encoding="utf-8") as f: + f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode("utf-8")) logger.info("✅ 向量存储保存完成") @@ -650,36 +636,31 @@ class VectorStorageManager: # 加载记忆缓存 cache_file = self.storage_path / "memory_cache.json" if cache_file.exists(): - with open(cache_file, 'r', encoding='utf-8') as f: + with open(cache_file, "r", encoding="utf-8") as f: cache_data = orjson.loads(f.read()) self.memory_cache = { - memory_id: MemoryChunk.from_dict(memory_data) - for memory_id, memory_data in cache_data.items() + memory_id: MemoryChunk.from_dict(memory_data) for memory_id, memory_data in cache_data.items() } # 加载向量缓存 vector_cache_file = self.storage_path / "vector_cache.json" if vector_cache_file.exists(): - with open(vector_cache_file, 'r', encoding='utf-8') as f: + with open(vector_cache_file, "r", encoding="utf-8") as f: self.vector_cache = orjson.loads(f.read()) # 加载映射关系 mapping_file = self.storage_path / "id_mapping.json" if mapping_file.exists(): - with open(mapping_file, 'r', encoding='utf-8') as f: + with open(mapping_file, "r", encoding="utf-8") as f: mapping_data = orjson.loads(f.read()) raw_memory_to_index = mapping_data.get("memory_id_to_index", {}) self.memory_id_to_index = { - str(memory_id): int(index) - for memory_id, index in raw_memory_to_index.items() + str(memory_id): int(index) for memory_id, index in raw_memory_to_index.items() } raw_index_to_memory = mapping_data.get("index_to_memory_id", {}) - self.index_to_memory_id = { - int(index): memory_id - for index, memory_id in raw_index_to_memory.items() - } + self.index_to_memory_id = {int(index): memory_id for index, memory_id in raw_index_to_memory.items()} # 加载FAISS索引(如果可用) index_loaded = False @@ -699,7 +680,7 @@ class VectorStorageManager: logger.warning(f"加载FAISS索引失败: {e},重新构建") else: logger.info("FAISS索引文件不存在,将重新构建") - + # 如果索引没有成功加载且有向量数据,则重建索引 if not index_loaded and self.vector_cache: logger.info(f"检测到 {len(self.vector_cache)} 个向量缓存,重建索引") @@ -708,7 +689,7 @@ class VectorStorageManager: # 加载统计信息 stats_file = self.storage_path / "storage_stats.json" if stats_file.exists(): - with open(stats_file, 'r', encoding='utf-8') as f: + with open(stats_file, "r", encoding="utf-8") as f: self.storage_stats = orjson.loads(f.read()) # 更新向量计数 @@ -738,7 +719,7 @@ class VectorStorageManager: # 准备向量数据 memory_ids = [] vectors = [] - + for memory_id, embedding in self.vector_cache.items(): if embedding and len(embedding) == self.config.dimension: memory_ids.append(memory_id) @@ -753,18 +734,18 @@ class VectorStorageManager: logger.info(f"准备重建 {len(vectors)} 个向量到索引") # 批量添加向量到FAISS索引 - if hasattr(self.vector_index, 'add'): + if hasattr(self.vector_index, "add"): # FAISS索引 - vector_array = np.array(vectors, dtype='float32') - + vector_array = np.array(vectors, dtype="float32") + # 特殊处理IVF索引 - if self.config.index_type == "ivf" and hasattr(self.vector_index, 'train'): + if self.config.index_type == "ivf" and hasattr(self.vector_index, "train"): logger.info("训练IVF索引...") self.vector_index.train(vector_array) # 添加向量 self.vector_index.add(vector_array) - + # 重建映射关系 for i, memory_id in enumerate(memory_ids): self.memory_id_to_index[memory_id] = i @@ -772,15 +753,15 @@ class VectorStorageManager: else: # 简单索引 - for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors)): + for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors, strict=False)): index_id = self.vector_index.add_vector(vector) self.memory_id_to_index[memory_id] = index_id self.index_to_memory_id[index_id] = memory_id # 更新统计 self.storage_stats["total_vectors"] = len(self.memory_id_to_index) - - final_count = getattr(self.vector_index, 'ntotal', len(self.memory_id_to_index)) + + final_count = getattr(self.vector_index, "ntotal", len(self.memory_id_to_index)) logger.info(f"✅ 向量索引重建完成,索引中向量数: {final_count}") except Exception as e: @@ -875,7 +856,7 @@ class SimpleVectorIndex: def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float: """计算余弦相似度""" try: - dot_product = sum(x * y for x, y in zip(v1, v2)) + dot_product = sum(x * y for x, y in zip(v1, v2, strict=False)) norm1 = sum(x * x for x in v1) ** 0.5 norm2 = sum(x * x for x in v2) ** 0.5 @@ -890,4 +871,4 @@ class SimpleVectorIndex: @property def ntotal(self) -> int: """向量总数""" - return len(self.vectors) \ No newline at end of file + return len(self.vectors) diff --git a/src/chat/memory_system/enhanced_memory_activator.py b/src/chat/memory_system/enhanced_memory_activator.py index a9c5fc9cb..7570715ee 100644 --- a/src/chat/memory_system/enhanced_memory_activator.py +++ b/src/chat/memory_system/enhanced_memory_activator.py @@ -6,7 +6,6 @@ import difflib import orjson -import time from typing import List, Dict, Optional from datetime import datetime @@ -15,7 +14,7 @@ from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.common.logger import get_logger from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.memory_system.memory_manager import memory_manager, MemoryResult +from src.chat.memory_system.memory_manager import MemoryResult logger = get_logger("memory_activator") @@ -127,8 +126,8 @@ class MemoryActivator: for result in memory_results: # 检查是否已存在相似内容的记忆 exists = any( - m["content"] == result.content or - difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7 + m["content"] == result.content + or difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7 for m in self.running_memory ) if not exists: @@ -140,7 +139,7 @@ class MemoryActivator: "confidence": result.confidence, "importance": result.importance, "source": result.source, - "relevance_score": result.relevance_score # 添加相关度评分 + "relevance_score": result.relevance_score, # 添加相关度评分 } self.running_memory.append(memory_entry) logger.debug(f"添加新记忆: {result.memory_type} - {result.content}") @@ -168,17 +167,14 @@ class MemoryActivator: return [] # 构建查询上下文 - context = { - "keywords": keywords, - "query_intent": "conversation_response" - } + context = {"keywords": keywords, "query_intent": "conversation_response"} # 查询记忆 memories = await memory_system.retrieve_relevant_memories( query_text=query_text, user_id="global", # 使用全局作用域 context=context, - limit=5 + limit=5, ) # 转换为 MemoryResult 格式 @@ -191,7 +187,7 @@ class MemoryActivator: importance=memory.metadata.importance.value, timestamp=memory.metadata.created_at, source="unified_memory", - relevance_score=memory.metadata.relevance_score + relevance_score=memory.metadata.relevance_score, ) memory_results.append(result) @@ -214,16 +210,10 @@ class MemoryActivator: if not memory_system or memory_system.status.value != "ready": return None - context = { - "query_intent": "instant_response", - "chat_id": chat_id - } + context = {"query_intent": "instant_response", "chat_id": chat_id} memories = await memory_system.retrieve_relevant_memories( - query_text=target_message, - user_id="global", - context=context, - limit=1 + query_text=target_message, user_id="global", context=context, limit=1 ) if memories: @@ -248,4 +238,4 @@ memory_activator = MemoryActivator() # 兼容性别名 enhanced_memory_activator = memory_activator -init_prompt() \ No newline at end of file +init_prompt() diff --git a/src/chat/memory_system/memory_activator_new.py b/src/chat/memory_system/memory_activator_new.py index e11e66e25..491034de4 100644 --- a/src/chat/memory_system/memory_activator_new.py +++ b/src/chat/memory_system/memory_activator_new.py @@ -6,7 +6,6 @@ import difflib import orjson -import time from typing import List, Dict, Optional from datetime import datetime @@ -15,7 +14,7 @@ from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.common.logger import get_logger from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.memory_system.memory_manager import memory_manager, MemoryResult +from src.chat.memory_system.memory_manager import MemoryResult logger = get_logger("memory_activator") @@ -127,8 +126,8 @@ class MemoryActivator: for result in memory_results: # 检查是否已存在相似内容的记忆 exists = any( - m["content"] == result.content or - difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7 + m["content"] == result.content + or difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7 for m in self.running_memory ) if not exists: @@ -140,7 +139,7 @@ class MemoryActivator: "confidence": result.confidence, "importance": result.importance, "source": result.source, - "relevance_score": result.relevance_score # 添加相关度评分 + "relevance_score": result.relevance_score, # 添加相关度评分 } self.running_memory.append(memory_entry) logger.debug(f"添加新记忆: {result.memory_type} - {result.content}") @@ -168,17 +167,14 @@ class MemoryActivator: return [] # 构建查询上下文 - context = { - "keywords": keywords, - "query_intent": "conversation_response" - } + context = {"keywords": keywords, "query_intent": "conversation_response"} # 查询记忆 memories = await memory_system.retrieve_relevant_memories( query_text=query_text, user_id="global", # 使用全局作用域 context=context, - limit=5 + limit=5, ) # 转换为 MemoryResult 格式 @@ -191,7 +187,7 @@ class MemoryActivator: importance=memory.metadata.importance.value, timestamp=memory.metadata.created_at, source="unified_memory", - relevance_score=memory.metadata.relevance_score + relevance_score=memory.metadata.relevance_score, ) memory_results.append(result) @@ -214,16 +210,10 @@ class MemoryActivator: if not memory_system or memory_system.status.value != "ready": return None - context = { - "query_intent": "instant_response", - "chat_id": chat_id - } + context = {"query_intent": "instant_response", "chat_id": chat_id} memories = await memory_system.retrieve_relevant_memories( - query_text=target_message, - user_id="global", - context=context, - limit=1 + query_text=target_message, user_id="global", context=context, limit=1 ) if memories: @@ -246,4 +236,4 @@ class MemoryActivator: memory_activator = MemoryActivator() -init_prompt() \ No newline at end of file +init_prompt() diff --git a/src/chat/memory_system/memory_builder.py b/src/chat/memory_system/memory_builder.py index 5937b491d..0c3f47043 100644 --- a/src/chat/memory_system/memory_builder.py +++ b/src/chat/memory_system/memory_builder.py @@ -33,7 +33,7 @@ import time from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Union, Type +from typing import Any, Dict, List, Optional, Union, Type import orjson @@ -53,14 +53,15 @@ logger = get_logger(__name__) class ExtractionStrategy(Enum): """提取策略""" - LLM_BASED = "llm_based" # 基于LLM的智能提取 - RULE_BASED = "rule_based" # 基于规则的提取 - HYBRID = "hybrid" # 混合策略 + LLM_BASED = "llm_based" # 基于LLM的智能提取 + RULE_BASED = "rule_based" # 基于规则的提取 + HYBRID = "hybrid" # 混合策略 @dataclass class ExtractionResult: """提取结果""" + memories: List[MemoryChunk] confidence_scores: List[float] extraction_time: float @@ -80,15 +81,11 @@ class MemoryBuilder: "total_extractions": 0, "successful_extractions": 0, "failed_extractions": 0, - "average_confidence": 0.0 + "average_confidence": 0.0, } async def build_memories( - self, - conversation_text: str, - context: Dict[str, Any], - user_id: str, - timestamp: float + self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: float ) -> List[MemoryChunk]: """从对话中构建记忆""" start_time = time.time() @@ -119,19 +116,13 @@ class MemoryBuilder: raise async def _extract_with_llm( - self, - text: str, - context: Dict[str, Any], - user_id: str, - timestamp: float + self, text: str, context: Dict[str, Any], user_id: str, timestamp: float ) -> List[MemoryChunk]: """使用LLM提取记忆""" try: prompt = self._build_llm_extraction_prompt(text, context) - response, _ = await self.llm_model.generate_response_async( - prompt, temperature=0.3 - ) + response, _ = await self.llm_model.generate_response_async(prompt, temperature=0.3) # 解析LLM响应 memories = self._parse_llm_response(response, user_id, timestamp, context) @@ -342,16 +333,12 @@ class MemoryBuilder: start = stripped.find("{") end = stripped.rfind("}") if start != -1 and end != -1 and end > start: - return stripped[start:end + 1].strip() + return stripped[start : end + 1].strip() return stripped if stripped.startswith("{") and stripped.endswith("}") else None def _parse_llm_response( - self, - response: str, - user_id: str, - timestamp: float, - context: Dict[str, Any] + self, response: str, user_id: str, timestamp: float, context: Dict[str, Any] ) -> List[MemoryChunk]: """解析LLM响应""" if not response: @@ -366,9 +353,7 @@ class MemoryBuilder: data = orjson.loads(json_payload) except Exception as e: preview = json_payload[:200] - raise MemoryExtractionError( - f"LLM响应JSON解析失败: {e}, 片段: {preview}" - ) from e + raise MemoryExtractionError(f"LLM响应JSON解析失败: {e}, 片段: {preview}") from e memory_list = data.get("memories", []) @@ -406,17 +391,15 @@ class MemoryBuilder: try: # 检查是否包含模糊代称 display_text = mem_data.get("display", "") - if any(ambiguous_term in display_text for ambiguous_term in ["用户", "user", "the user", "对方", "对手"]): + if any( + ambiguous_term in display_text for ambiguous_term in ["用户", "user", "the user", "对方", "对手"] + ): logger.debug(f"拒绝构建包含模糊代称的记忆,display字段: {display_text}") continue subject_value = mem_data.get("subject") normalized_subject = self._normalize_subjects( - subject_value, - bot_identifiers, - system_identifiers, - default_subjects, - bot_display + subject_value, bot_identifiers, system_identifiers, default_subjects, bot_display ) if not normalized_subject: @@ -425,17 +408,11 @@ class MemoryBuilder: # 创建记忆块 importance_level = self._parse_enum_value( - ImportanceLevel, - mem_data.get("importance"), - ImportanceLevel.NORMAL, - "importance" + ImportanceLevel, mem_data.get("importance"), ImportanceLevel.NORMAL, "importance" ) confidence_level = self._parse_enum_value( - ConfidenceLevel, - mem_data.get("confidence"), - ConfidenceLevel.MEDIUM, - "confidence" + ConfidenceLevel, mem_data.get("confidence"), ConfidenceLevel.MEDIUM, "confidence" ) predicate_value = mem_data.get("predicate", "") @@ -457,7 +434,7 @@ class MemoryBuilder: source_context=mem_data.get("reasoning", ""), importance=importance_level, confidence=confidence_level, - display=display_text + display=display_text, ) if used_fallback_display: @@ -483,13 +460,7 @@ class MemoryBuilder: return memories - def _parse_enum_value( - self, - enum_cls: Type[Enum], - raw_value: Any, - default: Enum, - field_name: str - ) -> Enum: + def _parse_enum_value(self, enum_cls: Type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum: """解析枚举值,兼容数字/字符串表示""" if isinstance(raw_value, enum_cls): return raw_value @@ -533,12 +504,14 @@ class MemoryBuilder: try: return enum_cls(raw_value) except Exception: - logger.debug("%s=%s 类型 %s 无法解析为 %s,使用默认值 %s", - field_name, - raw_value, - type(raw_value).__name__, - enum_cls.__name__, - default.name) + logger.debug( + "%s=%s 类型 %s 无法解析为 %s,使用默认值 %s", + field_name, + raw_value, + type(raw_value).__name__, + enum_cls.__name__, + default.name, + ) return default def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]: @@ -606,7 +579,7 @@ class MemoryBuilder: "members", "member_names", "mention_users", - "audiences" + "audiences", ] for key in candidate_keys: @@ -727,7 +700,7 @@ class MemoryBuilder: bot_identifiers: set[str], system_identifiers: set[str], default_subjects: List[str], - bot_display: Optional[str] = None + bot_display: Optional[str] = None, ) -> List[str]: defaults = default_subjects or ["对话参与者"] @@ -800,7 +773,9 @@ class MemoryBuilder: return obj.strip() or None return None - def _compose_display_text(self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]) -> str: + def _compose_display_text( + self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]] + ) -> str: subject_phrase = "、".join(subjects) if subjects else "对话参与者" predicate = (predicate or "").strip() @@ -866,11 +841,7 @@ class MemoryBuilder: return f"{subject_phrase}{predicate}".strip() return subject_phrase - def _validate_and_enhance_memories( - self, - memories: List[MemoryChunk], - context: Dict[str, Any] - ) -> List[MemoryChunk]: + def _validate_and_enhance_memories(self, memories: List[MemoryChunk], context: Dict[str, Any]) -> List[MemoryChunk]: """验证和增强记忆""" validated_memories = [] @@ -905,11 +876,7 @@ class MemoryBuilder: return True - def _enhance_memory( - self, - memory: MemoryChunk, - context: Dict[str, Any] - ) -> MemoryChunk: + def _enhance_memory(self, memory: MemoryChunk, context: Dict[str, Any]) -> MemoryChunk: """增强记忆块""" # 时间规范化处理 self._normalize_time_in_memory(memory) @@ -919,7 +886,7 @@ class MemoryBuilder: memory.temporal_context = { "timestamp": memory.metadata.created_at, "timezone": context.get("timezone", "UTC"), - "day_of_week": datetime.fromtimestamp(memory.metadata.created_at).strftime("%A") + "day_of_week": datetime.fromtimestamp(memory.metadata.created_at).strftime("%A"), } # 添加情感上下文(如果有) @@ -941,22 +908,22 @@ class MemoryBuilder: # 定义相对时间映射 relative_time_patterns = { - r'今天|今日': current_time.strftime('%Y-%m-%d'), - r'昨天|昨日': (current_time - timedelta(days=1)).strftime('%Y-%m-%d'), - r'明天|明日': (current_time + timedelta(days=1)).strftime('%Y-%m-%d'), - r'后天': (current_time + timedelta(days=2)).strftime('%Y-%m-%d'), - r'大后天': (current_time + timedelta(days=3)).strftime('%Y-%m-%d'), - r'前天': (current_time - timedelta(days=2)).strftime('%Y-%m-%d'), - r'大前天': (current_time - timedelta(days=3)).strftime('%Y-%m-%d'), - r'本周|这周|这星期': current_time.strftime('%Y-%m-%d'), - r'上周|上星期': (current_time - timedelta(weeks=1)).strftime('%Y-%m-%d'), - r'下周|下星期': (current_time + timedelta(weeks=1)).strftime('%Y-%m-%d'), - r'本月|这个月': current_time.strftime('%Y-%m-01'), - r'上月|上个月': (current_time.replace(day=1) - timedelta(days=1)).strftime('%Y-%m-01'), - r'下月|下个月': (current_time.replace(day=1) + timedelta(days=32)).replace(day=1).strftime('%Y-%m-01'), - r'今年|今年': current_time.strftime('%Y'), - r'去年|上一年': str(current_time.year - 1), - r'明年|下一年': str(current_time.year + 1), + r"今天|今日": current_time.strftime("%Y-%m-%d"), + r"昨天|昨日": (current_time - timedelta(days=1)).strftime("%Y-%m-%d"), + r"明天|明日": (current_time + timedelta(days=1)).strftime("%Y-%m-%d"), + r"后天": (current_time + timedelta(days=2)).strftime("%Y-%m-%d"), + r"大后天": (current_time + timedelta(days=3)).strftime("%Y-%m-%d"), + r"前天": (current_time - timedelta(days=2)).strftime("%Y-%m-%d"), + r"大前天": (current_time - timedelta(days=3)).strftime("%Y-%m-%d"), + r"本周|这周|这星期": current_time.strftime("%Y-%m-%d"), + r"上周|上星期": (current_time - timedelta(weeks=1)).strftime("%Y-%m-%d"), + r"下周|下星期": (current_time + timedelta(weeks=1)).strftime("%Y-%m-%d"), + r"本月|这个月": current_time.strftime("%Y-%m-01"), + r"上月|上个月": (current_time.replace(day=1) - timedelta(days=1)).strftime("%Y-%m-01"), + r"下月|下个月": (current_time.replace(day=1) + timedelta(days=32)).replace(day=1).strftime("%Y-%m-01"), + r"今年|今年": current_time.strftime("%Y"), + r"去年|上一年": str(current_time.year - 1), + r"明年|下一年": str(current_time.year + 1), } def _normalize_value(value): @@ -1009,10 +976,14 @@ class MemoryBuilder: # 更新平均置信度 if self.extraction_stats["successful_extractions"] > 0: - total_confidence = self.extraction_stats["average_confidence"] * (self.extraction_stats["successful_extractions"] - success_count) + total_confidence = self.extraction_stats["average_confidence"] * ( + self.extraction_stats["successful_extractions"] - success_count + ) # 假设新记忆的平均置信度为0.8 total_confidence += 0.8 * success_count - self.extraction_stats["average_confidence"] = total_confidence / self.extraction_stats["successful_extractions"] + self.extraction_stats["average_confidence"] = ( + total_confidence / self.extraction_stats["successful_extractions"] + ) def get_extraction_stats(self) -> Dict[str, Any]: """获取提取统计信息""" @@ -1024,5 +995,5 @@ class MemoryBuilder: "total_extractions": 0, "successful_extractions": 0, "failed_extractions": 0, - "average_confidence": 0.0 - } \ No newline at end of file + "average_confidence": 0.0, + } diff --git a/src/chat/memory_system/memory_chunk.py b/src/chat/memory_system/memory_chunk.py index 0ddf6bd0f..b5b609af6 100644 --- a/src/chat/memory_system/memory_chunk.py +++ b/src/chat/memory_system/memory_chunk.py @@ -8,8 +8,7 @@ import time import uuid import orjson from typing import Dict, List, Optional, Any, Union, Iterable -from dataclasses import dataclass, field, asdict -from datetime import datetime +from dataclasses import dataclass, field from enum import Enum import hashlib @@ -21,33 +20,36 @@ logger = get_logger(__name__) class MemoryType(Enum): """记忆类型分类""" - PERSONAL_FACT = "personal_fact" # 个人事实(姓名、职业、住址等) - EVENT = "event" # 事件(重要经历、约会等) - PREFERENCE = "preference" # 偏好(喜好、习惯等) - OPINION = "opinion" # 观点(对事物的看法) - RELATIONSHIP = "relationship" # 关系(与他人的关系) - EMOTION = "emotion" # 情感状态 - KNOWLEDGE = "knowledge" # 知识信息 - SKILL = "skill" # 技能能力 - GOAL = "goal" # 目标计划 - EXPERIENCE = "experience" # 经验教训 - CONTEXTUAL = "contextual" # 上下文信息 + + PERSONAL_FACT = "personal_fact" # 个人事实(姓名、职业、住址等) + EVENT = "event" # 事件(重要经历、约会等) + PREFERENCE = "preference" # 偏好(喜好、习惯等) + OPINION = "opinion" # 观点(对事物的看法) + RELATIONSHIP = "relationship" # 关系(与他人的关系) + EMOTION = "emotion" # 情感状态 + KNOWLEDGE = "knowledge" # 知识信息 + SKILL = "skill" # 技能能力 + GOAL = "goal" # 目标计划 + EXPERIENCE = "experience" # 经验教训 + CONTEXTUAL = "contextual" # 上下文信息 class ConfidenceLevel(Enum): """置信度等级""" - LOW = 1 # 低置信度,可能不准确 - MEDIUM = 2 # 中等置信度,有一定依据 - HIGH = 3 # 高置信度,有明确来源 - VERIFIED = 4 # 已验证,非常可靠 + + LOW = 1 # 低置信度,可能不准确 + MEDIUM = 2 # 中等置信度,有一定依据 + HIGH = 3 # 高置信度,有明确来源 + VERIFIED = 4 # 已验证,非常可靠 class ImportanceLevel(Enum): """重要性等级""" - LOW = 1 # 低重要性,普通信息 - NORMAL = 2 # 一般重要性,日常信息 - HIGH = 3 # 高重要性,重要信息 - CRITICAL = 4 # 关键重要性,核心信息 + + LOW = 1 # 低重要性,普通信息 + NORMAL = 2 # 一般重要性,日常信息 + HIGH = 3 # 高重要性,重要信息 + CRITICAL = 4 # 关键重要性,核心信息 @dataclass @@ -61,12 +63,7 @@ class ContentStructure: def to_dict(self) -> Dict[str, Any]: """转换为字典格式""" - return { - "subject": self.subject, - "predicate": self.predicate, - "object": self.object, - "display": self.display - } + return {"subject": self.subject, "predicate": self.predicate, "object": self.object, "display": self.display} @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure": @@ -75,7 +72,7 @@ class ContentStructure: subject=data.get("subject", ""), predicate=data.get("predicate", ""), object=data.get("object", ""), - display=data.get("display", "") + display=data.get("display", ""), ) def to_subject_list(self) -> List[str]: @@ -98,24 +95,25 @@ class ContentStructure: @dataclass class MemoryMetadata: """记忆元数据 - 简化版本""" + # 基础信息 - memory_id: str # 唯一标识符 - user_id: str # 用户ID - chat_id: Optional[str] = None # 聊天ID(群聊或私聊) + memory_id: str # 唯一标识符 + user_id: str # 用户ID + chat_id: Optional[str] = None # 聊天ID(群聊或私聊) # 时间信息 - created_at: float = 0.0 # 创建时间戳 - last_accessed: float = 0.0 # 最后访问时间 - last_modified: float = 0.0 # 最后修改时间 + created_at: float = 0.0 # 创建时间戳 + last_accessed: float = 0.0 # 最后访问时间 + last_modified: float = 0.0 # 最后修改时间 # 激活频率管理 - last_activation_time: float = 0.0 # 最后激活时间 - activation_frequency: int = 0 # 激活频率(单位时间内的激活次数) - total_activations: int = 0 # 总激活次数 + last_activation_time: float = 0.0 # 最后激活时间 + activation_frequency: int = 0 # 激活频率(单位时间内的激活次数) + total_activations: int = 0 # 总激活次数 # 统计信息 - access_count: int = 0 # 访问次数 - relevance_score: float = 0.0 # 相关度评分 + access_count: int = 0 # 访问次数 + relevance_score: float = 0.0 # 相关度评分 # 信心和重要性(核心字段) confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM @@ -123,10 +121,10 @@ class MemoryMetadata: # 遗忘机制相关 forgetting_threshold: float = 0.0 # 遗忘阈值(动态计算) - last_forgetting_check: float = 0.0 # 上次遗忘检查时间 + last_forgetting_check: float = 0.0 # 上次遗忘检查时间 # 来源信息 - source_context: Optional[str] = None # 来源上下文片段 + source_context: Optional[str] = None # 来源上下文片段 # 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source source: Optional[str] = None @@ -153,13 +151,13 @@ class MemoryMetadata: self.last_forgetting_check = current_time # 兼容性:如果旧字段 source 被使用,保证 source 与 source_context 同步 - if not getattr(self, 'source', None) and getattr(self, 'source_context', None): + if not getattr(self, "source", None) and getattr(self, "source_context", None): try: self.source = str(self.source_context) except Exception: self.source = None # 如果有 source 字段但 source_context 为空,也同步回去 - if not getattr(self, 'source_context', None) and getattr(self, 'source', None): + if not getattr(self, "source_context", None) and getattr(self, "source", None): try: self.source_context = str(self.source) except Exception: @@ -177,7 +175,6 @@ class MemoryMetadata: def _update_activation_frequency(self, current_time: float): """更新激活频率(24小时内的激活次数)""" - from datetime import datetime, timedelta # 如果超过24小时,重置激活频率 if current_time - self.last_activation_time > 86400: # 24小时 = 86400秒 @@ -251,7 +248,7 @@ class MemoryMetadata: "importance": self.importance.value, "forgetting_threshold": self.forgetting_threshold, "last_forgetting_check": self.last_forgetting_check, - "source_context": self.source_context + "source_context": self.source_context, } @classmethod @@ -273,7 +270,7 @@ class MemoryMetadata: importance=ImportanceLevel(data.get("importance", ImportanceLevel.NORMAL.value)), forgetting_threshold=data.get("forgetting_threshold", 0.0), last_forgetting_check=data.get("last_forgetting_check", 0), - source_context=data.get("source_context") + source_context=data.get("source_context"), ) @@ -285,21 +282,21 @@ class MemoryChunk: metadata: MemoryMetadata # 内容结构 - content: ContentStructure # 主谓宾结构 - memory_type: MemoryType # 记忆类型 + content: ContentStructure # 主谓宾结构 + memory_type: MemoryType # 记忆类型 # 扩展信息 - keywords: List[str] = field(default_factory=list) # 关键词列表 - tags: List[str] = field(default_factory=list) # 标签列表 - categories: List[str] = field(default_factory=list) # 分类列表 + keywords: List[str] = field(default_factory=list) # 关键词列表 + tags: List[str] = field(default_factory=list) # 标签列表 + categories: List[str] = field(default_factory=list) # 分类列表 # 语义信息 - embedding: Optional[List[float]] = None # 语义向量 - semantic_hash: Optional[str] = None # 语义哈希值 + embedding: Optional[List[float]] = None # 语义向量 + semantic_hash: Optional[str] = None # 语义哈希值 # 关联信息 related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表 - temporal_context: Optional[Dict[str, Any]] = None # 时间上下文 + temporal_context: Optional[Dict[str, Any]] = None # 时间上下文 def __post_init__(self): """后初始化处理""" @@ -317,7 +314,7 @@ class MemoryChunk: embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding])) hash_input = f"{content_str}|{embedding_str}" - hash_object = hashlib.sha256(hash_input.encode('utf-8')) + hash_object = hashlib.sha256(hash_input.encode("utf-8")) self.semantic_hash = hash_object.hexdigest()[:16] except Exception as e: @@ -430,7 +427,7 @@ class MemoryChunk: "embedding": self.embedding, "semantic_hash": self.semantic_hash, "related_memories": self.related_memories, - "temporal_context": self.temporal_context + "temporal_context": self.temporal_context, } @classmethod @@ -449,14 +446,14 @@ class MemoryChunk: embedding=data.get("embedding"), semantic_hash=data.get("semantic_hash"), related_memories=data.get("related_memories", []), - temporal_context=data.get("temporal_context") + temporal_context=data.get("temporal_context"), ) return chunk def to_json(self) -> str: """转换为JSON字符串""" - return orjson.dumps(self.to_dict(), ensure_ascii=False).decode('utf-8') + return orjson.dumps(self.to_dict(), ensure_ascii=False).decode("utf-8") @classmethod def from_json(cls, json_str: str) -> "MemoryChunk": @@ -530,7 +527,7 @@ class MemoryChunk: MemoryType.SKILL: "🛠️", MemoryType.GOAL: "🎯", MemoryType.EXPERIENCE: "💡", - MemoryType.CONTEXTUAL: "📝" + MemoryType.CONTEXTUAL: "📝", } emoji = type_emoji.get(self.memory_type, "📝") @@ -581,7 +578,7 @@ def create_memory_chunk( importance: ImportanceLevel = ImportanceLevel.NORMAL, confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM, display: Optional[str] = None, - **kwargs + **kwargs, ) -> MemoryChunk: """便捷的内存块创建函数""" metadata = MemoryMetadata( @@ -593,7 +590,7 @@ def create_memory_chunk( last_modified=0, confidence=confidence, importance=importance, - source_context=source_context + source_context=source_context, ) subjects: List[str] @@ -607,18 +604,8 @@ def create_memory_chunk( display_text = display or _build_display_text(subjects, predicate, obj) - content = ContentStructure( - subject=subject_payload, - predicate=predicate, - object=obj, - display=display_text - ) + content = ContentStructure(subject=subject_payload, predicate=predicate, object=obj, display=display_text) - chunk = MemoryChunk( - metadata=metadata, - content=content, - memory_type=memory_type, - **kwargs - ) + chunk = MemoryChunk(metadata=metadata, content=content, memory_type=memory_type, **kwargs) - return chunk \ No newline at end of file + return chunk diff --git a/src/chat/memory_system/memory_forgetting_engine.py b/src/chat/memory_system/memory_forgetting_engine.py index a52580a9f..3e243e433 100644 --- a/src/chat/memory_system/memory_forgetting_engine.py +++ b/src/chat/memory_system/memory_forgetting_engine.py @@ -6,8 +6,8 @@ import time import asyncio -from typing import List, Dict, Optional, Set, Tuple -from datetime import datetime, timedelta +from typing import List, Dict, Optional, Tuple +from datetime import datetime from dataclasses import dataclass from src.common.logger import get_logger @@ -19,6 +19,7 @@ logger = get_logger(__name__) @dataclass class ForgettingStats: """遗忘统计信息""" + total_checked: int = 0 marked_for_forgetting: int = 0 actually_forgotten: int = 0 @@ -30,34 +31,35 @@ class ForgettingStats: @dataclass class ForgettingConfig: """遗忘引擎配置""" + # 检查频率配置 - check_interval_hours: int = 24 # 定期检查间隔(小时) - batch_size: int = 100 # 批处理大小 + check_interval_hours: int = 24 # 定期检查间隔(小时) + batch_size: int = 100 # 批处理大小 # 遗忘阈值配置 - base_forgetting_days: float = 30.0 # 基础遗忘天数 - min_forgetting_days: float = 7.0 # 最小遗忘天数 - max_forgetting_days: float = 365.0 # 最大遗忘天数 + base_forgetting_days: float = 30.0 # 基础遗忘天数 + min_forgetting_days: float = 7.0 # 最小遗忘天数 + max_forgetting_days: float = 365.0 # 最大遗忘天数 # 重要程度权重 critical_importance_bonus: float = 45.0 # 关键重要性额外天数 - high_importance_bonus: float = 30.0 # 高重要性额外天数 - normal_importance_bonus: float = 15.0 # 一般重要性额外天数 - low_importance_bonus: float = 0.0 # 低重要性额外天数 + high_importance_bonus: float = 30.0 # 高重要性额外天数 + normal_importance_bonus: float = 15.0 # 一般重要性额外天数 + low_importance_bonus: float = 0.0 # 低重要性额外天数 # 置信度权重 verified_confidence_bonus: float = 30.0 # 已验证置信度额外天数 - high_confidence_bonus: float = 20.0 # 高置信度额外天数 - medium_confidence_bonus: float = 10.0 # 中等置信度额外天数 - low_confidence_bonus: float = 0.0 # 低置信度额外天数 + high_confidence_bonus: float = 20.0 # 高置信度额外天数 + medium_confidence_bonus: float = 10.0 # 中等置信度额外天数 + low_confidence_bonus: float = 0.0 # 低置信度额外天数 # 激活频率权重 activation_frequency_weight: float = 0.5 # 每次激活增加的天数权重 - max_frequency_bonus: float = 10.0 # 最大激活频率奖励天数 + max_frequency_bonus: float = 10.0 # 最大激活频率奖励天数 # 休眠配置 - dormant_threshold_days: int = 90 # 休眠状态判定天数 - force_forget_dormant_days: int = 180 # 强制遗忘休眠记忆的天数 + dormant_threshold_days: int = 90 # 休眠状态判定天数 + force_forget_dormant_days: int = 180 # 强制遗忘休眠记忆的天数 class MemoryForgettingEngine: @@ -107,13 +109,12 @@ class MemoryForgettingEngine: # 激活频率权重 frequency_bonus = min( memory.metadata.activation_frequency * self.config.activation_frequency_weight, - self.config.max_frequency_bonus + self.config.max_frequency_bonus, ) threshold += frequency_bonus # 确保在合理范围内 - return max(self.config.min_forgetting_days, - min(threshold, self.config.max_forgetting_days)) + return max(self.config.min_forgetting_days, min(threshold, self.config.max_forgetting_days)) def should_forget_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool: """ @@ -265,8 +266,8 @@ class MemoryForgettingEngine: "actually_forgotten": self.stats.actually_forgotten, "dormant_memories": self.stats.dormant_memories, "check_duration": self.stats.check_duration, - "last_check_time": self.stats.last_check_time - } + "last_check_time": self.stats.last_check_time, + }, } def is_forgetting_check_needed(self) -> bool: @@ -302,7 +303,9 @@ class MemoryForgettingEngine: # 如果启用自动清理,执行实际的遗忘操作 if enable_auto_cleanup and (result["normal_forgetting"] or result["force_forgetting"]): - logger.info(f"检测到 {len(result['normal_forgetting'])} 条普通遗忘和 {len(result['force_forgetting'])} 条强制遗忘记忆") + logger.info( + f"检测到 {len(result['normal_forgetting'])} 条普通遗忘和 {len(result['force_forgetting'])} 条强制遗忘记忆" + ) # 这里可以调用实际的删除逻辑 # await self.cleanup_forgotten_memories(result["normal_forgetting"] + result["force_forgetting"]) @@ -318,14 +321,16 @@ class MemoryForgettingEngine: "marked_for_forgetting": self.stats.marked_for_forgetting, "actually_forgotten": self.stats.actually_forgotten, "dormant_memories": self.stats.dormant_memories, - "last_check_time": datetime.fromtimestamp(self.stats.last_check_time).isoformat() if self.stats.last_check_time else None, + "last_check_time": datetime.fromtimestamp(self.stats.last_check_time).isoformat() + if self.stats.last_check_time + else None, "last_check_duration": self.stats.check_duration, "config": { "check_interval_hours": self.config.check_interval_hours, "base_forgetting_days": self.config.base_forgetting_days, "min_forgetting_days": self.config.min_forgetting_days, - "max_forgetting_days": self.config.max_forgetting_days - } + "max_forgetting_days": self.config.max_forgetting_days, + }, } def reset_stats(self): @@ -349,4 +354,4 @@ memory_forgetting_engine = MemoryForgettingEngine() def get_memory_forgetting_engine() -> MemoryForgettingEngine: """获取全局遗忘引擎实例""" - return memory_forgetting_engine \ No newline at end of file + return memory_forgetting_engine diff --git a/src/chat/memory_system/memory_fusion.py b/src/chat/memory_system/memory_fusion.py index 54c77d5bc..3ecc4cd71 100644 --- a/src/chat/memory_system/memory_fusion.py +++ b/src/chat/memory_system/memory_fusion.py @@ -5,14 +5,12 @@ """ import time -from typing import Dict, List, Optional, Tuple, Set, Any +from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import ( - MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel -) +from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel logger = get_logger(__name__) @@ -20,6 +18,7 @@ logger = get_logger(__name__) @dataclass class FusionResult: """融合结果""" + original_count: int fused_count: int removed_duplicates: int @@ -31,6 +30,7 @@ class FusionResult: @dataclass class DuplicateGroup: """重复记忆组""" + group_id: str memories: List[MemoryChunk] similarity_matrix: List[List[float]] @@ -46,22 +46,20 @@ class MemoryFusionEngine: "total_fusions": 0, "memories_fused": 0, "duplicates_removed": 0, - "average_similarity": 0.0 + "average_similarity": 0.0, } # 融合策略配置 self.fusion_strategies = { - "semantic_similarity": True, # 语义相似性融合 - "temporal_proximity": True, # 时间接近性融合 - "logical_consistency": True, # 逻辑一致性融合 - "confidence_boosting": True, # 置信度提升 - "importance_preservation": True # 重要性保持 + "semantic_similarity": True, # 语义相似性融合 + "temporal_proximity": True, # 时间接近性融合 + "logical_consistency": True, # 逻辑一致性融合 + "confidence_boosting": True, # 置信度提升 + "importance_preservation": True, # 重要性保持 } async def fuse_memories( - self, - new_memories: List[MemoryChunk], - existing_memories: Optional[List[MemoryChunk]] = None + self, new_memories: List[MemoryChunk], existing_memories: Optional[List[MemoryChunk]] = None ) -> List[MemoryChunk]: """融合记忆列表""" start_time = time.time() @@ -73,9 +71,7 @@ class MemoryFusionEngine: logger.info(f"开始记忆融合,新记忆: {len(new_memories)},现有记忆: {len(existing_memories or [])}") # 1. 检测重复记忆组 - duplicate_groups = await self._detect_duplicate_groups( - new_memories, existing_memories or [] - ) + duplicate_groups = await self._detect_duplicate_groups(new_memories, existing_memories or []) if not duplicate_groups: fusion_time = time.time() - start_time @@ -110,9 +106,7 @@ class MemoryFusionEngine: return new_memories # 失败时返回原始记忆 async def _detect_duplicate_groups( - self, - new_memories: List[MemoryChunk], - existing_memories: List[MemoryChunk] + self, new_memories: List[MemoryChunk], existing_memories: List[MemoryChunk] ) -> List[DuplicateGroup]: """检测重复记忆组""" all_memories = new_memories + existing_memories @@ -125,16 +119,12 @@ class MemoryFusionEngine: continue # 创建新的重复组 - group = DuplicateGroup( - group_id=f"group_{len(groups)}", - memories=[memory1], - similarity_matrix=[[1.0]] - ) + group = DuplicateGroup(group_id=f"group_{len(groups)}", memories=[memory1], similarity_matrix=[[1.0]]) processed_ids.add(memory1.memory_id) # 寻找相似记忆 - for j, memory2 in enumerate(all_memories[i+1:], i+1): + for j, memory2 in enumerate(all_memories[i + 1 :], i + 1): if memory2.memory_id in processed_ids: continue @@ -182,9 +172,7 @@ class MemoryFusionEngine: # 5. 时间接近性 if self.fusion_strategies["temporal_proximity"]: - temporal_sim = self._calculate_temporal_similarity( - mem1.metadata.created_at, mem2.metadata.created_at - ) + temporal_sim = self._calculate_temporal_similarity(mem1.metadata.created_at, mem2.metadata.created_at) similarity_scores.append(("temporal", temporal_sim)) # 6. 逻辑一致性 @@ -193,14 +181,7 @@ class MemoryFusionEngine: similarity_scores.append(("logical", logical_sim)) # 计算加权平均相似度 - weights = { - "semantic": 0.35, - "text": 0.25, - "keyword": 0.15, - "type": 0.10, - "temporal": 0.10, - "logical": 0.05 - } + weights = {"semantic": 0.35, "text": 0.25, "keyword": 0.15, "type": 0.10, "temporal": 0.10, "logical": 0.05} weighted_sum = 0.0 total_weight = 0.0 @@ -276,9 +257,7 @@ class MemoryFusionEngine: # 宾语相似性 if isinstance(mem1.content.object, str) and isinstance(mem2.content.object, str): - object_sim = self._calculate_text_similarity( - str(mem1.content.object), str(mem2.content.object) - ) + object_sim = self._calculate_text_similarity(str(mem1.content.object), str(mem2.content.object)) consistency_score += object_sim * 0.3 return consistency_score @@ -349,11 +328,7 @@ class MemoryFusionEngine: # 返回置信度最高的记忆 return max(group.memories, key=lambda m: m.metadata.confidence.value) - async def _merge_memory_attributes( - self, - base_memory: MemoryChunk, - memories: List[MemoryChunk] - ) -> MemoryChunk: + async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: List[MemoryChunk]) -> MemoryChunk: """合并记忆属性""" # 创建基础记忆的深拷贝 fused_memory = MemoryChunk.from_dict(base_memory.to_dict()) @@ -436,7 +411,7 @@ class MemoryFusionEngine: "earliest_timestamp": earliest_time, "latest_timestamp": latest_time, "time_span_hours": (latest_time - earliest_time) / 3600, - "source_memories": len(memories) + "source_memories": len(memories), } # 合并其他上下文信息 @@ -451,9 +426,7 @@ class MemoryFusionEngine: return merged_context async def incremental_fusion( - self, - new_memory: MemoryChunk, - existing_memories: List[MemoryChunk] + self, new_memory: MemoryChunk, existing_memories: List[MemoryChunk] ) -> Tuple[MemoryChunk, List[MemoryChunk]]: """增量融合(单个新记忆与现有记忆融合)""" # 寻找相似记忆 @@ -478,7 +451,7 @@ class MemoryFusionEngine: group = DuplicateGroup( group_id=f"incremental_{int(time.time())}", memories=[new_memory, best_match], - similarity_matrix=[[1.0, similarity], [similarity, 1.0]] + similarity_matrix=[[1.0, similarity], [similarity, 1.0]], ) # 执行融合 @@ -530,5 +503,5 @@ class MemoryFusionEngine: "total_fusions": 0, "memories_fused": 0, "duplicates_removed": 0, - "average_similarity": 0.0 - } \ No newline at end of file + "average_similarity": 0.0, + } diff --git a/src/chat/memory_system/memory_manager.py b/src/chat/memory_system/memory_manager.py index 80b9f7dcf..4c6b2696e 100644 --- a/src/chat/memory_system/memory_manager.py +++ b/src/chat/memory_system/memory_manager.py @@ -9,12 +9,9 @@ from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass from src.common.logger import get_logger -from src.config.config import global_config from src.chat.memory_system.memory_system import MemorySystem from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType -from src.chat.memory_system.memory_system import ( - initialize_memory_system -) +from src.chat.memory_system.memory_system import initialize_memory_system logger = get_logger(__name__) @@ -22,6 +19,7 @@ logger = get_logger(__name__) @dataclass class MemoryResult: """记忆查询结果""" + content: str memory_type: str confidence: float @@ -67,6 +65,7 @@ class MemoryManager: # 获取LLM模型 from src.llm_models.utils_model import LLMRequest from src.config.config import model_config + llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory") # 初始化记忆系统 @@ -121,7 +120,7 @@ class MemoryManager: max_memory_num: int = 3, max_memory_length: int = 2, time_weight: float = 1.0, - keyword_weight: float = 1.0 + keyword_weight: float = 1.0, ) -> List[Tuple[str, str]]: """从文本获取相关记忆 - 兼容原有接口""" if not self.is_initialized or not self.memory_system: @@ -131,14 +130,11 @@ class MemoryManager: # 使用增强记忆系统检索 context = { "chat_id": chat_id, - "expected_memory_types": [MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE] + "expected_memory_types": [MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE], } relevant_memories = await self.memory_system.retrieve_relevant_memories( - query=text, - user_id=user_id, - context=context, - limit=max_memory_num + query=text, user_id=user_id, context=context, limit=max_memory_num ) # 转换为原有格式 (topic, content) @@ -156,11 +152,7 @@ class MemoryManager: return [] async def get_memory_from_topic( - self, - valid_keywords: List[str], - max_memory_num: int = 3, - max_memory_length: int = 2, - max_depth: int = 3 + self, valid_keywords: List[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 ) -> List[Tuple[str, str]]: """从关键词获取记忆 - 兼容原有接口""" if not self.is_initialized or not self.memory_system: @@ -177,15 +169,15 @@ class MemoryManager: MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE, - MemoryType.OPINION - ] + MemoryType.OPINION, + ], } relevant_memories = await self.memory_system.retrieve_relevant_memories( query_text=query_text, user_id="default_user", # 可以根据实际需要传递 context=context, - limit=max_memory_num + limit=max_memory_num, ) # 转换为原有格式 (topic, content) @@ -216,11 +208,7 @@ class MemoryManager: return [] async def process_conversation( - self, - conversation_text: str, - context: Dict[str, Any], - user_id: str, - timestamp: Optional[float] = None + self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: Optional[float] = None ) -> List[MemoryChunk]: """处理对话并构建记忆 - 新增功能""" if not self.is_initialized or not self.memory_system: @@ -247,11 +235,7 @@ class MemoryManager: return [] async def get_enhanced_memory_context( - self, - query_text: str, - user_id: str, - context: Optional[Dict[str, Any]] = None, - limit: int = 5 + self, query_text: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: int = 5 ) -> List[MemoryResult]: """获取增强记忆上下文 - 新增功能""" if not self.is_initialized or not self.memory_system: @@ -259,10 +243,7 @@ class MemoryManager: try: relevant_memories = await self.memory_system.retrieve_relevant_memories( - query=query_text, - user_id=None, - context=context or {}, - limit=limit + query=query_text, user_id=None, context=context or {}, limit=limit ) results = [] @@ -276,7 +257,7 @@ class MemoryManager: timestamp=memory.metadata.created_at, source="enhanced_memory", relevance_score=memory.metadata.relevance_score, - structure=structure + structure=structure, ) results.append(result) @@ -342,7 +323,9 @@ class MemoryManager: return None return f"{subject}的职业是{profession}" if predicate == "lives_in": - location = self._extract_from_object(obj_value, ["location", "city", "place"]) or self._format_object(obj_value) + location = self._extract_from_object(obj_value, ["location", "city", "place"]) or self._format_object( + obj_value + ) location = self._clean_text(location) if not location: return None @@ -385,7 +368,9 @@ class MemoryManager: return None return f"{subject}最喜欢{favorite}" if predicate == "mentioned_event": - event_text = self._extract_from_object(obj_value, ["event_text", "description"]) or self._format_object(obj_value) + event_text = self._extract_from_object(obj_value, ["event_text", "description"]) or self._format_object( + obj_value + ) event_text = self._clean_text(self._truncate(event_text)) if not event_text: return None @@ -494,4 +479,4 @@ class MemoryManager: # 全局记忆管理器实例 -memory_manager = MemoryManager() \ No newline at end of file +memory_manager = MemoryManager() diff --git a/src/chat/memory_system/memory_metadata_index.py b/src/chat/memory_system/memory_metadata_index.py index 32104ffab..ad27971a6 100644 --- a/src/chat/memory_system/memory_metadata_index.py +++ b/src/chat/memory_system/memory_metadata_index.py @@ -12,7 +12,6 @@ from dataclasses import dataclass, asdict from datetime import datetime from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import MemoryType, ImportanceLevel, ConfidenceLevel logger = get_logger(__name__) @@ -20,22 +19,23 @@ logger = get_logger(__name__) @dataclass class MemoryMetadataIndexEntry: """元数据索引条目(轻量级,只用于快速过滤)""" + memory_id: str user_id: str - + # 分类信息 memory_type: str # MemoryType.value subjects: List[str] # 主语列表 objects: List[str] # 宾语列表 keywords: List[str] # 关键词列表 tags: List[str] # 标签列表 - + # 数值字段(用于范围过滤) importance: int # ImportanceLevel.value (1-4) confidence: int # ConfidenceLevel.value (1-4) created_at: float # 创建时间戳 access_count: int # 访问次数 - + # 可选字段 chat_id: Optional[str] = None content_preview: Optional[str] = None # 内容预览(前100字符) @@ -43,152 +43,152 @@ class MemoryMetadataIndexEntry: class MemoryMetadataIndex: """记忆元数据索引管理器""" - + def __init__(self, index_file: str = "data/memory_metadata_index.json"): self.index_file = Path(index_file) self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry - + # 倒排索引(用于快速查找) self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids} self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids} self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids} self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids} - + self.lock = threading.RLock() - + # 加载已有索引 self._load_index() - + def _load_index(self): """从文件加载索引""" if not self.index_file.exists(): logger.info(f"元数据索引文件不存在,将创建新索引: {self.index_file}") return - + try: - with open(self.index_file, 'rb') as f: + with open(self.index_file, "rb") as f: data = orjson.loads(f.read()) - + # 重建内存索引 - for entry_data in data.get('entries', []): + for entry_data in data.get("entries", []): entry = MemoryMetadataIndexEntry(**entry_data) self.index[entry.memory_id] = entry self._update_inverted_indices(entry) - + logger.info(f"✅ 加载元数据索引: {len(self.index)} 条记录") - + except Exception as e: logger.error(f"加载元数据索引失败: {e}", exc_info=True) - + def _save_index(self): """保存索引到文件""" try: # 确保目录存在 self.index_file.parent.mkdir(parents=True, exist_ok=True) - + # 序列化所有条目 entries = [asdict(entry) for entry in self.index.values()] data = { - 'version': '1.0', - 'count': len(entries), - 'last_updated': datetime.now().isoformat(), - 'entries': entries + "version": "1.0", + "count": len(entries), + "last_updated": datetime.now().isoformat(), + "entries": entries, } - + # 写入文件(使用临时文件 + 原子重命名) - temp_file = self.index_file.with_suffix('.tmp') - with open(temp_file, 'wb') as f: + temp_file = self.index_file.with_suffix(".tmp") + with open(temp_file, "wb") as f: f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2)) - + temp_file.replace(self.index_file) logger.debug(f"元数据索引已保存: {len(entries)} 条记录") - + except Exception as e: logger.error(f"保存元数据索引失败: {e}", exc_info=True) - + def _update_inverted_indices(self, entry: MemoryMetadataIndexEntry): """更新倒排索引""" memory_id = entry.memory_id - + # 类型索引 self.type_index.setdefault(entry.memory_type, set()).add(memory_id) - + # 主语索引 for subject in entry.subjects: subject_norm = subject.strip().lower() if subject_norm: self.subject_index.setdefault(subject_norm, set()).add(memory_id) - + # 关键词索引 for keyword in entry.keywords: keyword_norm = keyword.strip().lower() if keyword_norm: self.keyword_index.setdefault(keyword_norm, set()).add(memory_id) - + # 标签索引 for tag in entry.tags: tag_norm = tag.strip().lower() if tag_norm: self.tag_index.setdefault(tag_norm, set()).add(memory_id) - + def add_or_update(self, entry: MemoryMetadataIndexEntry): """添加或更新索引条目""" with self.lock: # 如果已存在,先从倒排索引中移除旧记录 if entry.memory_id in self.index: self._remove_from_inverted_indices(entry.memory_id) - + # 添加新记录 self.index[entry.memory_id] = entry self._update_inverted_indices(entry) - + def _remove_from_inverted_indices(self, memory_id: str): """从倒排索引中移除记录""" if memory_id not in self.index: return - + entry = self.index[memory_id] - + # 从类型索引移除 if entry.memory_type in self.type_index: self.type_index[entry.memory_type].discard(memory_id) - + # 从主语索引移除 for subject in entry.subjects: subject_norm = subject.strip().lower() if subject_norm in self.subject_index: self.subject_index[subject_norm].discard(memory_id) - + # 从关键词索引移除 for keyword in entry.keywords: keyword_norm = keyword.strip().lower() if keyword_norm in self.keyword_index: self.keyword_index[keyword_norm].discard(memory_id) - + # 从标签索引移除 for tag in entry.tags: tag_norm = tag.strip().lower() if tag_norm in self.tag_index: self.tag_index[tag_norm].discard(memory_id) - + def remove(self, memory_id: str): """移除索引条目""" with self.lock: if memory_id in self.index: self._remove_from_inverted_indices(memory_id) del self.index[memory_id] - + def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]): """批量添加或更新""" with self.lock: for entry in entries: self.add_or_update(entry) - + def save(self): """保存索引到磁盘""" with self.lock: self._save_index() - + def search( self, memory_types: Optional[List[str]] = None, @@ -201,11 +201,11 @@ class MemoryMetadataIndex: created_before: Optional[float] = None, user_id: Optional[str] = None, limit: Optional[int] = None, - flexible_mode: bool = True # 新增:灵活匹配模式 + flexible_mode: bool = True, # 新增:灵活匹配模式 ) -> List[str]: """ 搜索符合条件的记忆ID列表(支持模糊匹配) - + Returns: List[str]: 符合条件的 memory_id 列表 """ @@ -219,7 +219,7 @@ class MemoryMetadataIndex: created_after=created_after, created_before=created_before, user_id=user_id, - limit=limit + limit=limit, ) else: return self._search_strict( @@ -232,7 +232,7 @@ class MemoryMetadataIndex: created_after=created_after, created_before=created_before, user_id=user_id, - limit=limit + limit=limit, ) def _search_flexible( @@ -243,7 +243,7 @@ class MemoryMetadataIndex: created_before: Optional[float] = None, user_id: Optional[str] = None, limit: Optional[int] = None, - **kwargs # 接受但不使用的参数 + **kwargs, # 接受但不使用的参数 ) -> List[str]: """ 灵活搜索模式:2/4项匹配即可,支持部分匹配 @@ -258,10 +258,7 @@ class MemoryMetadataIndex: """ # 用户过滤(必选) if user_id: - base_candidates = { - mid for mid, entry in self.index.items() - if entry.user_id == user_id - } + base_candidates = {mid for mid, entry in self.index.items() if entry.user_id == user_id} else: base_candidates = set(self.index.keys()) @@ -386,7 +383,7 @@ class MemoryMetadataIndex: created_after: Optional[float] = None, created_before: Optional[float] = None, user_id: Optional[str] = None, - limit: Optional[int] = None + limit: Optional[int] = None, ) -> List[str]: """严格搜索模式(原有逻辑)""" # 初始候选集(所有记忆) @@ -394,10 +391,7 @@ class MemoryMetadataIndex: # 用户过滤(必选) if user_id: - candidate_ids = { - mid for mid, entry in self.index.items() - if entry.user_id == user_id - } + candidate_ids = {mid for mid, entry in self.index.items() if entry.user_id == user_id} else: candidate_ids = set(self.index.keys()) @@ -447,7 +441,8 @@ class MemoryMetadataIndex: # 重要性过滤 if importance_min is not None or importance_max is not None: importance_ids = { - mid for mid in candidate_ids + mid + for mid in candidate_ids if (importance_min is None or self.index[mid].importance >= importance_min) and (importance_max is None or self.index[mid].importance <= importance_max) } @@ -456,41 +451,37 @@ class MemoryMetadataIndex: # 时间范围过滤 if created_after is not None or created_before is not None: time_ids = { - mid for mid in candidate_ids + mid + for mid in candidate_ids if (created_after is None or self.index[mid].created_at >= created_after) and (created_before is None or self.index[mid].created_at <= created_before) } candidate_ids &= time_ids # 转换为列表并排序(按创建时间倒序) - result_ids = sorted( - candidate_ids, - key=lambda mid: self.index[mid].created_at, - reverse=True - ) + result_ids = sorted(candidate_ids, key=lambda mid: self.index[mid].created_at, reverse=True) # 限制数量 if limit: result_ids = result_ids[:limit] logger.debug( - f"[严格搜索] types={memory_types}, subjects={subjects}, " - f"keywords={keywords}, 返回={len(result_ids)}条" + f"[严格搜索] types={memory_types}, subjects={subjects}, keywords={keywords}, 返回={len(result_ids)}条" ) return result_ids - + def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]: """获取单个索引条目""" return self.index.get(memory_id) - + def get_stats(self) -> Dict[str, Any]: """获取索引统计信息""" with self.lock: return { - 'total_memories': len(self.index), - 'types': {mtype: len(ids) for mtype, ids in self.type_index.items()}, - 'subjects_count': len(self.subject_index), - 'keywords_count': len(self.keyword_index), - 'tags_count': len(self.tag_index), + "total_memories": len(self.index), + "types": {mtype: len(ids) for mtype, ids in self.type_index.items()}, + "subjects_count": len(self.subject_index), + "keywords_count": len(self.keyword_index), + "tags_count": len(self.tag_index), } diff --git a/src/chat/memory_system/memory_query_planner.py b/src/chat/memory_system/memory_query_planner.py index 690e26627..a8be9d951 100644 --- a/src/chat/memory_system/memory_query_planner.py +++ b/src/chat/memory_system/memory_query_planner.py @@ -80,10 +80,7 @@ class MemoryQueryPlanner: return self._default_plan(query_text) def _default_plan(self, query_text: str) -> MemoryQueryPlan: - return MemoryQueryPlan( - semantic_query=query_text, - limit=self.default_limit - ) + return MemoryQueryPlan(semantic_query=query_text, limit=self.default_limit) def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan: semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query @@ -122,7 +119,7 @@ class MemoryQueryPlanner: recency_preference=self._safe_str(data.get("recency")) or "any", limit=self._safe_int(data.get("limit"), self.default_limit), emphasis=self._safe_str(data.get("emphasis")) or "balanced", - raw_plan=data + raw_plan=data, ) return plan @@ -154,18 +151,18 @@ class MemoryQueryPlanner: context_section = f""" -## 📋 未读消息上下文 (共{unread_context.get('total_count', 0)}条未读消息) +## 📋 未读消息上下文 (共{unread_context.get("total_count", 0)}条未读消息) ### 最近消息预览: {chr(10).join(message_previews)} ### 上下文关键词: -{', '.join(unread_keywords[:15]) if unread_keywords else '无'} +{", ".join(unread_keywords[:15]) if unread_keywords else "无"} ### 对话参与者: -{', '.join(unread_participants) if unread_participants else '无'} +{", ".join(unread_participants) if unread_participants else "无"} ### 上下文摘要: -{context_summary[:300] if context_summary else '无'} +{context_summary[:300] if context_summary else "无"} """ else: context_section = """ @@ -223,7 +220,7 @@ class MemoryQueryPlanner: start = stripped.find("{") end = stripped.rfind("}") if start != -1 and end != -1 and end > start: - return stripped[start:end + 1] + return stripped[start : end + 1] return stripped if stripped.startswith("{") and stripped.endswith("}") else None @@ -243,4 +240,4 @@ class MemoryQueryPlanner: return default return number except (TypeError, ValueError): - return default \ No newline at end of file + return default diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index 86134e9fc..4a275babd 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -35,6 +35,7 @@ GLOBAL_MEMORY_SCOPE = "global" class MemorySystemStatus(Enum): """记忆系统状态""" + INITIALIZING = "initializing" READY = "ready" BUILDING = "building" @@ -45,6 +46,7 @@ class MemorySystemStatus(Enum): @dataclass class MemorySystemConfig: """记忆系统配置""" + # 记忆构建配置 min_memory_length: int = 10 max_memory_length: int = 500 @@ -97,11 +99,9 @@ class MemorySystemConfig: max_memory_length=global_config.memory.max_memory_length, memory_value_threshold=global_config.memory.memory_value_threshold, min_build_interval_seconds=getattr(global_config.memory, "memory_build_interval", 300.0), - # 向量存储配置 vector_dimension=int(embedding_dimension), similarity_threshold=global_config.memory.vector_similarity_threshold, - # 召回配置 coarse_recall_limit=global_config.memory.metadata_filter_limit, fine_recall_limit=global_config.memory.vector_search_limit, @@ -112,21 +112,16 @@ class MemorySystemConfig: semantic_weight=global_config.memory.semantic_weight, context_weight=global_config.memory.context_weight, recency_weight=global_config.memory.recency_weight, - # 融合配置 fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold, - deduplication_window=timedelta(hours=global_config.memory.deduplication_window_hours) + deduplication_window=timedelta(hours=global_config.memory.deduplication_window_hours), ) class MemorySystem: """精准记忆系统核心类""" - def __init__( - self, - llm_model: Optional[LLMRequest] = None, - config: Optional[MemorySystemConfig] = None - ): + def __init__(self, llm_model: Optional[LLMRequest] = None, config: Optional[MemorySystemConfig] = None): self.config = config or MemorySystemConfig.from_global_config() self.llm_model = llm_model self.status = MemorySystemStatus.INITIALIZING @@ -175,16 +170,16 @@ class MemorySystem: extraction_task_config = value_task_config or fallback_task if value_task_config is None or extraction_task_config is None: - raise RuntimeError("无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。") + raise RuntimeError( + "无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。" + ) self.value_assessment_model = LLMRequest( - model_set=value_task_config, - request_type="memory.value_assessment" + model_set=value_task_config, request_type="memory.value_assessment" ) self.memory_extraction_model = LLMRequest( - model_set=extraction_task_config, - request_type="memory.extraction" + model_set=extraction_task_config, request_type="memory.extraction" ) # 初始化核心组件(简化版) @@ -198,13 +193,13 @@ class MemorySystem: memory_collection="unified_memory_v2", metadata_collection="memory_metadata_v2", similarity_threshold=self.config.similarity_threshold, - search_limit=getattr(global_config.memory, 'unified_storage_search_limit', 20), - batch_size=getattr(global_config.memory, 'unified_storage_batch_size', 100), - enable_caching=getattr(global_config.memory, 'unified_storage_enable_caching', True), - cache_size_limit=getattr(global_config.memory, 'unified_storage_cache_limit', 1000), - auto_cleanup_interval=getattr(global_config.memory, 'unified_storage_auto_cleanup_interval', 3600), - enable_forgetting=getattr(global_config.memory, 'enable_memory_forgetting', True), - retention_hours=getattr(global_config.memory, 'memory_retention_hours', 720) # 30天 + search_limit=getattr(global_config.memory, "unified_storage_search_limit", 20), + batch_size=getattr(global_config.memory, "unified_storage_batch_size", 100), + enable_caching=getattr(global_config.memory, "unified_storage_enable_caching", True), + cache_size_limit=getattr(global_config.memory, "unified_storage_cache_limit", 1000), + auto_cleanup_interval=getattr(global_config.memory, "unified_storage_auto_cleanup_interval", 3600), + enable_forgetting=getattr(global_config.memory, "enable_memory_forgetting", True), + retention_hours=getattr(global_config.memory, "memory_retention_hours", 720), # 30天 ) try: @@ -220,32 +215,27 @@ class MemorySystem: # 从全局配置创建遗忘引擎配置 forgetting_config = ForgettingConfig( # 检查频率配置 - check_interval_hours=getattr(global_config.memory, 'forgetting_check_interval_hours', 24), + check_interval_hours=getattr(global_config.memory, "forgetting_check_interval_hours", 24), batch_size=100, # 固定值,暂不配置 - # 遗忘阈值配置 - base_forgetting_days=getattr(global_config.memory, 'base_forgetting_days', 30.0), - min_forgetting_days=getattr(global_config.memory, 'min_forgetting_days', 7.0), - max_forgetting_days=getattr(global_config.memory, 'max_forgetting_days', 365.0), - + base_forgetting_days=getattr(global_config.memory, "base_forgetting_days", 30.0), + min_forgetting_days=getattr(global_config.memory, "min_forgetting_days", 7.0), + max_forgetting_days=getattr(global_config.memory, "max_forgetting_days", 365.0), # 重要程度权重 - critical_importance_bonus=getattr(global_config.memory, 'critical_importance_bonus', 45.0), - high_importance_bonus=getattr(global_config.memory, 'high_importance_bonus', 30.0), - normal_importance_bonus=getattr(global_config.memory, 'normal_importance_bonus', 15.0), - low_importance_bonus=getattr(global_config.memory, 'low_importance_bonus', 0.0), - + critical_importance_bonus=getattr(global_config.memory, "critical_importance_bonus", 45.0), + high_importance_bonus=getattr(global_config.memory, "high_importance_bonus", 30.0), + normal_importance_bonus=getattr(global_config.memory, "normal_importance_bonus", 15.0), + low_importance_bonus=getattr(global_config.memory, "low_importance_bonus", 0.0), # 置信度权重 - verified_confidence_bonus=getattr(global_config.memory, 'verified_confidence_bonus', 30.0), - high_confidence_bonus=getattr(global_config.memory, 'high_confidence_bonus', 20.0), - medium_confidence_bonus=getattr(global_config.memory, 'medium_confidence_bonus', 10.0), - low_confidence_bonus=getattr(global_config.memory, 'low_confidence_bonus', 0.0), - + verified_confidence_bonus=getattr(global_config.memory, "verified_confidence_bonus", 30.0), + high_confidence_bonus=getattr(global_config.memory, "high_confidence_bonus", 20.0), + medium_confidence_bonus=getattr(global_config.memory, "medium_confidence_bonus", 10.0), + low_confidence_bonus=getattr(global_config.memory, "low_confidence_bonus", 0.0), # 激活频率权重 - activation_frequency_weight=getattr(global_config.memory, 'activation_frequency_weight', 0.5), - max_frequency_bonus=getattr(global_config.memory, 'max_frequency_bonus', 10.0), - + activation_frequency_weight=getattr(global_config.memory, "activation_frequency_weight", 0.5), + max_frequency_bonus=getattr(global_config.memory, "max_frequency_bonus", 10.0), # 休眠配置 - dormant_threshold_days=getattr(global_config.memory, 'dormant_threshold_days', 90) + dormant_threshold_days=getattr(global_config.memory, "dormant_threshold_days", 90), ) self.forgetting_engine = MemoryForgettingEngine(forgetting_config) @@ -253,17 +243,11 @@ class MemorySystem: planner_task_config = getattr(model_config.model_task_config, "utils_small", None) planner_model: Optional[LLMRequest] = None try: - planner_model = LLMRequest( - model_set=planner_task_config, - request_type="memory.query_planner" - ) + planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner") except Exception as planner_exc: logger.warning("查询规划模型初始化失败,将使用默认规划策略: %s", planner_exc, exc_info=True) - self.query_planner = MemoryQueryPlanner( - planner_model, - default_limit=self.config.final_recall_limit - ) + self.query_planner = MemoryQueryPlanner(planner_model, default_limit=self.config.final_recall_limit) # 统一存储已经自动加载数据,无需额外加载 logger.info("✅ 简化版记忆系统初始化完成") @@ -277,11 +261,7 @@ class MemorySystem: raise async def retrieve_memories_for_building( - self, - query_text: str, - user_id: Optional[str] = None, - context: Optional[Dict[str, Any]] = None, - limit: int = 5 + self, query_text: str, user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5 ) -> List[MemoryChunk]: """在构建记忆时检索相关记忆,使用统一存储系统 @@ -304,9 +284,7 @@ class MemorySystem: try: # 使用统一存储检索相似记忆 search_results = await self.unified_storage.search_similar_memories( - query_text=query_text, - limit=limit, - scope_id=user_id + query_text=query_text, limit=limit, scope_id=user_id ) # 转换为记忆对象 @@ -324,10 +302,7 @@ class MemorySystem: return [] async def build_memory_from_conversation( - self, - conversation_text: str, - context: Dict[str, Any], - timestamp: Optional[float] = None + self, conversation_text: str, context: Dict[str, Any], timestamp: Optional[float] = None ) -> List[MemoryChunk]: """从对话中构建记忆 @@ -383,7 +358,7 @@ class MemorySystem: conversation_text, normalized_context, GLOBAL_MEMORY_SCOPE, # 强制使用 global,不区分用户 - timestamp or time.time() + timestamp or time.time(), ) if not memory_chunks: @@ -393,10 +368,7 @@ class MemorySystem: # 3. 记忆融合与去重(包含与历史记忆的融合) existing_candidates = await self._collect_fusion_candidates(memory_chunks) - fused_chunks = await self.fusion_engine.fuse_memories( - memory_chunks, - existing_candidates - ) + fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks, existing_candidates) # 4. 存储记忆到统一存储 stored_count = await self._store_memories_unified(fused_chunks) @@ -459,11 +431,7 @@ class MemorySystem: return [] candidate_ids: Set[str] = set() - new_memory_ids = { - memory.memory_id - for memory in new_memories - if memory and getattr(memory, "memory_id", None) - } + new_memory_ids = {memory.memory_id for memory in new_memories if memory and getattr(memory, "memory_id", None)} # 基于指纹的直接匹配 for memory in new_memories: @@ -501,9 +469,7 @@ class MemorySystem: continue search_tasks.append( self.unified_storage.search_similar_memories( - query_text=display_text, - limit=8, - scope_id=GLOBAL_MEMORY_SCOPE + query_text=display_text, limit=8, scope_id=GLOBAL_MEMORY_SCOPE ) ) @@ -545,10 +511,7 @@ class MemorySystem: return existing_candidates - async def process_conversation_memory( - self, - context: Dict[str, Any] - ) -> Dict[str, Any]: + async def process_conversation_memory(self, context: Dict[str, Any]) -> Dict[str, Any]: """对外暴露的对话记忆处理接口,仅依赖上下文信息""" start_time = time.time() @@ -563,7 +526,9 @@ class MemorySystem: or "" ) - conversation_text = conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate) + conversation_text = ( + conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate) + ) timestamp = context.get("timestamp") if timestamp is None: @@ -573,9 +538,7 @@ class MemorySystem: normalized_context.setdefault("conversation_text", conversation_text) memories = await self.build_memory_from_conversation( - conversation_text=conversation_text, - context=normalized_context, - timestamp=timestamp + conversation_text=conversation_text, context=normalized_context, timestamp=timestamp ) processing_time = time.time() - start_time @@ -586,18 +549,13 @@ class MemorySystem: "created_memories": memories, "memory_count": memory_count, "processing_time": processing_time, - "status": self.status.value + "status": self.status.value, } except Exception as e: processing_time = time.time() - start_time logger.error(f"对话记忆处理失败: {e}", exc_info=True) - return { - "success": False, - "error": str(e), - "processing_time": processing_time, - "status": self.status.value - } + return {"success": False, "error": str(e), "processing_time": processing_time, "status": self.status.value} async def retrieve_relevant_memories( self, @@ -605,7 +563,7 @@ class MemorySystem: user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5, - **kwargs + **kwargs, ) -> List[MemoryChunk]: """检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)""" raw_query = query_text or kwargs.get("query") @@ -617,7 +575,7 @@ class MemorySystem: return [] context = context or {} - + # 所有记忆完全共享,统一使用 global 作用域,不区分用户 resolved_user_id = GLOBAL_MEMORY_SCOPE @@ -627,7 +585,7 @@ class MemorySystem: try: normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, None) effective_limit = self.config.final_recall_limit - + # === 阶段一:元数据粗筛(软性过滤) === coarse_filters = { "user_id": GLOBAL_MEMORY_SCOPE, # 必选:确保作用域正确 @@ -642,119 +600,126 @@ class MemorySystem: # 构建包含未读消息的增强上下文 enhanced_context = await self._build_enhanced_query_context(raw_query, normalized_context) query_plan = await self.query_planner.plan_query(raw_query, enhanced_context) - + # 使用LLM优化后的查询语句(更精确的语义表达) if getattr(query_plan, "semantic_query", None): optimized_query = query_plan.semantic_query - + # 构建JSON元数据过滤条件(用于阶段一粗筛) # 将查询规划的结果转换为元数据过滤条件 if getattr(query_plan, "memory_types", None): - metadata_filters['memory_types'] = [mt.value for mt in query_plan.memory_types] - + metadata_filters["memory_types"] = [mt.value for mt in query_plan.memory_types] + if getattr(query_plan, "subject_includes", None): - metadata_filters['subjects'] = query_plan.subject_includes - + metadata_filters["subjects"] = query_plan.subject_includes + if getattr(query_plan, "required_keywords", None): - metadata_filters['keywords'] = query_plan.required_keywords - + metadata_filters["keywords"] = query_plan.required_keywords + # 时间范围过滤 recency = getattr(query_plan, "recency_preference", "any") current_time = time.time() if recency == "recent": # 最近7天 - metadata_filters['created_after'] = current_time - (7 * 24 * 3600) + metadata_filters["created_after"] = current_time - (7 * 24 * 3600) elif recency == "historical": # 30天以前 - metadata_filters['created_before'] = current_time - (30 * 24 * 3600) - + metadata_filters["created_before"] = current_time - (30 * 24 * 3600) + # 添加用户ID到元数据过滤 - metadata_filters['user_id'] = GLOBAL_MEMORY_SCOPE - + metadata_filters["user_id"] = GLOBAL_MEMORY_SCOPE + logger.debug(f"[阶段一] 查询优化: '{raw_query}' → '{optimized_query}'") logger.debug(f"[阶段一] 元数据过滤条件: {metadata_filters}") - + except Exception as plan_exc: logger.warning("查询规划失败,使用原始查询: %s", plan_exc, exc_info=True) # 即使查询规划失败,也保留基本的user_id过滤 - metadata_filters = {'user_id': GLOBAL_MEMORY_SCOPE} + metadata_filters = {"user_id": GLOBAL_MEMORY_SCOPE} # === 阶段二:向量精筛 === coarse_limit = self.config.coarse_recall_limit # 粗筛阶段返回更多候选 - + logger.debug(f"[阶段二] 开始向量搜索: query='{optimized_query[:60]}...', limit={coarse_limit}") - + search_results = await self.unified_storage.search_similar_memories( query_text=optimized_query, limit=coarse_limit, filters=coarse_filters, # ChromaDB where条件(保留兼容) - metadata_filters=metadata_filters # JSON元数据索引过滤 + metadata_filters=metadata_filters, # JSON元数据索引过滤 ) - + logger.info(f"[阶段二] 向量搜索完成: 返回 {len(search_results)} 条候选") # === 阶段三:综合重排 === scored_memories = [] current_time = time.time() - + for memory, vector_similarity in search_results: # 1. 向量相似度得分(已归一化到 0-1) vector_score = vector_similarity - + # 2. 时效性得分(指数衰减,30天半衰期) age_seconds = current_time - memory.metadata.created_at age_days = age_seconds / (24 * 3600) # 使用 math.exp 而非 np.exp(避免依赖numpy) import math + recency_score = math.exp(-age_days / 30) - + # 3. 重要性得分(枚举值转换为归一化得分 0-1) # ImportanceLevel: LOW=1, NORMAL=2, HIGH=3, CRITICAL=4 importance_enum = memory.metadata.importance - if hasattr(importance_enum, 'value'): + if hasattr(importance_enum, "value"): # 枚举类型,转换为0-1范围:(value - 1) / 3 importance_score = (importance_enum.value - 1) / 3.0 else: # 如果已经是数值,直接使用 importance_score = float(importance_enum) if importance_enum else 0.5 - + # 4. 访问频率得分(归一化,访问10次以上得满分) access_count = memory.metadata.access_count frequency_score = min(access_count / 10.0, 1.0) - + # 综合得分(加权平均) final_score = ( - self.config.vector_weight * vector_score + - self.config.recency_weight * recency_score + - self.config.context_weight * importance_score + - 0.1 * frequency_score # 访问频率权重(固定10%) + self.config.vector_weight * vector_score + + self.config.recency_weight * recency_score + + self.config.context_weight * importance_score + + 0.1 * frequency_score # 访问频率权重(固定10%) ) - - scored_memories.append((memory, final_score, { - "vector": vector_score, - "recency": recency_score, - "importance": importance_score, - "frequency": frequency_score, - "final": final_score - })) - + + scored_memories.append( + ( + memory, + final_score, + { + "vector": vector_score, + "recency": recency_score, + "importance": importance_score, + "frequency": frequency_score, + "final": final_score, + }, + ) + ) + # 更新访问记录 memory.update_access() # 按综合得分排序 scored_memories.sort(key=lambda x: x[1], reverse=True) - + # 返回 Top-K final_memories = [mem for mem, score, details in scored_memories[:effective_limit]] - + retrieval_time = time.time() - start_time # 详细日志 if scored_memories: - logger.info(f"[阶段三] 综合重排完成: Top 3 得分详情") + logger.info("[阶段三] 综合重排完成: Top 3 得分详情") for i, (mem, score, details) in enumerate(scored_memories[:3], 1): try: - summary = mem.content[:60] if hasattr(mem, 'content') and mem.content else "" + summary = mem.content[:60] if hasattr(mem, "content") and mem.content else "" except: summary = "" logger.info( @@ -803,15 +768,12 @@ class MemorySystem: start = stripped.find("{") end = stripped.rfind("}") if start != -1 and end != -1 and end > start: - return stripped[start:end + 1].strip() + return stripped[start : end + 1].strip() return stripped if stripped.startswith("{") and stripped.endswith("}") else None def _normalize_context( - self, - raw_context: Optional[Dict[str, Any]], - user_id: Optional[str], - timestamp: Optional[float] + self, raw_context: Optional[Dict[str, Any]], user_id: Optional[str], timestamp: Optional[float] ) -> Dict[str, Any]: """标准化上下文,确保必备字段存在且格式正确""" context: Dict[str, Any] = {} @@ -850,9 +812,7 @@ class MemorySystem: # 历史窗口配置 window_candidate = ( - context.get("history_limit") - or context.get("history_window") - or context.get("memory_history_limit") + context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit") ) if window_candidate is not None: try: @@ -888,7 +848,9 @@ class MemorySystem: enhanced_context["unread_messages_context"] = unread_messages_summary enhanced_context["has_unread_context"] = True - logger.debug(f"为查询规划构建了增强上下文,包含 {len(unread_messages_summary.get('messages', []))} 条未读消息") + logger.debug( + f"为查询规划构建了增强上下文,包含 {len(unread_messages_summary.get('messages', []))} 条未读消息" + ) else: enhanced_context["has_unread_context"] = False logger.debug("未找到未读消息,使用基础上下文进行查询规划") @@ -934,26 +896,30 @@ class MemorySystem: for msg in unread_messages[:10]: # 限制处理最近10条未读消息 try: # 提取消息内容 - content = (getattr(msg, "processed_plain_text", None) or - getattr(msg, "display_message", None) or "") + content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None) or "" if not content: continue # 提取发送者信息 sender_name = "未知用户" if hasattr(msg, "user_info") and msg.user_info: - sender_name = (getattr(msg.user_info, "user_nickname", None) or - getattr(msg.user_info, "user_cardname", None) or - getattr(msg.user_info, "user_id", None) or "未知用户") + sender_name = ( + getattr(msg.user_info, "user_nickname", None) + or getattr(msg.user_info, "user_cardname", None) + or getattr(msg.user_info, "user_id", None) + or "未知用户" + ) participant_names.add(sender_name) # 添加到消息摘要 - messages_summary.append({ - "sender": sender_name, - "content": content[:200], # 限制长度避免过长 - "timestamp": getattr(msg, "time", None) - }) + messages_summary.append( + { + "sender": sender_name, + "content": content[:200], # 限制长度避免过长 + "timestamp": getattr(msg, "time", None), + } + ) # 提取关键词(简单实现) content_lower = content.lower() @@ -975,10 +941,12 @@ class MemorySystem: "processed_count": len(messages_summary), "keywords": list(all_keywords)[:20], # 最多20个关键词 "participants": list(participant_names), - "context_summary": self._build_unread_context_summary(messages_summary) + "context_summary": self._build_unread_context_summary(messages_summary), } - logger.debug(f"收集到未读消息上下文: {len(messages_summary)}条消息,{len(all_keywords)}个关键词,{len(participant_names)}个参与者") + logger.debug( + f"收集到未读消息上下文: {len(messages_summary)}条消息,{len(all_keywords)}个关键词,{len(participant_names)}个参与者" + ) return unread_context except Exception as e: @@ -1051,10 +1019,7 @@ class MemorySystem: if user_id and fallback_text: try: relevant_memories = await self.retrieve_memories_for_building( - query_text=fallback_text, - user_id=user_id, - context=context, - limit=3 + query_text=fallback_text, user_id=user_id, context=context, limit=3 ) if relevant_memories: @@ -1068,9 +1033,7 @@ class MemorySystem: memory_transcript = f"{memory_transcript}\n[当前消息] {cleaned_fallback}" logger.debug( - "使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s", - len(relevant_memories), - user_id + "使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s", len(relevant_memories), user_id ) return memory_transcript @@ -1087,11 +1050,7 @@ class MemorySystem: def _determine_history_limit(self, context: Dict[str, Any]) -> int: """确定历史消息获取数量,限制在30-50之间""" default_limit = 40 - candidate = ( - context.get("history_limit") - or context.get("history_window") - or context.get("memory_history_limit") - ) + candidate = context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit") if isinstance(candidate, str): try: @@ -1186,9 +1145,9 @@ class MemorySystem: {text} 上下文信息: -- 用户ID: {context.get('user_id', 'unknown')} -- 消息类型: {context.get('message_type', 'unknown')} -- 时间: {datetime.fromtimestamp(context.get('timestamp', time.time()))} +- 用户ID: {context.get("user_id", "unknown")} +- 消息类型: {context.get("message_type", "unknown")} +- 时间: {datetime.fromtimestamp(context.get("timestamp", time.time()))} ## 📋 评估要求: @@ -1214,9 +1173,7 @@ class MemorySystem: }} """ - response, _ = await self.value_assessment_model.generate_response_async( - prompt, temperature=0.3 - ) + response, _ = await self.value_assessment_model.generate_response_async(prompt, temperature=0.3) # 解析响应 try: @@ -1236,7 +1193,7 @@ class MemorySystem: return max(0.0, min(1.0, value_score)) except (orjson.JSONDecodeError, ValueError) as e: - preview = response[:200].replace('\n', ' ') + preview = response[:200].replace("\n", " ") logger.warning(f"解析价值评估响应失败: {e}, 响应片段: {preview}") return 0.5 # 默认中等价值 @@ -1331,13 +1288,15 @@ class MemorySystem: else: obj_part = str(obj).strip() - base = "|".join([ - str(memory.user_id or "unknown"), - memory.memory_type.value, - subject_part, - predicate_part, - obj_part, - ]) + base = "|".join( + [ + str(memory.user_id or "unknown"), + memory.memory_type.value, + subject_part, + predicate_part, + obj_part, + ] + ) return hashlib.sha256(base.encode("utf-8")).hexdigest() @@ -1352,7 +1311,7 @@ class MemorySystem: "total_memories": self.total_memories, "last_build_time": self.last_build_time, "last_retrieval_time": self.last_retrieval_time, - "config": asdict(self.config) + "config": asdict(self.config), } def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float: @@ -1369,7 +1328,9 @@ class MemorySystem: keyword_overlap = 0.0 if context_keywords: memory_keywords = set(k.lower() for k in memory.keywords) - keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max(len(context_keywords), 1) + keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max( + len(context_keywords), 1 + ) importance_boost = (memory.metadata.importance.value - 1) / 3 * 0.1 confidence_boost = (memory.metadata.confidence.value - 1) / 3 * 0.05 @@ -1429,7 +1390,7 @@ class MemorySystem: """重建向量存储(如果需要)""" try: # 检查是否有记忆缓存数据 - if not hasattr(self.unified_storage, 'memory_cache') or not self.unified_storage.memory_cache: + if not hasattr(self.unified_storage, "memory_cache") or not self.unified_storage.memory_cache: logger.info("无记忆缓存数据,跳过向量存储重建") return @@ -1443,19 +1404,19 @@ class MemorySystem: memories_to_rebuild.append(memory) elif memory.text_content and memory.text_content.strip(): memories_to_rebuild.append(memory) - + if not memories_to_rebuild: logger.warning("没有找到可重建向量的记忆") return - + logger.info(f"准备为 {len(memories_to_rebuild)} 条记忆重建向量") - + # 批量重建向量 batch_size = 10 rebuild_count = 0 - + for i in range(0, len(memories_to_rebuild), batch_size): - batch = memories_to_rebuild[i:i + batch_size] + batch = memories_to_rebuild[i : i + batch_size] try: await self.unified_storage.store_memories(batch) rebuild_count += len(batch) @@ -1472,7 +1433,7 @@ class MemorySystem: final_count = self.unified_storage.storage_stats.get("total_vectors", 0) logger.info(f"✅ 向量存储重建完成,最终向量数量: {final_count}") - + except Exception as e: logger.error(f"❌ 向量存储重建失败: {e}", exc_info=True) @@ -1495,4 +1456,4 @@ async def initialize_memory_system(llm_model: Optional[LLMRequest] = None): if memory_system is None: memory_system = MemorySystem(llm_model=llm_model) await memory_system.initialize() - return memory_system \ No newline at end of file + return memory_system diff --git a/src/chat/memory_system/vector_memory_storage_v2.py b/src/chat/memory_system/vector_memory_storage_v2.py index 6c590d888..3c924ba30 100644 --- a/src/chat/memory_system/vector_memory_storage_v2.py +++ b/src/chat/memory_system/vector_memory_storage_v2.py @@ -15,12 +15,10 @@ import time import orjson import asyncio import threading -from typing import Dict, List, Optional, Tuple, Set, Any, Union -from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple, Any +from dataclasses import dataclass from datetime import datetime -from pathlib import Path -import numpy as np from src.common.logger import get_logger from src.common.vector_db import vector_db_service from src.chat.utils.utils import get_embedding @@ -33,6 +31,7 @@ logger = get_logger(__name__) # 全局枚举映射表缓存 _ENUM_MAPPINGS_CACHE = {} + def _build_enum_mapping(enum_class: type) -> Dict[str, Any]: """构建枚举类的完整映射表 @@ -49,10 +48,10 @@ def _build_enum_mapping(enum_class: type) -> Dict[str, Any]: return _ENUM_MAPPINGS_CACHE[cache_key] mapping = { - "name_to_enum": {}, # 枚举名称 -> 枚举实例 (HIGH -> ImportanceLevel.HIGH) - "value_to_enum": {}, # 整数值 -> 枚举实例 (3 -> ImportanceLevel.HIGH) - "value_str_to_enum": {}, # 字符串value -> 枚举实例 ("3" -> ImportanceLevel.HIGH) - "enum_value_to_name": {}, # 枚举实例 -> 名称映射 (反向) + "name_to_enum": {}, # 枚举名称 -> 枚举实例 (HIGH -> ImportanceLevel.HIGH) + "value_to_enum": {}, # 整数值 -> 枚举实例 (3 -> ImportanceLevel.HIGH) + "value_str_to_enum": {}, # 字符串value -> 枚举实例 ("3" -> ImportanceLevel.HIGH) + "enum_value_to_name": {}, # 枚举实例 -> 名称映射 (反向) "all_possible_strings": set(), # 所有可能的字符串表示 } @@ -77,7 +76,9 @@ def _build_enum_mapping(enum_class: type) -> Dict[str, Any]: # 缓存结果 _ENUM_MAPPINGS_CACHE[cache_key] = mapping - logger.debug(f"构建枚举映射表: {enum_class.__name__} -> {len(mapping['name_to_enum'])} 个名称映射, {len(mapping['value_to_enum'])} 个值映射") + logger.debug( + f"构建枚举映射表: {enum_class.__name__} -> {len(mapping['name_to_enum'])} 个名称映射, {len(mapping['value_to_enum'])} 个值映射" + ) return mapping @@ -85,42 +86,43 @@ def _build_enum_mapping(enum_class: type) -> Dict[str, Any]: @dataclass class VectorStorageConfig: """Vector存储配置""" + # 集合配置 memory_collection: str = "unified_memory_v2" metadata_collection: str = "memory_metadata_v2" - + # 检索配置 similarity_threshold: float = 0.5 # 降低阈值以提高召回率(0.5-0.6 是合理范围) search_limit: int = 20 batch_size: int = 100 - + # 性能配置 enable_caching: bool = True cache_size_limit: int = 1000 auto_cleanup_interval: int = 3600 # 1小时 - + # 遗忘配置 enable_forgetting: bool = True retention_hours: int = 24 * 30 # 30天 - + @classmethod def from_global_config(cls): """从全局配置创建实例""" from src.config.config import global_config - + memory_cfg = global_config.memory - + return cls( - memory_collection=getattr(memory_cfg, 'vector_db_memory_collection', 'unified_memory_v2'), - metadata_collection=getattr(memory_cfg, 'vector_db_metadata_collection', 'memory_metadata_v2'), - similarity_threshold=getattr(memory_cfg, 'vector_db_similarity_threshold', 0.5), - search_limit=getattr(memory_cfg, 'vector_db_search_limit', 20), - batch_size=getattr(memory_cfg, 'vector_db_batch_size', 100), - enable_caching=getattr(memory_cfg, 'vector_db_enable_caching', True), - cache_size_limit=getattr(memory_cfg, 'vector_db_cache_size_limit', 1000), - auto_cleanup_interval=getattr(memory_cfg, 'vector_db_auto_cleanup_interval', 3600), - enable_forgetting=getattr(memory_cfg, 'enable_memory_forgetting', True), - retention_hours=getattr(memory_cfg, 'vector_db_retention_hours', 720), + memory_collection=getattr(memory_cfg, "vector_db_memory_collection", "unified_memory_v2"), + metadata_collection=getattr(memory_cfg, "vector_db_metadata_collection", "memory_metadata_v2"), + similarity_threshold=getattr(memory_cfg, "vector_db_similarity_threshold", 0.5), + search_limit=getattr(memory_cfg, "vector_db_search_limit", 20), + batch_size=getattr(memory_cfg, "vector_db_batch_size", 100), + enable_caching=getattr(memory_cfg, "vector_db_enable_caching", True), + cache_size_limit=getattr(memory_cfg, "vector_db_cache_size_limit", 1000), + auto_cleanup_interval=getattr(memory_cfg, "vector_db_auto_cleanup_interval", 3600), + enable_forgetting=getattr(memory_cfg, "enable_memory_forgetting", True), + retention_hours=getattr(memory_cfg, "vector_db_retention_hours", 720), ) @@ -133,15 +135,16 @@ class VectorMemoryStorage: """ index = {} for memory in self.memory_cache.values(): - for kw in getattr(memory, 'keywords', []): + for kw in getattr(memory, "keywords", []): if not kw: continue kw_norm = kw.strip().lower() if kw_norm: - index.setdefault(kw_norm, []).append(getattr(memory.metadata, 'memory_id', None)) + index.setdefault(kw_norm, []).append(getattr(memory.metadata, "memory_id", None)) return index + """基于Vector DB的记忆存储系统""" - + def __init__(self, config: Optional[VectorStorageConfig] = None): # 默认从全局配置读取,如果没有传入config if config is None: @@ -153,25 +156,25 @@ class VectorMemoryStorage: self.config = VectorStorageConfig() else: self.config = config - + # 从配置中获取批处理大小和集合名称 self.batch_size = self.config.batch_size self.collection_name = self.config.memory_collection self.vector_db_service = vector_db_service - + # 内存缓存 self.memory_cache: Dict[str, MemoryChunk] = {} self.cache_timestamps: Dict[str, float] = {} self._cache = self.memory_cache # 别名,兼容旧代码 - + # 元数据索引管理器(JSON文件索引) self.metadata_index = MemoryMetadataIndex() - + # 遗忘引擎 self.forgetting_engine: Optional[MemoryForgettingEngine] = None if self.config.enable_forgetting: self.forgetting_engine = MemoryForgettingEngine() - + # 统计信息 self.stats = { "total_memories": 0, @@ -180,55 +183,48 @@ class VectorMemoryStorage: "total_searches": 0, "total_stores": 0, "last_cleanup_time": 0.0, - "forgetting_stats": {} + "forgetting_stats": {}, } - + # 线程锁 self._lock = threading.RLock() - + # 定时清理任务 self._cleanup_task = None self._stop_cleanup = False - + # 初始化系统 self._initialize_storage() self._start_cleanup_task() - + def _initialize_storage(self): """初始化Vector DB存储""" try: # 创建记忆集合 vector_db_service.get_or_create_collection( name=self.config.memory_collection, - metadata={ - "description": "统一记忆存储V2", - "hnsw:space": "cosine", - "version": "2.0" - } + metadata={"description": "统一记忆存储V2", "hnsw:space": "cosine", "version": "2.0"}, ) - + # 创建元数据集合(用于复杂查询) vector_db_service.get_or_create_collection( name=self.config.metadata_collection, - metadata={ - "description": "记忆元数据索引", - "hnsw:space": "cosine", - "version": "2.0" - } + metadata={"description": "记忆元数据索引", "hnsw:space": "cosine", "version": "2.0"}, ) - + # 获取当前记忆总数 self.stats["total_memories"] = vector_db_service.count(self.config.memory_collection) - + logger.info(f"Vector记忆存储初始化完成,当前记忆数: {self.stats['total_memories']}") - + except Exception as e: logger.error(f"Vector存储系统初始化失败: {e}", exc_info=True) raise - + def _start_cleanup_task(self): """启动定时清理任务""" if self.config.auto_cleanup_interval > 0: + def cleanup_worker(): while not self._stop_cleanup: try: @@ -237,47 +233,50 @@ class VectorMemoryStorage: asyncio.create_task(self._perform_auto_cleanup()) except Exception as e: logger.error(f"定时清理任务出错: {e}") - + self._cleanup_task = threading.Thread(target=cleanup_worker, daemon=True) self._cleanup_task.start() logger.info(f"定时清理任务已启动,间隔: {self.config.auto_cleanup_interval}秒") - + async def _perform_auto_cleanup(self): """执行自动清理""" try: current_time = time.time() - + # 清理过期缓存 if self.config.enable_caching: expired_keys = [ - memory_id for memory_id, timestamp in self.cache_timestamps.items() + memory_id + for memory_id, timestamp in self.cache_timestamps.items() if current_time - timestamp > 3600 # 1小时过期 ] - + for key in expired_keys: self.memory_cache.pop(key, None) self.cache_timestamps.pop(key, None) - + if expired_keys: logger.debug(f"清理了 {len(expired_keys)} 个过期缓存项") - + # 执行遗忘检查 if self.forgetting_engine: await self.perform_forgetting_check() - + self.stats["last_cleanup_time"] = current_time - + except Exception as e: logger.error(f"自动清理失败: {e}") - + def _memory_to_vector_format(self, memory: MemoryChunk) -> Dict[str, Any]: """将MemoryChunk转换为向量存储格式""" try: # 获取memory_id - memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None) - + memory_id = getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", None) + # 生成向量表示的文本 - display_text = getattr(memory, 'display', None) or getattr(memory, 'text_content', None) or str(memory.content) + display_text = ( + getattr(memory, "display", None) or getattr(memory, "text_content", None) or str(memory.content) + ) if not display_text.strip(): logger.warning(f"记忆 {memory_id} 缺少有效的显示文本") display_text = f"{memory.memory_type.value}: {', '.join(memory.subjects)}" @@ -296,16 +295,16 @@ class VectorMemoryStorage: "keywords": orjson.dumps(memory.keywords).decode("utf-8"), # 列表转JSON字符串 "tags": orjson.dumps(memory.tags).decode("utf-8"), # 列表转JSON字符串 "categories": orjson.dumps(memory.categories).decode("utf-8"), # 列表转JSON字符串 - "relevance_score": memory.metadata.relevance_score + "relevance_score": memory.metadata.relevance_score, } # 添加可选字段 if memory.metadata.source_context: metadata["source_context"] = str(memory.metadata.source_context) - + if memory.content.predicate: metadata["predicate"] = memory.content.predicate - + if memory.content.object: if isinstance(memory.content.object, (dict, list)): metadata["object"] = orjson.dumps(memory.content.object).decode() @@ -316,14 +315,14 @@ class VectorMemoryStorage: "id": memory_id, "embedding": None, # 将由vector_db_service生成 "metadata": metadata, - "document": display_text + "document": display_text, } except Exception as e: - memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown') + memory_id = getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", "unknown") logger.error(f"转换记忆 {memory_id} 到向量格式失败: {e}", exc_info=True) raise - + def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]: """将Vector DB结果转换为MemoryChunk""" try: @@ -331,10 +330,10 @@ class VectorMemoryStorage: if "memory_data" in metadata: memory_dict = orjson.loads(metadata["memory_data"]) return MemoryChunk.from_dict(memory_dict) - + # 兜底:从基础字段重建(使用新的结构化格式) logger.warning(f"未找到memory_data,使用兜底逻辑重建记忆 (id={metadata.get('memory_id', 'unknown')})") - + # 构建符合MemoryChunk.from_dict期望的结构 memory_dict = { "metadata": { @@ -345,24 +344,30 @@ class VectorMemoryStorage: "last_modified": metadata.get("timestamp", time.time()), "access_count": metadata.get("access_count", 0), "relevance_score": 0.0, - "confidence": self._parse_enum_value(metadata.get("confidence", 2), ConfidenceLevel, ConfidenceLevel.MEDIUM), - "importance": self._parse_enum_value(metadata.get("importance", 2), ImportanceLevel, ImportanceLevel.NORMAL), + "confidence": self._parse_enum_value( + metadata.get("confidence", 2), ConfidenceLevel, ConfidenceLevel.MEDIUM + ), + "importance": self._parse_enum_value( + metadata.get("importance", 2), ImportanceLevel, ImportanceLevel.NORMAL + ), "source_context": None, }, "content": { "subject": "", "predicate": "", "object": "", - "display": document # 使用document作为显示文本 + "display": document, # 使用document作为显示文本 }, "memory_type": metadata.get("memory_type", "contextual"), - "keywords": orjson.loads(metadata.get("keywords", "[]")) if isinstance(metadata.get("keywords"), str) else metadata.get("keywords", []), + "keywords": orjson.loads(metadata.get("keywords", "[]")) + if isinstance(metadata.get("keywords"), str) + else metadata.get("keywords", []), "tags": [], "categories": [], "embedding": None, "semantic_hash": None, "related_memories": [], - "temporal_context": None + "temporal_context": None, } return MemoryChunk.from_dict(memory_dict) @@ -434,40 +439,39 @@ class VectorMemoryStorage: # 其他类型,返回默认值 logger.warning(f"不支持的{enum_class.__name__}值类型: {type(value)},使用默认值") return default - + def _get_from_cache(self, memory_id: str) -> Optional[MemoryChunk]: """从缓存获取记忆""" if not self.config.enable_caching: return None - + with self._lock: if memory_id in self.memory_cache: self.cache_timestamps[memory_id] = time.time() self.stats["cache_hits"] += 1 return self.memory_cache[memory_id] - + self.stats["cache_misses"] += 1 return None - + def _add_to_cache(self, memory: MemoryChunk): """添加记忆到缓存""" if not self.config.enable_caching: return - + with self._lock: # 检查缓存大小限制 if len(self.memory_cache) >= self.config.cache_size_limit: # 移除最老的缓存项 - oldest_id = min(self.cache_timestamps.keys(), - key=lambda k: self.cache_timestamps[k]) + oldest_id = min(self.cache_timestamps.keys(), key=lambda k: self.cache_timestamps[k]) self.memory_cache.pop(oldest_id, None) self.cache_timestamps.pop(oldest_id, None) - - memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None) + + memory_id = getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", None) if memory_id: self.memory_cache[memory_id] = memory self.cache_timestamps[memory_id] = time.time() - + async def store_memories(self, memories: List[MemoryChunk]) -> int: """批量存储记忆""" if not memories: @@ -475,7 +479,7 @@ class VectorMemoryStorage: start_time = datetime.now() success_count = 0 - + try: # 转换为向量格式 vector_data_list = [] @@ -484,7 +488,7 @@ class VectorMemoryStorage: vector_data = self._memory_to_vector_format(memory) vector_data_list.append(vector_data) except Exception as e: - memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown') + memory_id = getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", "unknown") logger.error(f"处理记忆 {memory_id} 失败: {e}") continue @@ -494,8 +498,8 @@ class VectorMemoryStorage: # 批量存储到向量数据库 for i in range(0, len(vector_data_list), self.batch_size): - batch = vector_data_list[i:i + self.batch_size] - + batch = vector_data_list[i : i + self.batch_size] + try: # 生成embeddings embeddings = [] @@ -507,29 +511,37 @@ class VectorMemoryStorage: logger.error(f"生成embedding失败: {e}") # 使用零向量作为后备 embeddings.append([0.0] * 768) # 默认维度 - + # vector_db_service.add 需要embeddings参数 self.vector_db_service.add( collection_name=self.collection_name, embeddings=embeddings, ids=[item["id"] for item in batch], documents=[item["document"] for item in batch], - metadatas=[item["metadata"] for item in batch] + metadatas=[item["metadata"] for item in batch], ) success = True - + if success: # 更新缓存和元数据索引 metadata_entries = [] for item in batch: memory_id = item["id"] # 从原始 memories 列表中找到对应的 MemoryChunk - memory = next((m for m in memories if (getattr(m.metadata, 'memory_id', None) or getattr(m, 'memory_id', None)) == memory_id), None) + memory = next( + ( + m + for m in memories + if (getattr(m.metadata, "memory_id", None) or getattr(m, "memory_id", None)) + == memory_id + ), + None, + ) if memory: # 更新缓存 self._cache[memory_id] = memory success_count += 1 - + # 创建元数据索引条目 try: index_entry = MemoryMetadataIndexEntry( @@ -545,12 +557,12 @@ class VectorMemoryStorage: created_at=memory.metadata.created_at, access_count=memory.metadata.access_count, chat_id=memory.metadata.chat_id, - content_preview=str(memory.content)[:100] if memory.content else None + content_preview=str(memory.content)[:100] if memory.content else None, ) metadata_entries.append(index_entry) except Exception as e: logger.warning(f"创建元数据索引条目失败 (memory_id={memory_id}): {e}") - + # 批量更新元数据索引 if metadata_entries: try: @@ -560,14 +572,14 @@ class VectorMemoryStorage: logger.error(f"批量更新元数据索引失败: {e}") else: logger.warning(f"批次存储失败,跳过 {len(batch)} 条记忆") - + except Exception as e: logger.error(f"批量存储失败: {e}", exc_info=True) continue duration = (datetime.now() - start_time).total_seconds() logger.info(f"成功存储 {success_count}/{len(memories)} 条记忆,耗时 {duration:.2f}秒") - + # 保存元数据索引到磁盘 if success_count > 0: try: @@ -575,18 +587,18 @@ class VectorMemoryStorage: logger.debug("元数据索引已保存到磁盘") except Exception as e: logger.error(f"保存元数据索引失败: {e}") - + return success_count except Exception as e: logger.error(f"批量存储记忆失败: {e}", exc_info=True) return success_count - + async def store_memory(self, memory: MemoryChunk) -> bool: """存储单条记忆""" result = await self.store_memories([memory]) return result > 0 - + async def search_similar_memories( self, query_text: str, @@ -594,11 +606,11 @@ class VectorMemoryStorage: similarity_threshold: Optional[float] = None, filters: Optional[Dict[str, Any]] = None, # 新增:元数据过滤参数(用于JSON索引粗筛) - metadata_filters: Optional[Dict[str, Any]] = None + metadata_filters: Optional[Dict[str, Any]] = None, ) -> List[Tuple[MemoryChunk, float]]: """ 搜索相似记忆(混合索引模式) - + Args: query_text: 查询文本 limit: 返回数量限制 @@ -617,47 +629,49 @@ class VectorMemoryStorage: """ if not query_text.strip(): return [] - + try: # === 阶段一:JSON元数据粗筛(可选) === candidate_ids: Optional[List[str]] = None if metadata_filters: logger.debug(f"[JSON元数据粗筛] 开始,过滤条件: {metadata_filters}") candidate_ids = self.metadata_index.search( - memory_types=metadata_filters.get('memory_types'), - subjects=metadata_filters.get('subjects'), - keywords=metadata_filters.get('keywords'), - tags=metadata_filters.get('tags'), - importance_min=metadata_filters.get('importance_min'), - importance_max=metadata_filters.get('importance_max'), - created_after=metadata_filters.get('created_after'), - created_before=metadata_filters.get('created_before'), - user_id=metadata_filters.get('user_id'), + memory_types=metadata_filters.get("memory_types"), + subjects=metadata_filters.get("subjects"), + keywords=metadata_filters.get("keywords"), + tags=metadata_filters.get("tags"), + importance_min=metadata_filters.get("importance_min"), + importance_max=metadata_filters.get("importance_max"), + created_after=metadata_filters.get("created_after"), + created_before=metadata_filters.get("created_before"), + user_id=metadata_filters.get("user_id"), limit=self.config.search_limit * 2, # 粗筛返回更多候选 - flexible_mode=True # 使用灵活匹配模式 + flexible_mode=True, # 使用灵活匹配模式 ) logger.info(f"[JSON元数据粗筛] 完成,筛选出 {len(candidate_ids)} 个候选ID") # 如果粗筛后没有结果,回退到全部记忆搜索 if not candidate_ids: total_memories = len(self.metadata_index.index) - logger.warning(f"JSON元数据粗筛后无候选,启用回退机制:在全部 {total_memories} 条记忆中进行向量搜索") + logger.warning( + f"JSON元数据粗筛后无候选,启用回退机制:在全部 {total_memories} 条记忆中进行向量搜索" + ) logger.info("💡 提示:这可能是因为查询条件过于严格,或相关记忆的元数据与查询条件不完全匹配") candidate_ids = None # 设为None表示不限制候选ID else: - logger.debug(f"[JSON元数据粗筛] 成功筛选出候选,进入向量精筛阶段") - + logger.debug("[JSON元数据粗筛] 成功筛选出候选,进入向量精筛阶段") + # === 阶段二:向量精筛 === # 生成查询向量 query_embedding = await get_embedding(query_text) if not query_embedding: return [] - + threshold = similarity_threshold or self.config.similarity_threshold - + # 构建where条件 where_conditions = filters or {} - + # 如果有候选ID列表,添加到where条件 if candidate_ids: # ChromaDB的where条件需要使用$in操作符 @@ -665,117 +679,114 @@ class VectorMemoryStorage: logger.debug(f"[向量精筛] 限制在 {len(candidate_ids)} 个候选ID内搜索") else: logger.info("[向量精筛] 在全部记忆中搜索(元数据筛选无结果回退)") - + # 查询Vector DB logger.debug(f"[向量精筛] 开始,limit={min(limit, self.config.search_limit)}") results = vector_db_service.query( collection_name=self.config.memory_collection, query_embeddings=[query_embedding], n_results=min(limit, self.config.search_limit), - where=where_conditions if where_conditions else None + where=where_conditions if where_conditions else None, ) - + # 处理结果 similar_memories = [] - + if results.get("documents") and results["documents"][0]: documents = results["documents"][0] distances = results.get("distances", [[]])[0] metadatas = results.get("metadatas", [[]])[0] ids = results.get("ids", [[]])[0] - - logger.info(f"向量检索返回原始结果:documents={len(documents)}, ids={len(ids)}, metadatas={len(metadatas)}") - for i, (doc, metadata, memory_id) in enumerate(zip(documents, metadatas, ids)): + + logger.info( + f"向量检索返回原始结果:documents={len(documents)}, ids={len(ids)}, metadatas={len(metadatas)}" + ) + for i, (doc, metadata, memory_id) in enumerate(zip(documents, metadatas, ids, strict=False)): # 计算相似度 distance = distances[i] if i < len(distances) else 1.0 similarity = 1 - distance # ChromaDB返回距离,转换为相似度 - + if similarity < threshold: continue - + # 首先尝试从缓存获取 memory = self._get_from_cache(memory_id) - + if not memory: # 从Vector结果重建 memory = self._vector_result_to_memory(doc, metadata) if memory: self._add_to_cache(memory) - + if memory: similar_memories.append((memory, similarity)) # 记录单条结果的关键日志(id,相似度,简短文本) try: - short_text = (str(memory.content)[:120]) if hasattr(memory, 'content') else (doc[:120] if isinstance(doc, str) else '') + short_text = ( + (str(memory.content)[:120]) + if hasattr(memory, "content") + else (doc[:120] if isinstance(doc, str) else "") + ) except Exception: - short_text = '' + short_text = "" logger.info(f"检索结果 - id={memory_id}, similarity={similarity:.4f}, summary={short_text}") - + # 按相似度排序 similar_memories.sort(key=lambda x: x[1], reverse=True) - + self.stats["total_searches"] += 1 - logger.info(f"搜索相似记忆: query='{query_text[:60]}...', limit={limit}, threshold={threshold}, filters={where_conditions}, 返回数={len(similar_memories)}") + logger.info( + f"搜索相似记忆: query='{query_text[:60]}...', limit={limit}, threshold={threshold}, filters={where_conditions}, 返回数={len(similar_memories)}" + ) logger.debug(f"搜索相似记忆 详细结果数={len(similar_memories)}") - + return similar_memories - + except Exception as e: logger.error(f"搜索相似记忆失败: {e}") return [] - + async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]: """根据ID获取记忆""" # 首先尝试从缓存获取 memory = self._get_from_cache(memory_id) if memory: return memory - + try: # 从Vector DB获取 - results = vector_db_service.get( - collection_name=self.config.memory_collection, - ids=[memory_id] - ) - + results = vector_db_service.get(collection_name=self.config.memory_collection, ids=[memory_id]) + if results.get("documents") and results["documents"]: document = results["documents"][0] metadata = results["metadatas"][0] if results.get("metadatas") else {} - + memory = self._vector_result_to_memory(document, metadata) if memory: self._add_to_cache(memory) - + return memory - + except Exception as e: logger.error(f"获取记忆 {memory_id} 失败: {e}") - + return None - - async def get_memories_by_filters( - self, - filters: Dict[str, Any], - limit: int = 100 - ) -> List[MemoryChunk]: + + async def get_memories_by_filters(self, filters: Dict[str, Any], limit: int = 100) -> List[MemoryChunk]: """根据过滤条件获取记忆""" try: - results = vector_db_service.get( - collection_name=self.config.memory_collection, - where=filters, - limit=limit - ) - + results = vector_db_service.get(collection_name=self.config.memory_collection, where=filters, limit=limit) + memories = [] if results.get("documents"): documents = results["documents"] metadatas = results.get("metadatas", [{}] * len(documents)) ids = results.get("ids", []) - + logger.info(f"按过滤条件获取返回: docs={len(documents)}, ids={len(ids)}") - for i, (doc, metadata) in enumerate(zip(documents, metadatas)): + for i, (doc, metadata) in enumerate(zip(documents, metadatas, strict=False)): memory_id = ids[i] if i < len(ids) else None - + # 首先尝试从缓存获取 if memory_id: memory = self._get_from_cache(memory_id) @@ -783,7 +794,7 @@ class VectorMemoryStorage: memories.append(memory) logger.debug(f"过滤获取命中缓存: id={memory_id}") continue - + # 从Vector结果重建 memory = self._vector_result_to_memory(doc, metadata) if memory: @@ -791,137 +802,129 @@ class VectorMemoryStorage: if memory_id: self._add_to_cache(memory) logger.debug(f"过滤获取结果: id={memory_id}, meta_keys={list(metadata.keys())}") - + return memories - + except Exception as e: logger.error(f"根据过滤条件获取记忆失败: {e}") return [] - + async def update_memory(self, memory: MemoryChunk) -> bool: """更新记忆""" try: - memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None) + memory_id = getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", None) if not memory_id: logger.error("无法更新记忆:缺少memory_id") return False - + # 先删除旧记忆 await self.delete_memory(memory_id) - + # 重新存储更新后的记忆 return await self.store_memory(memory) - + except Exception as e: - memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown') + memory_id = getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", "unknown") logger.error(f"更新记忆 {memory_id} 失败: {e}") return False - + async def delete_memory(self, memory_id: str) -> bool: """删除记忆""" try: # 从Vector DB删除 - vector_db_service.delete( - collection_name=self.config.memory_collection, - ids=[memory_id] - ) - + vector_db_service.delete(collection_name=self.config.memory_collection, ids=[memory_id]) + # 从缓存删除 with self._lock: self.memory_cache.pop(memory_id, None) self.cache_timestamps.pop(memory_id, None) - + self.stats["total_memories"] = max(0, self.stats["total_memories"] - 1) logger.debug(f"删除记忆: {memory_id}") - + return True - + except Exception as e: logger.error(f"删除记忆 {memory_id} 失败: {e}") return False - + async def delete_memories_by_filters(self, filters: Dict[str, Any]) -> int: """根据过滤条件批量删除记忆""" try: # 先获取要删除的记忆ID results = vector_db_service.get( - collection_name=self.config.memory_collection, - where=filters, - include=["metadatas"] + collection_name=self.config.memory_collection, where=filters, include=["metadatas"] ) - + if not results.get("ids"): return 0 - + memory_ids = results["ids"] - + # 批量删除 - vector_db_service.delete( - collection_name=self.config.memory_collection, - where=filters - ) - + vector_db_service.delete(collection_name=self.config.memory_collection, where=filters) + # 从缓存删除 with self._lock: for memory_id in memory_ids: self.memory_cache.pop(memory_id, None) self.cache_timestamps.pop(memory_id, None) - + deleted_count = len(memory_ids) self.stats["total_memories"] = max(0, self.stats["total_memories"] - deleted_count) logger.info(f"批量删除记忆: {deleted_count} 条") - + return deleted_count - + except Exception as e: logger.error(f"批量删除记忆失败: {e}") return 0 - + async def perform_forgetting_check(self) -> Dict[str, Any]: """执行遗忘检查""" if not self.forgetting_engine: return {"error": "遗忘引擎未启用"} - + try: # 获取所有记忆进行遗忘检查 # 注意:对于大型数据集,这里应该分批处理 current_time = time.time() cutoff_time = current_time - (self.config.retention_hours * 3600) - + # 先删除明显过期的记忆 expired_filters = {"timestamp": {"$lt": cutoff_time}} expired_count = await self.delete_memories_by_filters(expired_filters) - + # 对剩余记忆执行智能遗忘检查 # 这里为了性能考虑,只检查一部分记忆 sample_memories = await self.get_memories_by_filters({}, limit=500) - + if sample_memories: result = await self.forgetting_engine.perform_forgetting_check(sample_memories) - + # 遗忘标记的记忆 forgetting_ids = result.get("normal_forgetting", []) + result.get("force_forgetting", []) forgotten_count = 0 - + for memory_id in forgetting_ids: if await self.delete_memory(memory_id): forgotten_count += 1 - + result["forgotten_count"] = forgotten_count result["expired_count"] = expired_count - + # 更新统计 self.stats["forgetting_stats"] = self.forgetting_engine.get_forgetting_stats() - + logger.info(f"遗忘检查完成: 过期删除 {expired_count}, 智能遗忘 {forgotten_count}") return result - + return {"expired_count": expired_count, "forgotten_count": 0} - + except Exception as e: logger.error(f"执行遗忘检查失败: {e}") return {"error": str(e)} - + def get_storage_stats(self) -> Dict[str, Any]: """获取存储统计信息""" try: @@ -929,27 +932,27 @@ class VectorMemoryStorage: self.stats["total_memories"] = current_total except Exception: pass - + return { **self.stats, "cache_size": len(self.memory_cache), "collection_name": self.config.memory_collection, "storage_type": "vector_db_v2", - "uptime": time.time() - self.stats.get("start_time", time.time()) + "uptime": time.time() - self.stats.get("start_time", time.time()), } - + def stop(self): """停止存储系统""" self._stop_cleanup = True - + if self._cleanup_task and self._cleanup_task.is_alive(): logger.info("正在停止定时清理任务...") - + # 清空缓存 with self._lock: self.memory_cache.clear() self.cache_timestamps.clear() - + logger.info("Vector记忆存储系统已停止") @@ -960,36 +963,33 @@ _global_vector_storage = None def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> VectorMemoryStorage: """获取全局Vector记忆存储实例""" global _global_vector_storage - + if _global_vector_storage is None: _global_vector_storage = VectorMemoryStorage(config) - + return _global_vector_storage # 兼容性接口 class VectorMemoryStorageAdapter: """适配器类,提供与原UnifiedMemoryStorage兼容的接口""" - + def __init__(self, config: Optional[VectorStorageConfig] = None): self.storage = VectorMemoryStorage(config) - + async def store_memories(self, memories: List[MemoryChunk]) -> int: return await self.storage.store_memories(memories) - + async def search_similar_memories( - self, - query_text: str, - limit: int = 10, - scope_id: Optional[str] = None, - filters: Optional[Dict[str, Any]] = None + self, query_text: str, limit: int = 10, scope_id: Optional[str] = None, filters: Optional[Dict[str, Any]] = None ) -> List[Tuple[str, float]]: - results = await self.storage.search_similar_memories( - query_text, limit, filters=filters - ) + results = await self.storage.search_similar_memories(query_text, limit, filters=filters) # 转换为原格式:(memory_id, similarity) - return [(getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown'), similarity) for memory, similarity in results] - + return [ + (getattr(memory.metadata, "memory_id", None) or getattr(memory, "memory_id", "unknown"), similarity) + for memory, similarity in results + ] + def get_stats(self) -> Dict[str, Any]: return self.storage.get_storage_stats() @@ -998,33 +998,34 @@ if __name__ == "__main__": # 简单测试 async def test_vector_storage(): storage = VectorMemoryStorage() - + # 创建测试记忆 from src.chat.memory_system.memory_chunk import MemoryType + test_memory = MemoryChunk( memory_id="test_001", user_id="test_user", text_content="今天天气很好,适合出门散步", memory_type=MemoryType.FACT, keywords=["天气", "散步"], - importance=0.7 + importance=0.7, ) - + # 存储记忆 success = await storage.store_memory(test_memory) print(f"存储结果: {success}") - + # 搜索记忆 results = await storage.search_similar_memories("天气怎么样", limit=5) print(f"搜索结果: {len(results)} 条") - + for memory, similarity in results: print(f" - {memory.text_content[:50]}... (相似度: {similarity:.3f})") - + # 获取统计信息 stats = storage.get_storage_stats() print(f"存储统计: {stats}") - + storage.stop() - - asyncio.run(test_vector_storage()) \ No newline at end of file + + asyncio.run(test_vector_storage()) diff --git a/src/chat/message_manager/__init__.py b/src/chat/message_manager/__init__.py index 3e27858a0..fe5e90785 100644 --- a/src/chat/message_manager/__init__.py +++ b/src/chat/message_manager/__init__.py @@ -5,15 +5,12 @@ from .message_manager import MessageManager, message_manager from .context_manager import SingleStreamContextManager -from .distribution_manager import ( - StreamLoopManager, - stream_loop_manager -) +from .distribution_manager import StreamLoopManager, stream_loop_manager __all__ = [ "MessageManager", "message_manager", "SingleStreamContextManager", "StreamLoopManager", - "stream_loop_manager" -] \ No newline at end of file + "stream_loop_manager", +] diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index 0f744b129..5f3212065 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -230,12 +230,14 @@ class SingleStreamContextManager: 异步计算消息的兴趣度。 此方法通过检查当前是否存在正在运行的 asyncio 事件循环来兼容同步和异步调用。 """ + # 内部异步函数,封装实际的计算逻辑 async def _get_score(): try: from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ( chatter_interest_scoring_system, ) + interest_score = await chatter_interest_scoring_system._calculate_single_message_score( message=message, bot_nickname=global_config.bot.nickname ) diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index f16a22b70..69f3e662d 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -34,17 +34,13 @@ class StreamLoopManager: } # 配置参数 - self.max_concurrent_streams = ( - max_concurrent_streams or global_config.chat.max_concurrent_distributions - ) + self.max_concurrent_streams = max_concurrent_streams or global_config.chat.max_concurrent_distributions # 强制分发策略 self.force_dispatch_unread_threshold: Optional[int] = getattr( global_config.chat, "force_dispatch_unread_threshold", 20 ) - self.force_dispatch_min_interval: float = getattr( - global_config.chat, "force_dispatch_min_interval", 0.1 - ) + self.force_dispatch_min_interval: float = getattr(global_config.chat, "force_dispatch_min_interval", 0.1) # Chatter管理器 self.chatter_manager: Optional[ChatterManager] = None @@ -108,7 +104,9 @@ class StreamLoopManager: if force and len(self.stream_loops) >= self.max_concurrent_streams: logger.warning( - "流 %s 未读消息积压严重(>%s),突破并发限制强制启动分发", stream_id, self.force_dispatch_unread_threshold + "流 %s 未读消息积压严重(>%s),突破并发限制强制启动分发", + stream_id, + self.force_dispatch_unread_threshold, ) # 创建流循环任务 @@ -168,9 +166,7 @@ class StreamLoopManager: if has_messages: if force_dispatch: - logger.info( - "流 %s 未读消息 %d 条,触发强制分发", stream_id, unread_count - ) + logger.info("流 %s 未读消息 %d 条,触发强制分发", stream_id, unread_count) # 3. 激活chatter处理 success = await self._process_stream_messages(stream_id, context) diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 5b715715b..bd55bd43f 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -11,7 +11,7 @@ from typing import Dict, Optional, Any, TYPE_CHECKING, List from src.chat.message_receive.chat_stream import ChatStream from src.common.logger import get_logger from src.common.data_models.database_data_model import DatabaseMessages -from src.common.data_models.message_manager_data_model import StreamContext, MessageManagerStats, StreamStats +from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats from src.chat.chatter_manager import ChatterManager from src.chat.planner_actions.action_manager import ChatterActionManager from .sleep_manager.sleep_manager import SleepManager @@ -21,7 +21,7 @@ from src.plugin_system.apis.chat_api import get_chat_manager from .distribution_manager import stream_loop_manager if TYPE_CHECKING: - from src.common.data_models.message_manager_data_model import StreamContext + pass logger = get_logger("message_manager") @@ -63,7 +63,7 @@ class MessageManager: stream_loop_manager.set_chatter_manager(self.chatter_manager) logger.info("🚀 消息管理器已启动 | 流循环管理器已启动") - + async def stop(self): """停止消息管理器""" if not self.is_running: @@ -88,7 +88,9 @@ class MessageManager: logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在") return await self._check_and_handle_interruption(chat_stream) - chat_stream.context_manager.context.processing_task = asyncio.create_task(chat_stream.context_manager.add_message(message)) + chat_stream.context_manager.context.processing_task = asyncio.create_task( + chat_stream.context_manager.add_message(message) + ) except Exception as e: logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}") @@ -141,11 +143,7 @@ class MessageManager: if not message_id: continue - payload = { - key: value - for key, value in item.items() - if key != "message_id" and value is not None - } + payload = {key: value for key, value in item.items() if key != "message_id" and value is not None} if not payload: continue @@ -169,9 +167,7 @@ class MessageManager: if not chat_stream: logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在") return - success = await chat_stream.context_manager.update_message( - message_id, {"actions": [action]} - ) + success = await chat_stream.context_manager.update_message(message_id, {"actions": [action]}) if success: logger.debug(f"为消息 {message_id} 添加动作 {action} 成功") else: @@ -193,7 +189,7 @@ class MessageManager: context.is_active = False # 取消处理任务 - if hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done(): + if hasattr(context, "processing_task") and context.processing_task and not context.processing_task.done(): context.processing_task.cancel() logger.info(f"停用聊天流: {stream_id}") @@ -236,7 +232,11 @@ class MessageManager: unread_count=unread_count, history_count=len(context.history_messages), last_check_time=context.last_check_time, - has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()), + has_active_task=bool( + hasattr(context, "processing_task") + and context.processing_task + and not context.processing_task.done() + ), ) except Exception as e: @@ -284,7 +284,10 @@ class MessageManager: return # 检查是否有正在进行的处理任务 - if chat_stream.context_manager.context.processing_task and not chat_stream.context_manager.context.processing_task.done(): + if ( + chat_stream.context_manager.context.processing_task + and not chat_stream.context_manager.context.processing_task.done() + ): # 计算打断概率 interruption_probability = chat_stream.context_manager.context.calculate_interruption_probability( global_config.chat.interruption_max_limit, global_config.chat.interruption_probability_factor @@ -310,7 +313,9 @@ class MessageManager: # 增加打断计数并应用afc阈值降低 chat_stream.context_manager.context.increment_interruption_count() - chat_stream.context_manager.context.apply_interruption_afc_reduction(global_config.chat.interruption_afc_reduction) + chat_stream.context_manager.context.apply_interruption_afc_reduction( + global_config.chat.interruption_afc_reduction + ) # 检查是否已达到最大次数 if chat_stream.context_manager.context.interruption_count >= global_config.chat.interruption_max_limit: @@ -364,7 +369,7 @@ class MessageManager: return context = chat_stream.context_manager.context - if hasattr(context, 'unread_messages') and context.unread_messages: + if hasattr(context, "unread_messages") and context.unread_messages: logger.debug(f"正在为流 {stream_id} 清除 {len(context.unread_messages)} 条未读消息") context.unread_messages.clear() else: diff --git a/src/chat/message_manager/sleep_manager/notification_sender.py b/src/chat/message_manager/sleep_manager/notification_sender.py index 07e8b09d4..d40cd612d 100644 --- a/src/chat/message_manager/sleep_manager/notification_sender.py +++ b/src/chat/message_manager/sleep_manager/notification_sender.py @@ -1,33 +1,33 @@ from src.common.logger import get_logger -#from ..hfc_context import HfcContext +# from ..hfc_context import HfcContext logger = get_logger("notification_sender") class NotificationSender: @staticmethod - async def send_goodnight_notification(context): # type: ignore + async def send_goodnight_notification(context): # type: ignore """发送晚安通知""" - #try: - #from ..proactive.events import ProactiveTriggerEvent - #from ..proactive.proactive_thinker import ProactiveThinker - - #event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight") - #proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor) - #await proactive_thinker.think(event) - #except Exception as e: - #logger.error(f"发送晚安通知失败: {e}") + # try: + # from ..proactive.events import ProactiveTriggerEvent + # from ..proactive.proactive_thinker import ProactiveThinker + + # event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight") + # proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor) + # await proactive_thinker.think(event) + # except Exception as e: + # logger.error(f"发送晚安通知失败: {e}") @staticmethod - async def send_insomnia_notification(context, reason: str): # type: ignore + async def send_insomnia_notification(context, reason: str): # type: ignore """发送失眠通知""" - #try: - #from ..proactive.events import ProactiveTriggerEvent - #from ..proactive.proactive_thinker import ProactiveThinker + # try: + # from ..proactive.events import ProactiveTriggerEvent + # from ..proactive.proactive_thinker import ProactiveThinker - #event = ProactiveTriggerEvent(source="sleep_manager", reason=reason) - #proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor) - #await proactive_thinker.think(event) - #except Exception as e: - #logger.error(f"发送失眠通知失败: {e}") \ No newline at end of file + # event = ProactiveTriggerEvent(source="sleep_manager", reason=reason) + # proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor) + # await proactive_thinker.think(event) + # except Exception as e: + # logger.error(f"发送失眠通知失败: {e}") diff --git a/src/chat/message_manager/sleep_manager/sleep_manager.py b/src/chat/message_manager/sleep_manager/sleep_manager.py index 0ed21e685..b0cf79b1b 100644 --- a/src/chat/message_manager/sleep_manager/sleep_manager.py +++ b/src/chat/message_manager/sleep_manager/sleep_manager.py @@ -1,6 +1,6 @@ import asyncio import random -from datetime import datetime, timedelta, date +from datetime import datetime, timedelta from typing import Optional, TYPE_CHECKING from src.common.logger import get_logger @@ -21,6 +21,7 @@ class SleepManager: 它实现了一个状态机,根据预设的时间表、睡眠压力和随机因素, 在不同的睡眠状态(如清醒、准备入睡、睡眠、失眠)之间进行切换。 """ + def __init__(self): """ 初始化睡眠管理器。 @@ -97,7 +98,7 @@ class SleepManager: logger.info(f"进入理论休眠时间 '{activity}',开始进行睡眠决策...") else: logger.info("进入理论休眠时间,开始进行睡眠决策...") - + if global_config.sleep_system.enable_flexible_sleep: # --- 新的弹性睡眠逻辑 --- if wakeup_manager: @@ -112,7 +113,7 @@ class SleepManager: pressure_diff = (pressure_threshold - sleep_pressure) / pressure_threshold # 延迟分钟数,压力越低,延迟越长 delay_minutes = int(pressure_diff * max_delay_minutes) - + # 确保总延迟不超过当日最大值 remaining_delay = max_delay_minutes - self.context.total_delayed_minutes_today delay_minutes = min(delay_minutes, remaining_delay) @@ -151,9 +152,10 @@ class SleepManager: if wakeup_manager and global_config.sleep_system.enable_pre_sleep_notification: asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context)) self.context.current_state = SleepState.SLEEPING - - def _handle_preparing_sleep(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]): + def _handle_preparing_sleep( + self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"] + ): """处理“准备入睡”状态下的逻辑。""" # 如果在准备期间离开了理论睡眠时间,则取消入睡 if not is_in_theoretical_sleep: @@ -166,16 +168,22 @@ class SleepManager: logger.info("睡眠缓冲期结束,正式进入休眠状态。") self.context.current_state = SleepState.SLEEPING self._last_fully_slept_log_time = now.timestamp() - + # 设置一个随机的延迟,用于触发“睡后失眠”检查 delay_minutes_range = global_config.sleep_system.insomnia_trigger_delay_minutes delay_minutes = random.randint(delay_minutes_range[0], delay_minutes_range[1]) self.context.sleep_buffer_end_time = now + timedelta(minutes=delay_minutes) logger.info(f"已设置睡后失眠检查,将在 {delay_minutes} 分钟后触发。") - + self.context.save() - def _handle_sleeping(self, now: datetime, is_in_theoretical_sleep: bool, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]): + def _handle_sleeping( + self, + now: datetime, + is_in_theoretical_sleep: bool, + activity: Optional[str], + wakeup_manager: Optional["WakeUpManager"], + ): """处理“正在睡觉”状态下的逻辑。""" # 如果理论睡眠时间结束,则自然醒来 if not is_in_theoretical_sleep: @@ -198,14 +206,16 @@ class SleepManager: if insomnia_reason: self.context.current_state = SleepState.INSOMNIA - + # 设置失眠的持续时间 duration_minutes_range = global_config.sleep_system.insomnia_duration_minutes duration_minutes = random.randint(*duration_minutes_range) self.context.sleep_buffer_end_time = now + timedelta(minutes=duration_minutes) - + # 发送失眠通知 - asyncio.create_task(NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason)) + asyncio.create_task( + NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason) + ) logger.info(f"进入失眠状态 (原因: {insomnia_reason}),将持续 {duration_minutes} 分钟。") else: # 睡眠压力正常,不触发失眠,清除检查时间点 diff --git a/src/chat/message_manager/sleep_manager/sleep_state.py b/src/chat/message_manager/sleep_manager/sleep_state.py index d59f1f3d6..105302169 100644 --- a/src/chat/message_manager/sleep_manager/sleep_state.py +++ b/src/chat/message_manager/sleep_manager/sleep_state.py @@ -25,6 +25,7 @@ class SleepContext: """ 睡眠上下文,负责封装和管理所有与睡眠相关的状态,并处理其持久化。 """ + def __init__(self): """初始化睡眠上下文,并从本地存储加载初始状态。""" self.current_state: SleepState = SleepState.AWAKE @@ -83,4 +84,4 @@ class SleepContext: logger.info(f"成功从本地存储加载睡眠上下文: {state}") except Exception as e: - logger.warning(f"加载睡眠上下文失败,将使用默认值: {e}") \ No newline at end of file + logger.warning(f"加载睡眠上下文失败,将使用默认值: {e}") diff --git a/src/chat/message_manager/sleep_manager/time_checker.py b/src/chat/message_manager/sleep_manager/time_checker.py index 47376ac35..773830c3a 100644 --- a/src/chat/message_manager/sleep_manager/time_checker.py +++ b/src/chat/message_manager/sleep_manager/time_checker.py @@ -15,23 +15,25 @@ class TimeChecker: self._daily_sleep_offset: int = 0 self._daily_wake_offset: int = 0 self._offset_date = None - + def _get_daily_offsets(self): """获取当天的睡眠和起床时间偏移量,每天生成一次""" today = datetime.now().date() - + # 如果是新的一天,重新生成偏移量 if self._offset_date != today: sleep_offset_range = global_config.sleep_system.sleep_time_offset_minutes wake_offset_range = global_config.sleep_system.wake_up_time_offset_minutes - + # 生成 ±offset_range 范围内的随机偏移量 self._daily_sleep_offset = random.randint(-sleep_offset_range, sleep_offset_range) self._daily_wake_offset = random.randint(-wake_offset_range, wake_offset_range) self._offset_date = today - - logger.debug(f"生成新的每日偏移量 - 睡觉时间偏移: {self._daily_sleep_offset}分钟, 起床时间偏移: {self._daily_wake_offset}分钟") - + + logger.debug( + f"生成新的每日偏移量 - 睡觉时间偏移: {self._daily_sleep_offset}分钟, 起床时间偏移: {self._daily_wake_offset}分钟" + ) + return self._daily_sleep_offset, self._daily_wake_offset @staticmethod @@ -82,28 +84,36 @@ class TimeChecker: try: start_time_str = global_config.sleep_system.fixed_sleep_time end_time_str = global_config.sleep_system.fixed_wake_up_time - + # 获取当天的偏移量 sleep_offset, wake_offset = self._get_daily_offsets() - + # 解析基础时间 base_start_time = datetime.strptime(start_time_str, "%H:%M") base_end_time = datetime.strptime(end_time_str, "%H:%M") - + # 应用偏移量 actual_start_time = (base_start_time + timedelta(minutes=sleep_offset)).time() actual_end_time = (base_end_time + timedelta(minutes=wake_offset)).time() - - logger.debug(f"固定睡眠时间检查 - 基础时间: {start_time_str}-{end_time_str}, " - f"偏移后时间: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')}, " - f"当前时间: {now_time.strftime('%H:%M')}") + + logger.debug( + f"固定睡眠时间检查 - 基础时间: {start_time_str}-{end_time_str}, " + f"偏移后时间: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')}, " + f"当前时间: {now_time.strftime('%H:%M')}" + ) if actual_start_time <= actual_end_time: if actual_start_time <= now_time < actual_end_time: - return True, f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})" + return ( + True, + f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})", + ) else: if now_time >= actual_start_time or now_time < actual_end_time: - return True, f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})" + return ( + True, + f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})", + ) except ValueError as e: logger.error(f"固定的睡眠时间格式不正确,请使用 HH:MM 格式: {e}") - return False, None \ No newline at end of file + return False, None diff --git a/src/chat/message_manager/sleep_manager/wakeup_context.py b/src/chat/message_manager/sleep_manager/wakeup_context.py index bfa1a62dd..efb59e392 100644 --- a/src/chat/message_manager/sleep_manager/wakeup_context.py +++ b/src/chat/message_manager/sleep_manager/wakeup_context.py @@ -1,4 +1,3 @@ -import time from src.common.logger import get_logger from src.manager.local_store_manager import local_storage @@ -9,6 +8,7 @@ class WakeUpContext: """ 唤醒上下文,负责封装和管理所有与唤醒相关的状态,并处理其持久化。 """ + def __init__(self): """初始化唤醒上下文,并从本地存储加载初始状态。""" self.wakeup_value: float = 0.0 @@ -42,4 +42,4 @@ class WakeUpContext: "sleep_pressure": self.sleep_pressure, } local_storage[self._get_storage_key()] = state - logger.debug(f"已将唤醒上下文保存到本地存储: {state}") \ No newline at end of file + logger.debug(f"已将唤醒上下文保存到本地存储: {state}") diff --git a/src/chat/message_manager/sleep_manager/wakeup_manager.py b/src/chat/message_manager/sleep_manager/wakeup_manager.py index 51ab80bb1..5fc68ff41 100644 --- a/src/chat/message_manager/sleep_manager/wakeup_manager.py +++ b/src/chat/message_manager/sleep_manager/wakeup_manager.py @@ -3,7 +3,6 @@ import time from typing import Optional, TYPE_CHECKING from src.common.logger import get_logger from src.config.config import global_config -from src.manager.local_store_manager import local_storage from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext if TYPE_CHECKING: @@ -51,7 +50,7 @@ class WakeUpManager: if not self.enabled: logger.info("唤醒度系统已禁用,跳过启动") return - + self.is_running = True if not self._decay_task or self._decay_task.done(): self._decay_task = asyncio.create_task(self._decay_loop()) @@ -88,6 +87,7 @@ class WakeUpManager: self.context.is_angry = False # 通知情绪管理系统清除愤怒状态 from src.mood.mood_manager import mood_manager + if self.angry_chat_id: mood_manager.clear_angry_from_wakeup(self.angry_chat_id) self.angry_chat_id = None @@ -104,7 +104,9 @@ class WakeUpManager: logger.debug(f"唤醒度衰减: {old_value:.1f} -> {self.context.wakeup_value:.1f}") self.context.save() - def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None) -> bool: + def add_wakeup_value( + self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None + ) -> bool: """ 增加唤醒度值 @@ -173,6 +175,7 @@ class WakeUpManager: # 通知情绪管理系统进入愤怒状态 from src.mood.mood_manager import mood_manager + mood_manager.set_angry_from_wakeup(chat_id) # 通知SleepManager重置睡眠状态 @@ -194,6 +197,7 @@ class WakeUpManager: self.context.is_angry = False # 通知情绪管理系统清除愤怒状态 from src.mood.mood_manager import mood_manager + if self.angry_chat_id: mood_manager.clear_angry_from_wakeup(self.angry_chat_id) self.angry_chat_id = None diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index d30600b70..47d1f26e2 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -190,7 +190,7 @@ class ChatBot: try: # 检查聊天类型限制 if not plus_command_instance.is_chat_type_allowed(): - is_group = message.message_info.group_info + is_group = message.message_info.group_info logger.info( f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" ) @@ -420,7 +420,9 @@ class ChatBot: await message.process() # 在这里打印[所见]日志,确保在所有处理和过滤之前记录 - logger.info(f"\u001b[38;5;118m{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m") + logger.info( + f"\u001b[38;5;118m{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m" + ) # 过滤检查 if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore @@ -452,7 +454,7 @@ class ChatBot: result = await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message=message) if not result.all_continue_process(): raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理") - + # TODO:暂不可用 # 确认从接口发来的message是否有自定义的prompt模板信息 if message.message_info.template_info and not message.message_info.template_info.template_default: @@ -469,14 +471,14 @@ class ChatBot: async def preprocess(): # 存储消息到数据库 from .storage import MessageStorage - + try: await MessageStorage.store_message(message, message.chat_stream) logger.debug(f"消息已存储到数据库: {message.message_info.message_id}") except Exception as e: logger.error(f"存储消息到数据库失败: {e}") traceback.print_exc() - + # 使用消息管理器处理消息(保持原有功能) from src.common.data_models.database_data_model import DatabaseMessages diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index d31daeaee..559490694 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -89,7 +89,7 @@ class ChatStream: # 复制 stream_context,但跳过 processing_task new_stream.stream_context = copy.deepcopy(self.stream_context, memo) - if hasattr(new_stream.stream_context, 'processing_task'): + if hasattr(new_stream.stream_context, "processing_task"): new_stream.stream_context.processing_task = None # 复制 context_manager @@ -377,6 +377,7 @@ class ChatStream: # 默认基础分 return 0.3 + class ChatManager: """聊天管理器,管理所有聊天流""" @@ -563,9 +564,8 @@ class ChatManager: if not hasattr(stream, "context_manager"): # 创建新的单流上下文管理器 from src.chat.message_manager.context_manager import SingleStreamContextManager - stream.context_manager = SingleStreamContextManager( - stream_id=stream_id, context=stream.stream_context - ) + + stream.context_manager = SingleStreamContextManager(stream_id=stream_id, context=stream.stream_context) # 保存到内存和数据库 self.streams[stream_id] = stream @@ -721,6 +721,7 @@ class ChatManager: # 确保 ChatStream 有自己的 context_manager if not hasattr(stream, "context_manager"): from src.chat.message_manager.context_manager import SingleStreamContextManager + stream.context_manager = SingleStreamContextManager( stream_id=stream.stream_id, context=stream.stream_context ) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index f6041b7d4..fee932b62 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -108,7 +108,7 @@ class MessageRecv(Message): self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) self.raw_message = message_dict.get("raw_message") - + self.chat_stream = None self.reply = None self.processed_plain_text = message_dict.get("processed_plain_text", "") diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index a6aaa00eb..5a654e867 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -52,7 +52,11 @@ class MessageStorage: filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) else: # 如果没有设置display_message,使用processed_plain_text作为显示消息 - filtered_display_message = re.sub(pattern, "", message.processed_plain_text, flags=re.DOTALL) if message.processed_plain_text else "" + filtered_display_message = ( + re.sub(pattern, "", message.processed_plain_text, flags=re.DOTALL) + if message.processed_plain_text + else "" + ) interest_value = 0 is_mentioned = False reply_to = message.reply_to @@ -175,9 +179,11 @@ class MessageStorage: from src.common.database.sqlalchemy_models import get_db_session async with get_db_session() as session: - matched_message = (await session.execute( - select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) - )).scalar() + matched_message = ( + await session.execute( + select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) + ) + ).scalar() if matched_message: await session.execute( @@ -211,9 +217,11 @@ class MessageStorage: from src.common.database.sqlalchemy_models import get_db_session async with get_db_session() as session: - image_record = (await session.execute( - select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) - )).scalar() + image_record = ( + await session.execute( + select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) + ) + ).scalar() return f"[picid:{image_record.image_id}]" if image_record else match.group(0) except Exception: return match.group(0) @@ -294,15 +302,19 @@ class MessageStorage: from src.common.database.sqlalchemy_models import Messages # 查找需要修复的记录:interest_value为0、null或很小的值 - query = select(Messages).where( - (Messages.chat_id == chat_id) & - (Messages.time >= since_time) & - ( - (Messages.interest_value == 0) | - (Messages.interest_value.is_(None)) | - (Messages.interest_value < 0.1) + query = ( + select(Messages) + .where( + (Messages.chat_id == chat_id) + & (Messages.time >= since_time) + & ( + (Messages.interest_value == 0) + | (Messages.interest_value.is_(None)) + | (Messages.interest_value < 0.1) + ) ) - ).limit(50) # 限制每次修复的数量,避免性能问题 + .limit(50) + ) # 限制每次修复的数量,避免性能问题 result = await session.execute(query) messages_to_fix = result.scalars().all() @@ -314,7 +326,7 @@ class MessageStorage: default_interest = 0.3 # 默认中等兴趣度 # 如果消息内容较长,可能是重要消息,兴趣度稍高 - if hasattr(msg, 'processed_plain_text') and msg.processed_plain_text: + if hasattr(msg, "processed_plain_text") and msg.processed_plain_text: text_length = len(msg.processed_plain_text) if text_length > 50: # 长消息 default_interest = 0.4 @@ -322,13 +334,15 @@ class MessageStorage: default_interest = 0.35 # 如果是被@的消息,兴趣度更高 - if getattr(msg, 'is_mentioned', False): + if getattr(msg, "is_mentioned", False): default_interest = min(default_interest + 0.2, 0.8) # 执行更新 - update_stmt = update(Messages).where( - Messages.message_id == msg.message_id - ).values(interest_value=default_interest) + update_stmt = ( + update(Messages) + .where(Messages.message_id == msg.message_id) + .values(interest_value=default_interest) + ) result = await session.execute(update_stmt) if result.rowcount > 0: diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index f92956bc2..21a00ee52 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -40,7 +40,7 @@ class ChatterActionManager: @staticmethod def create_action( - action_name: str, + action_name: str, action_data: dict, reasoning: str, cycle_timers: dict, @@ -162,7 +162,7 @@ class ChatterActionManager: Returns: 执行结果 """ - from src.chat.message_manager.message_manager import message_manager + try: logger.debug(f"🎯 [ActionManager] execute_action接收到 target_message: {target_message}") # 通过chat_id获取chat_stream @@ -309,9 +309,7 @@ class ChatterActionManager: # 通过message_manager更新消息的动作记录并刷新focus_energy await message_manager.add_action( - stream_id=chat_stream.stream_id, - message_id=target_message_id, - action=action_name + stream_id=chat_stream.stream_id, message_id=target_message_id, action=action_name ) logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy") @@ -321,9 +319,10 @@ class ChatterActionManager: async def _reset_interruption_count_after_action(self, stream_id: str): """在动作执行成功后重置打断计数""" - from src.chat.message_manager.message_manager import message_manager + try: from src.plugin_system.apis.chat_api import get_chat_manager + chat_manager = get_chat_manager() chat_stream = chat_manager.get_stream(stream_id) if chat_stream: @@ -332,7 +331,9 @@ class ChatterActionManager: old_count = context.context.interruption_count old_afc_adjustment = context.context.get_afc_threshold_adjustment() context.context.reset_interruption_count() - logger.debug(f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0") + logger.debug( + f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0" + ) except Exception as e: logger.warning(f"重置打断计数时出错: {e}") @@ -531,7 +532,7 @@ class ChatterActionManager: # 根据新消息数量决定是否需要引用回复 reply_text = "" is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True - + logger.debug(f"[send_response] message_data: {message_data}") first_replied = False @@ -558,7 +559,9 @@ class ChatterActionManager: # 发送第一段回复 if not first_replied: set_reply_flag = bool(message_data) - logger.debug(f"📤 [ActionManager] 准备发送第一段回复。message_data: {message_data}, set_reply: {set_reply_flag}") + logger.debug( + f"📤 [ActionManager] 准备发送第一段回复。message_data: {message_data}, set_reply: {set_reply_flag}" + ) await send_api.text_to_stream( text=data, stream_id=chat_stream.stream_id, @@ -577,4 +580,4 @@ class ChatterActionManager: typing=True, ) - return reply_text \ No newline at end of file + return reply_text diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index ecd57639b..063fc1bf1 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -29,6 +29,7 @@ from src.chat.utils.chat_message_builder import ( replace_user_references_sync, ) from src.chat.express.expression_selector import expression_selector + # 旧记忆系统已被移除 # 旧记忆系统已被移除 from src.mood.mood_manager import mood_manager @@ -562,7 +563,9 @@ class DefaultReplyer: memory_context["user_aliases"] = memory_aliases if group_info_obj is not None: - group_name = getattr(group_info_obj, "group_name", None) or getattr(group_info_obj, "group_nickname", None) + group_name = getattr(group_info_obj, "group_name", None) or getattr( + group_info_obj, "group_nickname", None + ) if group_name: memory_context["group_name"] = str(group_name) group_id = getattr(group_info_obj, "group_id", None) @@ -576,11 +579,7 @@ class DefaultReplyer: # 检索相关记忆 enhanced_memories = await memory_system.retrieve_relevant_memories( - query=target, - user_id=memory_user_id, - scope_id=stream.stream_id, - context=memory_context, - limit=10 + query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10 ) # 注意:记忆存储已迁移到回复生成完成后进行,不在查询阶段执行 @@ -591,23 +590,27 @@ class DefaultReplyer: logger.debug(f"[记忆转换] 收到 {len(enhanced_memories)} 条原始记忆") for idx, memory_chunk in enumerate(enhanced_memories, 1): # 获取结构化内容的字符串表示 - structure_display = str(memory_chunk.content) if hasattr(memory_chunk, 'content') else "unknown" - + structure_display = str(memory_chunk.content) if hasattr(memory_chunk, "content") else "unknown" + # 获取记忆内容,优先使用display content = memory_chunk.display or memory_chunk.text_content or "" - + # 调试:记录每条记忆的内容获取情况 - logger.debug(f"[记忆转换] 第{idx}条: display={repr(memory_chunk.display)[:80]}, text_content={repr(memory_chunk.text_content)[:80]}, final_content={repr(content)[:80]}") - - running_memories.append({ - "content": content, - "memory_type": memory_chunk.memory_type.value, - "confidence": memory_chunk.metadata.confidence.value, - "importance": memory_chunk.metadata.importance.value, - "relevance": getattr(memory_chunk.metadata, 'relevance_score', 0.5), - "source": memory_chunk.metadata.source, - "structure": structure_display, - }) + logger.debug( + f"[记忆转换] 第{idx}条: display={repr(memory_chunk.display)[:80]}, text_content={repr(memory_chunk.text_content)[:80]}, final_content={repr(content)[:80]}" + ) + + running_memories.append( + { + "content": content, + "memory_type": memory_chunk.memory_type.value, + "confidence": memory_chunk.metadata.confidence.value, + "importance": memory_chunk.metadata.importance.value, + "relevance": getattr(memory_chunk.metadata, "relevance_score", 0.5), + "source": memory_chunk.metadata.source, + "structure": structure_display, + } + ) # 构建瞬时记忆字符串 if running_memories: @@ -615,7 +618,9 @@ class DefaultReplyer: if top_memory: instant_memory = top_memory[0].get("content", "") - logger.info(f"增强记忆系统检索到 {len(enhanced_memories)} 条原始记忆,转换为 {len(running_memories)} 条可用记忆") + logger.info( + f"增强记忆系统检索到 {len(enhanced_memories)} 条原始记忆,转换为 {len(running_memories)} 条可用记忆" + ) except Exception as e: logger.warning(f"增强记忆系统检索失败: {e}") @@ -632,17 +637,17 @@ class DefaultReplyer: memory_parts = ["### 🧠 相关记忆 (Relevant Memories)", ""] # 按相关度排序,并记录相关度信息用于调试 - sorted_memories = sorted(running_memories, key=lambda x: x.get('relevance', 0.0), reverse=True) + sorted_memories = sorted(running_memories, key=lambda x: x.get("relevance", 0.0), reverse=True) # 调试相关度信息 - relevance_info = [(m.get('memory_type', 'unknown'), m.get('relevance', 0.0)) for m in sorted_memories] + relevance_info = [(m.get("memory_type", "unknown"), m.get("relevance", 0.0)) for m in sorted_memories] logger.debug(f"记忆相关度信息: {relevance_info}") logger.debug(f"[记忆构建] 准备将 {len(sorted_memories)} 条记忆添加到提示词") for idx, running_memory in enumerate(sorted_memories, 1): - content = running_memory.get('content', '') - memory_type = running_memory.get('memory_type', 'unknown') - + content = running_memory.get("content", "") + memory_type = running_memory.get("memory_type", "unknown") + # 跳过空内容 if not content or not content.strip(): logger.warning(f"[记忆构建] 跳过第 {idx} 条记忆:内容为空 (type={memory_type})") @@ -801,10 +806,10 @@ class DefaultReplyer: """ try: # 从message_manager获取真实的已读/未读消息 - from src.chat.message_manager.message_manager import message_manager # 获取聊天流的上下文 from src.plugin_system.apis.chat_api import get_chat_manager + chat_manager = get_chat_manager() chat_stream = chat_manager.get_stream(chat_id) if chat_stream: @@ -1000,7 +1005,9 @@ class DefaultReplyer: interest_scores = {} try: - from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system as interest_scoring_system + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ( + chatter_interest_scoring_system as interest_scoring_system, + ) from src.common.data_models.database_data_model import DatabaseMessages # 转换消息格式 @@ -1145,7 +1152,7 @@ class DefaultReplyer: platform, # type: ignore reply_message.get("user_id"), # type: ignore reply_message.get("user_nickname"), - reply_message.get("user_cardname") + reply_message.get("user_cardname"), ) # 检查是否是bot自己的名字,如果是则替换为"(你)" @@ -1657,6 +1664,7 @@ class DefaultReplyer: # 创建关系追踪器实例 from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system + relationship_tracker = ChatterRelationshipTracker(chatter_interest_scoring_system) if relationship_tracker: # 获取用户信息以获取真实的user_id @@ -1699,7 +1707,7 @@ class DefaultReplyer: async def _store_chat_memory_async(self, reply_to: str, reply_message: Optional[Dict[str, Any]] = None): """ 异步存储聊天记忆(从build_memory_block迁移而来) - + Args: reply_to: 回复对象 reply_message: 回复的原始消息 @@ -1768,9 +1776,7 @@ class DefaultReplyer: memory_aliases.append(stripped) alias_values = ( - user_info_dict.get("aliases") - or user_info_dict.get("alias_names") - or user_info_dict.get("alias") + user_info_dict.get("aliases") or user_info_dict.get("alias_names") or user_info_dict.get("alias") ) if isinstance(alias_values, (list, tuple, set)): for alias in alias_values: @@ -1794,7 +1800,9 @@ class DefaultReplyer: memory_context["user_aliases"] = memory_aliases if group_info_obj is not None: - group_name = getattr(group_info_obj, "group_name", None) or getattr(group_info_obj, "group_nickname", None) + group_name = getattr(group_info_obj, "group_name", None) or getattr( + group_info_obj, "group_nickname", None + ) if group_name: memory_context["group_name"] = str(group_name) group_id = getattr(group_info_obj, "group_id", None) @@ -1826,11 +1834,11 @@ class DefaultReplyer: "conversation_text": chat_history, "user_id": memory_user_id, "scope_id": stream.stream_id, - **memory_context + **memory_context, } ) ) - + logger.debug(f"已启动记忆存储任务,用户: {memory_user_display or memory_user_id}") except Exception as e: diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 277ba3d23..8503e369a 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -13,6 +13,7 @@ from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_m from src.common.database.sqlalchemy_database_api import get_db_session from sqlalchemy import select, and_ from src.common.logger import get_logger + logger = get_logger("chat_message_builder") install(extra_lines=3) @@ -277,21 +278,52 @@ async def get_actions_by_timestamp_with_chat( async with get_db_session() as session: if limit > 0: + result = await session.execute( + select(ActionRecords) + .where( + and_( + ActionRecords.chat_id == chat_id, + ActionRecords.time >= timestamp_start, + ActionRecords.time <= timestamp_end, + ) + ) + .order_by(ActionRecords.time.desc()) + .limit(limit) + ) + actions = list(result.scalars()) + actions_result = [] + for action in reversed(actions): + action_dict = { + "id": action.id, + "action_id": action.action_id, + "time": action.time, + "action_name": action.action_name, + "action_data": action.action_data, + "action_done": action.action_done, + "action_build_into_prompt": action.action_build_into_prompt, + "action_prompt_display": action.action_prompt_display, + "chat_id": action.chat_id, + "chat_info_stream_id": action.chat_info_stream_id, + "chat_info_platform": action.chat_info_platform, + } + actions_result.append(action_dict) + actions_result.append(action_dict) + else: # earliest result = await session.execute( select(ActionRecords) .where( and_( ActionRecords.chat_id == chat_id, - ActionRecords.time >= timestamp_start, - ActionRecords.time <= timestamp_end, + ActionRecords.time > timestamp_start, + ActionRecords.time < timestamp_end, ) ) - .order_by(ActionRecords.time.desc()) + .order_by(ActionRecords.time.asc()) .limit(limit) ) actions = list(result.scalars()) actions_result = [] - for action in reversed(actions): + for action in actions: action_dict = { "id": action.id, "action_id": action.action_id, @@ -306,37 +338,6 @@ async def get_actions_by_timestamp_with_chat( "chat_info_platform": action.chat_info_platform, } actions_result.append(action_dict) - actions_result.append(action_dict) - else: # earliest - result = await session.execute( - select(ActionRecords) - .where( - and_( - ActionRecords.chat_id == chat_id, - ActionRecords.time > timestamp_start, - ActionRecords.time < timestamp_end, - ) - ) - .order_by(ActionRecords.time.asc()) - .limit(limit) - ) - actions = list(result.scalars()) - actions_result = [] - for action in actions: - action_dict = { - "id": action.id, - "action_id": action.action_id, - "time": action.time, - "action_name": action.action_name, - "action_data": action.action_data, - "action_done": action.action_done, - "action_build_into_prompt": action.action_build_into_prompt, - "action_prompt_display": action.action_prompt_display, - "chat_id": action.chat_id, - "chat_info_stream_id": action.chat_info_stream_id, - "chat_info_platform": action.chat_info_platform, - } - actions_result.append(action_dict) else: result = await session.execute( select(ActionRecords) @@ -460,7 +461,9 @@ async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp_with_chat( + chat_id: str, timestamp: float, limit: int = 0 +) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ @@ -469,7 +472,9 @@ async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -async def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp_with_users( + timestamp: float, person_ids: list, limit: int = 0 +) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ @@ -478,7 +483,9 @@ async def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: +async def num_new_messages_since( + chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None +) -> int: """ 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。 如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。 @@ -828,7 +835,7 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: async with get_db_session() as session: result = await session.execute(select(Images).where(Images.image_id == pic_id)) image = result.scalar_one_or_none() - if image and hasattr(image, 'description') and image.description: + if image and hasattr(image, "description") and image.description: description = image.description except Exception as e: # 如果查询失败,保持默认描述 @@ -1015,23 +1022,29 @@ async def build_readable_messages( async with get_db_session() as session: # 获取这个时间范围内的动作记录,并匹配chat_id - actions_in_range = (await session.execute( - select(ActionRecords) - .where( - and_( - ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id + actions_in_range = ( + await session.execute( + select(ActionRecords) + .where( + and_( + ActionRecords.time >= min_time, + ActionRecords.time <= max_time, + ActionRecords.chat_id == chat_id, + ) ) + .order_by(ActionRecords.time) ) - .order_by(ActionRecords.time) - )).scalars() + ).scalars() # 获取最新消息之后的第一个动作记录 - action_after_latest = (await session.execute( - select(ActionRecords) - .where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id)) - .order_by(ActionRecords.time) - .limit(1) - )).scalars() + action_after_latest = ( + await session.execute( + select(ActionRecords) + .where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id)) + .order_by(ActionRecords.time) + .limit(1) + ) + ).scalars() # 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError actions = [ @@ -1222,9 +1235,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: except Exception: return "?" - content = await replace_user_references_async( - content, platform, anon_name_resolver, replace_bot_name=False - ) + content = await replace_user_references_async(content, platform, anon_name_resolver, replace_bot_name=False) header = f"{anon_name}说 " output_lines.append(header) diff --git a/src/chat/utils/memory_mappings.py b/src/chat/utils/memory_mappings.py index 79ce50ade..4da20fdb5 100644 --- a/src/chat/utils/memory_mappings.py +++ b/src/chat/utils/memory_mappings.py @@ -17,7 +17,7 @@ MEMORY_TYPE_CHINESE_MAPPING = { "goal": "目标计划", "experience": "经验教训", "contextual": "上下文信息", - "unknown": "未知" + "unknown": "未知", } # 置信度等级到中文标签的映射表 @@ -30,7 +30,7 @@ CONFIDENCE_LEVEL_CHINESE_MAPPING = { "MEDIUM": "中等置信度", "HIGH": "高置信度", "VERIFIED": "已验证", - "unknown": "未知" + "unknown": "未知", } # 重要性等级到中文标签的映射表 @@ -43,7 +43,7 @@ IMPORTANCE_LEVEL_CHINESE_MAPPING = { "NORMAL": "一般重要性", "HIGH": "高重要性", "CRITICAL": "关键重要性", - "unknown": "未知" + "unknown": "未知", } @@ -69,7 +69,7 @@ def get_confidence_level_chinese_label(level) -> str: str: 对应的中文标签,如果找不到则返回"未知" """ # 处理枚举实例 - if hasattr(level, 'value'): + if hasattr(level, "value"): level = level.value # 处理数字 @@ -94,7 +94,7 @@ def get_importance_level_chinese_label(level) -> str: str: 对应的中文标签,如果找不到则返回"未知" """ # 处理枚举实例 - if hasattr(level, 'value'): + if hasattr(level, "value"): level = level.value # 处理数字 @@ -106,4 +106,4 @@ def get_importance_level_chinese_label(level) -> str: level_upper = level.upper() return IMPORTANCE_LEVEL_CHINESE_MAPPING.get(level_upper, "未知") - return "未知" \ No newline at end of file + return "未知" diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 3e011bb15..baf77a143 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -377,12 +377,12 @@ class Prompt: # 性能优化 - 为不同任务设置不同的超时时间 task_timeouts = { - "memory_block": 15.0, # 记忆系统 - 降低超时时间,鼓励预构建 - "tool_info": 15.0, # 工具信息 - "relation_info": 10.0, # 关系信息 - "knowledge_info": 10.0, # 知识库查询 - "cross_context": 10.0, # 上下文处理 - "expression_habits": 10.0, # 表达习惯 + "memory_block": 15.0, # 记忆系统 - 降低超时时间,鼓励预构建 + "tool_info": 15.0, # 工具信息 + "relation_info": 10.0, # 关系信息 + "knowledge_info": 10.0, # 知识库查询 + "cross_context": 10.0, # 上下文处理 + "expression_habits": 10.0, # 表达习惯 } # 分别处理每个任务,避免慢任务影响快任务 @@ -559,7 +559,7 @@ class Prompt: ), enhanced_memory_activator.get_instant_memory( target_message=self.parameters.target, chat_id=self.parameters.chat_id - ) + ), ] try: @@ -602,26 +602,27 @@ class Prompt: "opinion": "opinion", "personal_fact": "personal_fact", "preference": "preference", - "event": "event" + "event": "event", } mapped_type = memory_type_mapping.get(topic, "personal_fact") - formatted_memories.append({ - "display": display_text, - "memory_type": mapped_type, - "metadata": { - "confidence": memory.get("confidence", "未知"), - "importance": memory.get("importance", "一般"), - "timestamp": memory.get("timestamp", ""), - "source": memory.get("source", "unknown"), - "relevance_score": memory.get("relevance_score", 0.0) + formatted_memories.append( + { + "display": display_text, + "memory_type": mapped_type, + "metadata": { + "confidence": memory.get("confidence", "未知"), + "importance": memory.get("importance", "一般"), + "timestamp": memory.get("timestamp", ""), + "source": memory.get("source", "unknown"), + "relevance_score": memory.get("relevance_score", 0.0), + }, } - }) + ) # 使用方括号格式格式化记忆 memory_block = format_memories_bracket_style( - formatted_memories, - query_context=self.parameters.target + formatted_memories, query_context=self.parameters.target ) except Exception as e: logger.warning(f"记忆格式化失败,使用简化格式: {e}") @@ -829,7 +830,8 @@ class Prompt: "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), "safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), - "chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", + "chat_scene": self.parameters.chat_scene + or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", } def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: @@ -856,7 +858,8 @@ class Prompt: "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), "safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), - "chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", + "chat_scene": self.parameters.chat_scene + or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", } def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index ed8530387..1c879a01b 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -293,11 +293,14 @@ class StatisticOutputTask(AsyncTask): # 以最早的时间戳为起始时间获取记录 query_start_time = collect_period[-1][1] - records = await db_get( - model_class=LLMUsage, - filters={"timestamp": {"$gte": query_start_time}}, - order_by="-timestamp", - ) or [] + records = ( + await db_get( + model_class=LLMUsage, + filters={"timestamp": {"$gte": query_start_time}}, + order_by="-timestamp", + ) + or [] + ) for record in records: if not isinstance(record, dict): @@ -389,7 +392,9 @@ class StatisticOutputTask(AsyncTask): return stats @staticmethod - async def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]: + async def _collect_online_time_for_period( + collect_period: List[Tuple[str, datetime]], now: datetime + ) -> Dict[str, Any]: """ 收集指定时间段的在线时间统计数据 @@ -408,11 +413,14 @@ class StatisticOutputTask(AsyncTask): } query_start_time = collect_period[-1][1] - records = await db_get( - model_class=OnlineTime, - filters={"end_timestamp": {"$gte": query_start_time}}, - order_by="-end_timestamp", - ) or [] + records = ( + await db_get( + model_class=OnlineTime, + filters={"end_timestamp": {"$gte": query_start_time}}, + order_by="-end_timestamp", + ) + or [] + ) for record in records: if not isinstance(record, dict): @@ -464,11 +472,14 @@ class StatisticOutputTask(AsyncTask): } query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) - records = await db_get( - model_class=Messages, - filters={"time": {"$gte": query_start_timestamp}}, - order_by="-time", - ) or [] + records = ( + await db_get( + model_class=Messages, + filters={"time": {"$gte": query_start_timestamp}}, + order_by="-time", + ) + or [] + ) for message in records: if not isinstance(message, dict): @@ -1026,11 +1037,14 @@ class StatisticOutputTask(AsyncTask): interval_seconds = interval_minutes * 60 # 单次查询 LLMUsage - llm_records = await db_get( - model_class=LLMUsage, - filters={"timestamp": {"$gte": start_time}}, - order_by="-timestamp", - ) or [] + llm_records = ( + await db_get( + model_class=LLMUsage, + filters={"timestamp": {"$gte": start_time}}, + order_by="-timestamp", + ) + or [] + ) for record in llm_records: if not isinstance(record, dict) or not record.get("timestamp"): continue @@ -1056,11 +1070,14 @@ class StatisticOutputTask(AsyncTask): cost_by_module[module_name][idx] += cost # 单次查询 Messages - msg_records = await db_get( - model_class=Messages, - filters={"time": {"$gte": start_time.timestamp()}}, - order_by="-time", - ) or [] + msg_records = ( + await db_get( + model_class=Messages, + filters={"time": {"$gte": start_time.timestamp()}}, + order_by="-time", + ) + or [] + ) for msg in msg_records: if not isinstance(msg, dict) or not msg.get("time"): continue @@ -1363,4 +1380,4 @@ class StatisticOutputTask(AsyncTask): }}); - """ \ No newline at end of file + """ diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 989428677..ea3bdc89f 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -670,7 +670,6 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: if loop.is_running(): # 如果事件循环在运行,从其他线程提交并等待结果 try: - from concurrent.futures import TimeoutError fut = asyncio.run_coroutine_threadsafe( person_info_manager.get_value(person_id, "person_name"), loop diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 948b6f237..ab0915842 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -81,14 +81,16 @@ class ImageManager: """ try: async with get_db_session() as session: - record = (await session.execute( - select(ImageDescriptions).where( - and_( - ImageDescriptions.image_description_hash == image_hash, - ImageDescriptions.type == description_type, + record = ( + await session.execute( + select(ImageDescriptions).where( + and_( + ImageDescriptions.image_description_hash == image_hash, + ImageDescriptions.type == description_type, + ) ) ) - )).scalar() + ).scalar() return record.description if record else None except Exception as e: logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}") @@ -107,14 +109,16 @@ class ImageManager: current_timestamp = time.time() async with get_db_session() as session: # 查找现有记录 - existing = (await session.execute( - select(ImageDescriptions).where( - and_( - ImageDescriptions.image_description_hash == image_hash, - ImageDescriptions.type == description_type, + existing = ( + await session.execute( + select(ImageDescriptions).where( + and_( + ImageDescriptions.image_description_hash == image_hash, + ImageDescriptions.type == description_type, + ) ) ) - )).scalar() + ).scalar() if existing: # 更新现有记录 @@ -262,9 +266,11 @@ class ImageManager: from src.common.database.sqlalchemy_models import get_db_session async with get_db_session() as session: - existing_img = (await session.execute( - select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji")) - )).scalar() + existing_img = ( + await session.execute( + select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji")) + ) + ).scalar() if existing_img: existing_img.path = file_path diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index bbe489336..19ec72cb6 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -29,7 +29,7 @@ logger = get_logger("utils_video") # Rust模块可用性检测 RUST_VIDEO_AVAILABLE = False try: - import rust_video # pyright: ignore[reportMissingImports] + import rust_video # pyright: ignore[reportMissingImports] RUST_VIDEO_AVAILABLE = True logger.info("✅ Rust 视频处理模块加载成功") @@ -220,7 +220,7 @@ class VideoAnalyzer: return None async def _store_video_result( - self, video_hash: str, description: str, metadata: Optional[Dict] = None + self, video_hash: str, description: str, metadata: Optional[Dict] = None ) -> Optional[Videos]: """存储视频分析结果到数据库""" # 检查描述是否为错误信息,如果是则不保存 diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index 6ebe5c3d8..c77d9e8bd 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -77,7 +77,10 @@ class CacheManager: embedding_array = embedding_array.flatten() # 检查维度是否符合预期 - expected_dim = getattr(CacheManager, "embedding_dimension", None) or global_config.lpmm_knowledge.embedding_dimension + expected_dim = ( + getattr(CacheManager, "embedding_dimension", None) + or global_config.lpmm_knowledge.embedding_dimension + ) if embedding_array.shape[0] != expected_dim: logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}") return None diff --git a/src/common/database/db_migration.py b/src/common/database/db_migration.py index 085c277a3..2ab7ba13e 100644 --- a/src/common/database/db_migration.py +++ b/src/common/database/db_migration.py @@ -82,22 +82,26 @@ async def check_and_migrate_database(): if dialect.name == "sqlite" and isinstance(default_arg, bool): # SQLite 将布尔值存储为 0 或 1 default_value = "1" if default_arg else "0" - elif hasattr(compiler, 'render_literal_value'): + elif hasattr(compiler, "render_literal_value"): try: # 尝试使用 render_literal_value default_value = compiler.render_literal_value(default_arg, column.type) except AttributeError: # 如果失败,则回退到简单的字符串转换 - default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) + default_value = ( + f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) + ) else: # 对于没有 render_literal_value 的旧版或特定方言 - default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) - + default_value = ( + f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) + ) + sql += f" DEFAULT {default_value}" if not column.nullable: sql += " NOT NULL" - + conn.execute(text(sql)) logger.info(f"成功向表 '{table_name}' 添加列 '{column_name}'。") @@ -131,4 +135,3 @@ async def check_and_migrate_database(): continue logger.info("数据库结构检查与自动迁移完成。") - diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 210882dc8..330846983 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -423,4 +423,4 @@ async def store_action_info( except Exception as e: logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}") traceback.print_exc() - return None \ No newline at end of file + return None diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 64c1fd66a..2f78e56d0 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -781,6 +781,7 @@ async def get_db_session() -> AsyncGenerator[Optional[AsyncSession], None]: session = SessionLocal() # 对于 SQLite,在会话开始时设置 PRAGMA from src.config.config import global_config + if global_config.database.database_type == "sqlite": await session.execute(text("PRAGMA busy_timeout = 60000")) await session.execute(text("PRAGMA foreign_keys = ON")) diff --git a/src/common/logger.py b/src/common/logger.py index fa5e27d04..2830c127d 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -781,23 +781,23 @@ class ModuleColoredConsoleRenderer: thought_color = "\033[38;5;218m" # 分割消息内容 prefix, thought = event_content.split("内心思考:", 1) - + # 前缀部分(“决定进行回复,”)使用模块颜色 if module_color: prefix_colored = f"{module_color}{prefix.strip()}{RESET_COLOR}" else: prefix_colored = prefix.strip() - + # “内心思考”部分换行并使用专属颜色 thought_colored = f"\n\n{thought_color}内心思考:{thought.strip()}{RESET_COLOR}\n" - + # 重新组合 # parts.append(prefix_colored + thought_colored) # 将前缀和思考内容作为独立的part添加,避免它们之间出现多余的空格 if prefix_colored: parts.append(prefix_colored) parts.append(thought_colored) - + elif module_color: event_content = f"{module_color}{event_content}{RESET_COLOR}" parts.append(event_content) diff --git a/src/common/vector_db/chromadb_impl.py b/src/common/vector_db/chromadb_impl.py index d50ddf0de..a0267dfed 100644 --- a/src/common/vector_db/chromadb_impl.py +++ b/src/common/vector_db/chromadb_impl.py @@ -98,13 +98,13 @@ class ChromaDBImpl(VectorDBBase): "n_results": n_results, **kwargs, } - + # 修复ChromaDB的where条件格式 if where: processed_where = self._process_where_condition(where) if processed_where: query_params["where"] = processed_where - + return collection.query(**query_params) except Exception as e: logger.error(f"查询集合 '{collection_name}' 失败: {e}") @@ -114,7 +114,7 @@ class ChromaDBImpl(VectorDBBase): "query_embeddings": query_embeddings, "n_results": n_results, } - logger.warning(f"使用回退查询模式(无where条件)") + logger.warning("使用回退查询模式(无where条件)") return collection.query(**fallback_params) except Exception as fallback_e: logger.error(f"回退查询也失败: {fallback_e}") @@ -124,19 +124,19 @@ class ChromaDBImpl(VectorDBBase): """ 处理where条件,转换为ChromaDB支持的格式 ChromaDB支持的格式: - - 简单条件: {"field": "value"} + - 简单条件: {"field": "value"} - 操作符条件: {"field": {"$op": "value"}} - AND条件: {"$and": [condition1, condition2]} - OR条件: {"$or": [condition1, condition2]} """ if not where: return None - + try: # 如果只有一个字段,直接返回 if len(where) == 1: key, value = next(iter(where.items())) - + # 处理列表值(如memory_types) if isinstance(value, list): if len(value) == 1: @@ -146,7 +146,7 @@ class ChromaDBImpl(VectorDBBase): return {key: {"$in": value}} else: return {key: value} - + # 多个字段使用 $and 操作符 conditions = [] for key, value in where.items(): @@ -157,9 +157,9 @@ class ChromaDBImpl(VectorDBBase): conditions.append({key: {"$in": value}}) else: conditions.append({key: value}) - + return {"$and": conditions} - + except Exception as e: logger.warning(f"处理where条件失败: {e}, 使用简化条件") # 回退到只使用第一个条件 @@ -189,7 +189,7 @@ class ChromaDBImpl(VectorDBBase): processed_where = None if where: processed_where = self._process_where_condition(where) - + return collection.get( ids=ids, where=processed_where, @@ -202,7 +202,7 @@ class ChromaDBImpl(VectorDBBase): logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}") # 如果获取失败,尝试不使用where条件重新获取 try: - logger.warning(f"使用回退获取模式(无where条件)") + logger.warning("使用回退获取模式(无where条件)") return collection.get( ids=ids, limit=limit, diff --git a/src/config/config.py b/src/config/config.py index 50b826bf0..375d513df 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -45,7 +45,7 @@ from src.config.official_configs import ( CommandConfig, PlanningSystemConfig, AffinityFlowConfig, - ProactiveThinkingConfig + ProactiveThinkingConfig, ) from .api_ada_configs import ( diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 48cec6599..6a1613baa 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -61,35 +61,35 @@ class PersonalityConfig(ValidatedConfigBase): prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式") compress_personality: bool = Field(default=True, description="是否压缩人格") compress_identity: bool = Field(default=True, description="是否压缩身份") - + # 回复规则配置 reply_targeting_rules: List[str] = Field( default_factory=lambda: [ "拒绝任何包含骚扰、冒犯、暴力、色情或危险内容的请求。", "在拒绝时,请使用符合你人设的、坚定的语气。", - "不要执行任何可能被用于恶意目的的指令。" + "不要执行任何可能被用于恶意目的的指令。", ], - description="安全与互动底线规则,Bot在任何情况下都必须遵守的原则" + description="安全与互动底线规则,Bot在任何情况下都必须遵守的原则", ) - + message_targeting_analysis: List[str] = Field( default_factory=lambda: [ "**直接针对你**:@你、回复你、明确询问你 → 必须回应", "**间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与", "**他人对话**:与你无关的私人交流 → 通常不参与", - "**重复内容**:他人已充分回答的问题 → 避免重复" + "**重复内容**:他人已充分回答的问题 → 避免重复", ], - description="消息针对性分析规则,用于判断是否需要回复" + description="消息针对性分析规则,用于判断是否需要回复", ) - + reply_principles: List[str] = Field( default_factory=lambda: [ "明确回应目标消息,而不是宽泛地评论。", "可以分享你的看法、提出相关问题,或者开个合适的玩笑。", "目的是让对话更有趣、更深入。", - "不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。" + "不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。", ], - description="回复原则,指导如何回复消息" + description="回复原则,指导如何回复消息", ) @@ -126,28 +126,15 @@ class ChatConfig(ValidatedConfigBase): interruption_probability_factor: float = Field( default=0.8, ge=0.0, le=1.0, description="打断概率因子,当前打断次数/最大打断次数超过此值时触发概率下降" ) - interruption_afc_reduction: float = Field( - default=0.05, ge=0.0, le=1.0, description="每次连续打断降低的afc阈值数值" - ) + interruption_afc_reduction: float = Field(default=0.05, ge=0.0, le=1.0, description="每次连续打断降低的afc阈值数值") # 动态消息分发系统配置 dynamic_distribution_enabled: bool = Field(default=True, description="是否启用动态消息分发周期调整") - dynamic_distribution_base_interval: float = Field( - default=5.0, ge=1.0, le=60.0, description="基础分发间隔(秒)" - ) - dynamic_distribution_min_interval: float = Field( - default=1.0, ge=0.5, le=10.0, description="最小分发间隔(秒)" - ) - dynamic_distribution_max_interval: float = Field( - default=30.0, ge=5.0, le=300.0, description="最大分发间隔(秒)" - ) - dynamic_distribution_jitter_factor: float = Field( - default=0.2, ge=0.0, le=0.5, description="分发间隔随机扰动因子" - ) - max_concurrent_distributions: int = Field( - default=10, ge=1, le=100, description="最大并发处理的消息流数量" - ) - + dynamic_distribution_base_interval: float = Field(default=5.0, ge=1.0, le=60.0, description="基础分发间隔(秒)") + dynamic_distribution_min_interval: float = Field(default=1.0, ge=0.5, le=10.0, description="最小分发间隔(秒)") + dynamic_distribution_max_interval: float = Field(default=30.0, ge=5.0, le=300.0, description="最大分发间隔(秒)") + dynamic_distribution_jitter_factor: float = Field(default=0.2, ge=0.0, le=0.5, description="分发间隔随机扰动因子") + max_concurrent_distributions: int = Field(default=10, ge=1, le=100, description="最大并发处理的消息流数量") class MessageReceiveConfig(ValidatedConfigBase): @@ -309,11 +296,13 @@ class MemoryConfig(ValidatedConfigBase): enable_vector_memory_storage: bool = Field(default=True, description="启用Vector DB记忆存储") enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆") enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆") - + # Vector DB配置 vector_db_memory_collection: str = Field(default="unified_memory_v2", description="Vector DB记忆集合名称") vector_db_metadata_collection: str = Field(default="memory_metadata_v2", description="Vector DB元数据集合名称") - vector_db_similarity_threshold: float = Field(default=0.5, description="Vector DB相似度阈值(推荐0.5-0.6,过高会导致检索不到结果)") + vector_db_similarity_threshold: float = Field( + default=0.5, description="Vector DB相似度阈值(推荐0.5-0.6,过高会导致检索不到结果)" + ) vector_db_search_limit: int = Field(default=20, description="Vector DB搜索限制") vector_db_batch_size: int = Field(default=100, description="批处理大小") vector_db_enable_caching: bool = Field(default=True, description="启用内存缓存") @@ -327,23 +316,23 @@ class MemoryConfig(ValidatedConfigBase): base_forgetting_days: float = Field(default=30.0, description="基础遗忘天数") min_forgetting_days: float = Field(default=7.0, description="最小遗忘天数") max_forgetting_days: float = Field(default=365.0, description="最大遗忘天数") - + # 重要程度权重 critical_importance_bonus: float = Field(default=45.0, description="关键重要性额外天数") high_importance_bonus: float = Field(default=30.0, description="高重要性额外天数") normal_importance_bonus: float = Field(default=15.0, description="一般重要性额外天数") low_importance_bonus: float = Field(default=0.0, description="低重要性额外天数") - + # 置信度权重 verified_confidence_bonus: float = Field(default=30.0, description="已验证置信度额外天数") high_confidence_bonus: float = Field(default=20.0, description="高置信度额外天数") medium_confidence_bonus: float = Field(default=10.0, description="中等置信度额外天数") low_confidence_bonus: float = Field(default=0.0, description="低置信度额外天数") - + # 激活频率权重 activation_frequency_weight: float = Field(default=0.5, description="每次激活增加的天数权重") max_frequency_bonus: float = Field(default=10.0, description="最大激活频率奖励天数") - + # 休眠机制 dormant_threshold_days: int = Field(default=90, description="休眠状态判定天数") @@ -596,6 +585,7 @@ class SleepSystemConfig(ValidatedConfigBase): default="我准备睡觉了,请生成一句简短自然的晚安问候。", description="用于生成睡前消息的提示" ) + class ContextGroup(ValidatedConfigBase): """上下文共享组配置""" @@ -658,6 +648,7 @@ class AffinityFlowConfig(ValidatedConfigBase): mention_bot_interest_score: float = Field(default=0.6, description="提及bot的兴趣分") base_relationship_score: float = Field(default=0.5, description="基础人物关系分") + class ProactiveThinkingConfig(ValidatedConfigBase): """主动思考(主动发起对话)功能配置""" @@ -666,24 +657,26 @@ class ProactiveThinkingConfig(ValidatedConfigBase): # --- 触发时机 --- interval: int = Field(default=1500, description="基础触发间隔(秒),AI会围绕这个时间点主动发起对话") - interval_sigma: int = Field(default=120, description="间隔随机化标准差(秒),让触发时间更自然。设为0则为固定间隔。") + interval_sigma: int = Field( + default=120, description="间隔随机化标准差(秒),让触发时间更自然。设为0则为固定间隔。" + ) talk_frequency_adjust: list[list[str]] = Field( - default_factory=lambda: [['', '8:00,1', '12:00,1.2', '18:00,1.5', '01:00,0.6']], - description='每日活跃度调整,格式:[["", "HH:MM,factor", ...], ["stream_id", ...]]' + default_factory=lambda: [["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"]], + description='每日活跃度调整,格式:[["", "HH:MM,factor", ...], ["stream_id", ...]]', ) # --- 作用范围 --- enable_in_private: bool = Field(default=True, description="是否允许在私聊中主动发起对话") enable_in_group: bool = Field(default=True, description="是否允许在群聊中主动发起对话") enabled_private_chats: List[str] = Field( - default_factory=list, - description='私聊白名单,为空则对所有私聊生效。格式: ["platform:user_id", ...]' + default_factory=list, description='私聊白名单,为空则对所有私聊生效。格式: ["platform:user_id", ...]' ) enabled_group_chats: List[str] = Field( - default_factory=list, - description='群聊白名单,为空则对所有群聊生效。格式: ["platform:group_id", ...]' + default_factory=list, description='群聊白名单,为空则对所有群聊生效。格式: ["platform:group_id", ...]' ) # --- 冷启动配置 (针对私聊) --- enable_cold_start: bool = Field(default=True, description="对于白名单中不活跃的私聊,是否允许进行一次“冷启动”问候") - cold_start_cooldown: int = Field(default=86400, description="冷启动后,该私聊的下一次主动思考需要等待的最小时间(秒)") + cold_start_cooldown: int = Field( + default=86400, description="冷启动后,该私聊的下一次主动思考需要等待的最小时间(秒)" + ) diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index a4a106387..4716921f9 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -84,7 +84,9 @@ class Individuality: full_personality = f"{personality_result},{identity_result}" # 获取全局兴趣评分系统实例 - from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system as interest_scoring_system + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ( + chatter_interest_scoring_system as interest_scoring_system, + ) # 初始化智能兴趣系统 await interest_scoring_system.initialize_smart_interests( @@ -114,7 +116,7 @@ class Individuality: @staticmethod def _get_config_hash( - bot_nickname: str, personality_core: str, personality_side: str, identity: str + bot_nickname: str, personality_core: str, personality_side: str, identity: str ) -> tuple[str, str]: """获取personality和identity配置的哈希值 diff --git a/src/llm_models/model_client/aiohttp_gemini_client.py b/src/llm_models/model_client/aiohttp_gemini_client.py index 7b997b680..3d4dd8ca1 100644 --- a/src/llm_models/model_client/aiohttp_gemini_client.py +++ b/src/llm_models/model_client/aiohttp_gemini_client.py @@ -453,9 +453,7 @@ class AiohttpGeminiClient(BaseClient): # 构建请求体 request_data = { "contents": contents, - "generationConfig": _build_generation_config( - max_tokens, temperature, tb, response_format, extra_params - ), + "generationConfig": _build_generation_config(max_tokens, temperature, tb, response_format, extra_params), } # 添加系统指令 diff --git a/src/llm_models/payload_content/message.py b/src/llm_models/payload_content/message.py index 4253efab0..17d1fa30b 100644 --- a/src/llm_models/payload_content/message.py +++ b/src/llm_models/payload_content/message.py @@ -58,7 +58,7 @@ class MessageBuilder: self, image_format: str, image_base64: str, - support_formats=None, # 默认支持格式 + support_formats=None, # 默认支持格式 ) -> "MessageBuilder": """ 添加图片内容 diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index bd895693b..a8a68c2fb 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -18,6 +18,7 @@ - **LLMRequest (主接口)**: 作为模块的统一入口(Facade),为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。 """ + import re import asyncio import time @@ -26,14 +27,13 @@ import string from enum import Enum from rich.traceback import install -from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator +from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine from src.common.logger import get_logger from src.config.config import model_config from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig from .payload_content.message import MessageBuilder, Message -from .payload_content.resp_format import RespFormat -from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType +from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder from .model_client.base_client import BaseClient, APIResponse, client_registry, UsageRecord from .utils import compress_messages, llm_usage_recorder from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException @@ -46,6 +46,7 @@ logger = get_logger("model_utils") # Standalone Utility Functions # ============================================================================== + def _normalize_image_format(image_format: str) -> str: """ 标准化图片格式名称,确保与各种API的兼容性 @@ -57,17 +58,26 @@ def _normalize_image_format(image_format: str) -> str: str: 标准化后的图片格式 """ format_mapping = { - "jpg": "jpeg", "JPG": "jpeg", "JPEG": "jpeg", "jpeg": "jpeg", - "png": "png", "PNG": "png", - "webp": "webp", "WEBP": "webp", - "gif": "gif", "GIF": "gif", - "heic": "heic", "HEIC": "heic", - "heif": "heif", "HEIF": "heif", + "jpg": "jpeg", + "JPG": "jpeg", + "JPEG": "jpeg", + "jpeg": "jpeg", + "png": "png", + "PNG": "png", + "webp": "webp", + "WEBP": "webp", + "gif": "gif", + "GIF": "gif", + "heic": "heic", + "HEIC": "heic", + "heif": "heif", + "HEIF": "heif", } normalized = format_mapping.get(image_format, image_format.lower()) logger.debug(f"图片格式标准化: {image_format} -> {normalized}") return normalized + async def execute_concurrently( coro_callable: Callable[..., Coroutine[Any, Any, Any]], concurrency_count: int, @@ -103,25 +113,29 @@ async def execute_concurrently( for i, res in enumerate(results): if isinstance(res, Exception): logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}") - + first_exception = next((res for res in results if isinstance(res, Exception)), None) if first_exception: raise first_exception raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息") + class RequestType(Enum): """请求类型枚举""" + RESPONSE = "response" EMBEDDING = "embedding" AUDIO = "audio" + # ============================================================================== # Helper Classes for LLMRequest Refactoring # ============================================================================== + class _ModelSelector: """负责模型选择、负载均衡和动态故障切换的策略。""" - + CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数 DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量 @@ -168,16 +182,18 @@ class _ModelSelector: # - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。 least_used_model_name = min( candidate_models_usage, - key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000, + key=lambda k: candidate_models_usage[k][0] + + candidate_models_usage[k][1] * 300 + + candidate_models_usage[k][2] * 1000, ) - + model_info = model_config.get_model_info(least_used_model_name) api_provider = model_config.get_provider(model_info.api_provider) # 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。 # 这是为了避免在某些高并发场景下,共享的ClientSession可能引发的事件循环相关问题。 force_new_client = request_type == "embedding" client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - + logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}") # 增加所选模型的请求使用惩罚值,以实现动态负载均衡。 self.update_usage_penalty(model_info.name, increase=True) @@ -214,26 +230,32 @@ class _ModelSelector: if isinstance(e, (NetworkConnectionError, ReqAbortException)): # 网络连接错误或请求被中断,通常是基础设施问题,应重罚 penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER - logger.warning(f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}") + logger.warning( + f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}" + ) elif isinstance(e, RespNotOkException): # 对于HTTP响应错误,重点关注服务器端错误 if e.status_code >= 500: # 5xx 错误表明服务器端出现问题,应重罚 penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER - logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}") + logger.warning( + f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}" + ) else: # 4xx 客户端错误通常不代表模型本身不可用,给予基础惩罚 - logger.warning(f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") + logger.warning( + f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}" + ) else: # 其他未知异常,给予基础惩罚 logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}") - + self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty) class _PromptProcessor: """封装所有与提示词和响应内容的预处理和后处理逻辑。""" - + def __init__(self): """ 初始化提示处理器。 @@ -276,18 +298,18 @@ class _PromptProcessor: """ # 步骤1: 根据API提供商的配置应用内容混淆 processed_prompt = self._apply_content_obfuscation(prompt, api_provider) - + # 步骤2: 检查模型是否需要注入反截断指令 if getattr(model_info, "use_anti_truncation", False): processed_prompt += self.anti_truncation_instruction logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。") - + return processed_prompt def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]: """ 处理响应内容,提取思维链并检查截断。 - + Returns: Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断) """ @@ -317,14 +339,14 @@ class _PromptProcessor: # 检查当前API提供商是否启用了内容混淆功能 if not getattr(api_provider, "enable_content_obfuscation", False): return text - + # 获取混淆强度,默认为1 intensity = getattr(api_provider, "obfuscation_intensity", 1) logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}") - + # 将抗审查指令和原始文本拼接 processed_text = self.noise_instruction + "\n\n" + text - + # 在拼接后的文本中注入随机噪音 return self._inject_random_noise(processed_text, intensity) @@ -346,12 +368,12 @@ class _PromptProcessor: # 定义不同强度级别的噪音参数:概率和长度范围 params = { 1: {"probability": 15, "length": (3, 6)}, # 低强度 - 2: {"probability": 25, "length": (5, 10)}, # 中强度 - 3: {"probability": 35, "length": (8, 15)}, # 高强度 + 2: {"probability": 25, "length": (5, 10)}, # 中强度 + 3: {"probability": 35, "length": (8, 15)}, # 高强度 } # 根据传入的强度选择配置,如果强度无效则使用默认值 config = params.get(intensity, params[1]) - + words = text.split() result = [] # 遍历每个单词 @@ -366,7 +388,7 @@ class _PromptProcessor: # 生成噪音字符串 noise = "".join(random.choice(chars) for _ in range(noise_length)) result.append(noise) - + # 将处理后的单词列表重新组合成字符串 return " ".join(result) @@ -396,7 +418,7 @@ class _PromptProcessor: else: reasoning = "" clean_content = content.strip() - + return clean_content, reasoning @@ -441,7 +463,7 @@ class _RequestExecutor: """ retry_remain = api_provider.max_retry compressed_messages: Optional[List[Message]] = None - + while retry_remain > 0: try: # 优先使用压缩后的消息列表 @@ -451,11 +473,11 @@ class _RequestExecutor: # 根据请求类型调用不同的客户端方法 if request_type == RequestType.RESPONSE: assert current_messages is not None, "message_list cannot be None for response requests" - + # 修复: 防止 'message_list' 在 kwargs 中重复传递 request_params = kwargs.copy() request_params.pop("message_list", None) - + return await client.get_response( model_info=model_info, message_list=current_messages, **request_params ) @@ -463,15 +485,19 @@ class _RequestExecutor: return await client.get_embedding(model_info=model_info, **kwargs) elif request_type == RequestType.AUDIO: return await client.get_audio_transcriptions(model_info=model_info, **kwargs) - + except Exception as e: logger.debug(f"请求失败: {str(e)}") # 记录失败并更新模型的惩罚值 self.model_selector.update_failure_penalty(model_info.name, e) - + # 处理异常,决定是否重试以及等待多久 wait_interval, new_compressed_messages = self._handle_exception( - e, model_info, api_provider, retry_remain, (kwargs.get("message_list"), compressed_messages is not None) + e, + model_info, + api_provider, + retry_remain, + (kwargs.get("message_list"), compressed_messages is not None), ) if new_compressed_messages: compressed_messages = new_compressed_messages # 更新为压缩后的消息 @@ -482,7 +508,7 @@ class _RequestExecutor: await asyncio.sleep(wait_interval) # 等待指定时间后重试 finally: retry_remain -= 1 - + logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") raise RuntimeError("请求失败,已达到最大重试次数") @@ -491,7 +517,7 @@ class _RequestExecutor: ) -> Tuple[int, Optional[List[Message]]]: """ 默认异常处理函数,决定是否重试。 - + Returns: (等待间隔(-1表示不再重试), 新的消息列表(适用于压缩消息)) """ @@ -534,7 +560,9 @@ class _RequestExecutor: model_name = model_info.name # 处理客户端错误 (400-404),这些错误通常是请求本身有问题,不应重试 if e.status_code in [400, 401, 402, 403, 404]: - logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。") + logger.warning( + f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。" + ) return -1, None # 处理请求体过大的情况 elif e.status_code == 413: @@ -570,9 +598,11 @@ class _RequestExecutor: """ # 只有在剩余重试次数大于1时才进行下一次重试(因为当前这次失败已经消耗掉一次) if remain_try > 1: - logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。") + logger.warning( + f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。" + ) return interval, None - + # 如果已无剩余重试次数,则记录错误并返回-1表示放弃 logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},已达最大重试次数,放弃。") return -1, None @@ -585,7 +615,14 @@ class _RequestStrategy: 即使在单个模型或API端点失败的情况下也能正常工作。 """ - def __init__(self, model_selector: _ModelSelector, prompt_processor: _PromptProcessor, executor: _RequestExecutor, model_list: List[str], task_name: str): + def __init__( + self, + model_selector: _ModelSelector, + prompt_processor: _PromptProcessor, + executor: _RequestExecutor, + model_list: List[str], + task_name: str, + ): """ 初始化请求策略。 @@ -616,11 +653,13 @@ class _RequestStrategy: last_exception: Optional[Exception] = None for attempt in range(max_attempts): - selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request, str(request_type.value)) + selection_result = self.model_selector.select_best_available_model( + failed_models_in_this_request, str(request_type.value) + ) if selection_result is None: logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。") break - + model_info, api_provider, client = selection_result logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_info.name}'...") @@ -637,32 +676,36 @@ class _RequestStrategy: # 合并模型特定的额外参数 if model_info.extra_params: - request_kwargs["extra_params"] = {**model_info.extra_params, **request_kwargs.get("extra_params", {})} + request_kwargs["extra_params"] = { + **model_info.extra_params, + **request_kwargs.get("extra_params", {}), + } + + response = await self._try_model_request( + model_info, api_provider, client, request_type, **request_kwargs + ) - response = await self._try_model_request(model_info, api_provider, client, request_type, **request_kwargs) - # 成功,立即返回 logger.debug(f"模型 '{model_info.name}' 成功生成了回复。") self.model_selector.update_usage_penalty(model_info.name, increase=False) return response, model_info - + except Exception as e: logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。") failed_models_in_this_request.add(model_info.name) last_exception = e # 使用惩罚值已在 select 时增加,失败后不减少,以降低其后续被选中的概率 - + logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。") if raise_when_empty: if last_exception: raise RuntimeError("所有模型均未能生成响应。") from last_exception raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。") - + # 如果不抛出异常,返回一个备用响应 fallback_model_info = model_config.get_model_info(self.model_list[0]) return APIResponse(content="所有模型都请求失败"), fallback_model_info - async def _try_model_request( self, model_info: ModelInfo, api_provider: APIProvider, client: BaseClient, request_type: RequestType, **kwargs ) -> APIResponse: @@ -684,46 +727,49 @@ class _RequestStrategy: RuntimeError: 如果在达到最大重试次数后仍然收到空回复或截断的响应。 """ max_empty_retry = api_provider.max_retry - + for i in range(max_empty_retry + 1): - response = await self.executor.execute_request( - api_provider, client, request_type, model_info, **kwargs - ) + response = await self.executor.execute_request(api_provider, client, request_type, model_info, **kwargs) if request_type != RequestType.RESPONSE: - return response # 对于非响应类型,直接返回 + return response # 对于非响应类型,直接返回 # --- 响应内容处理和空回复/截断检查 --- content = response.content or "" use_anti_truncation = getattr(model_info, "use_anti_truncation", False) - processed_content, reasoning, is_truncated = self.prompt_processor.process_response(content, use_anti_truncation) - + processed_content, reasoning, is_truncated = self.prompt_processor.process_response( + content, use_anti_truncation + ) + # 更新响应对象 response.content = processed_content response.reasoning_content = response.reasoning_content or reasoning is_empty_reply = not response.tool_calls and not (response.content and response.content.strip()) - + if not is_empty_reply and not is_truncated: - return response # 成功获取有效响应 + return response # 成功获取有效响应 if i < max_empty_retry: reason = "空回复" if is_empty_reply else "截断" - logger.warning(f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})...") + logger.warning( + f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})..." + ) if api_provider.retry_interval > 0: await asyncio.sleep(api_provider.retry_interval) else: reason = "空回复" if is_empty_reply else "截断" logger.error(f"模型 '{model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。") raise RuntimeError(f"模型 '{model_info.name}' 已达到空回复/截断的最大内部重试次数。") - - raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里 + + raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里 # ============================================================================== # Main Facade Class # ============================================================================== + class LLMRequest: """ LLM请求协调器。 @@ -745,7 +791,7 @@ class LLMRequest: model: (0, 0, 0) for model in self.model_for_task.model_list } """模型使用量记录,(total_tokens, penalty, usage_penalty)""" - + # 初始化辅助类 self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage) self._prompt_processor = _PromptProcessor() @@ -769,36 +815,44 @@ class LLMRequest: prompt (str): 提示词 image_base64 (str): 图像的Base64编码字符串 image_format (str): 图像格式(如 'png', 'jpeg' 等) - + Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ start_time = time.time() - + # 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行 selection_result = self._model_selector.select_best_available_model(set(), "response") if not selection_result: raise RuntimeError("无法为图像响应选择可用模型。") model_info, api_provider, client = selection_result - + normalized_format = _normalize_image_format(image_format) - message = MessageBuilder().add_text_content(prompt).add_image_content( - image_base64=image_base64, - image_format=normalized_format, - support_formats=client.get_support_image_formats(), - ).build() + message = ( + MessageBuilder() + .add_text_content(prompt) + .add_image_content( + image_base64=image_base64, + image_format=normalized_format, + support_formats=client.get_support_image_formats(), + ) + .build() + ) response = await self._executor.execute_request( - api_provider, client, RequestType.RESPONSE, model_info, + api_provider, + client, + RequestType.RESPONSE, + model_info, message_list=[message], temperature=temperature, max_tokens=max_tokens, ) - + await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions") content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False) reasoning = response.reasoning_content or reasoning - + return content, (reasoning, model_info.name, response.tool_calls) async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: @@ -812,9 +866,7 @@ class LLMRequest: Returns: Optional[str]: 语音转换后的文本内容,如果所有模型都失败则返回None。 """ - response, _ = await self._strategy.execute_with_failover( - RequestType.AUDIO, audio_base64=voice_base64 - ) + response, _ = await self._strategy.execute_with_failover(RequestType.AUDIO, audio_base64=voice_base64) return response.content or None async def generate_response_async( @@ -834,7 +886,7 @@ class LLMRequest: max_tokens (int, optional): 最大token数 tools: 工具配置 raise_when_empty (bool): 是否在空回复时抛出异常 - + Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ @@ -842,12 +894,16 @@ class LLMRequest: if concurrency_count <= 1: return await self._execute_single_text_request(prompt, temperature, max_tokens, tools, raise_when_empty) - + try: return await execute_concurrently( self._execute_single_text_request, concurrency_count, - prompt, temperature, max_tokens, tools, raise_when_empty=False + prompt, + temperature, + max_tokens, + tools, + raise_when_empty=False, ) except Exception as e: logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}") @@ -885,7 +941,7 @@ class LLMRequest: response, model_info = await self._strategy.execute_with_failover( RequestType.RESPONSE, raise_when_empty=raise_when_empty, - prompt=prompt, # 传递原始prompt,由strategy处理 + prompt=prompt, # 传递原始prompt,由strategy处理 tool_options=tool_options, temperature=self.model_for_task.temperature if temperature is None else temperature, max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, @@ -906,21 +962,20 @@ class LLMRequest: Args: embedding_input (str): 获取嵌入的目标 - + Returns: (Tuple[List[float], str]): (嵌入向量,使用的模型名称) """ start_time = time.time() response, model_info = await self._strategy.execute_with_failover( - RequestType.EMBEDDING, - embedding_input=embedding_input + RequestType.EMBEDDING, embedding_input=embedding_input ) - + await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings") - + if not response.embedding: raise RuntimeError("获取embedding失败") - + return response.embedding, model_info.name async def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str): @@ -940,16 +995,18 @@ class LLMRequest: # 步骤1: 更新内存中的token计数,用于负载均衡 total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty) - + # 步骤2: 创建一个后台任务,将用量数据异步写入数据库 - asyncio.create_task(llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=usage, - user_id="system", # 此处可根据业务需求修改 - time_cost=time_cost, - request_type=self.task_name, - endpoint=endpoint, - )) + asyncio.create_task( + llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, + user_id="system", # 此处可根据业务需求修改 + time_cost=time_cost, + request_type=self.task_name, + endpoint=endpoint, + ) + ) @staticmethod def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: @@ -970,14 +1027,14 @@ class LLMRequest: # 如果没有提供工具,直接返回 None if not tools: return None - + tool_options: List[ToolOption] = [] # 遍历每个工具定义 for tool in tools: try: # 使用建造者模式创建 ToolOption builder = ToolOptionBuilder().set_name(tool["name"]).set_description(tool.get("description", "")) - + # 遍历工具的参数 for param in tool.get("parameters", []): # 严格验证参数格式是否为包含5个元素的元组 @@ -994,6 +1051,6 @@ class LLMRequest: except (KeyError, IndexError, TypeError, AssertionError) as e: # 如果构建过程中出现任何错误,记录日志并跳过该工具 logger.error(f"构建工具 '{tool.get('name', 'N/A')}' 失败: {e}") - + # 如果列表非空则返回列表,否则返回 None return tool_options or None diff --git a/src/main.py b/src/main.py index 44084509a..4e91f1419 100644 --- a/src/main.py +++ b/src/main.py @@ -78,6 +78,7 @@ class MainSystem: logger.info("收到退出信号,正在优雅关闭系统...") import asyncio + try: loop = asyncio.get_event_loop() if loop.is_running(): @@ -106,6 +107,7 @@ class MainSystem: # 停止消息管理器 try: from src.chat.message_manager import message_manager + await message_manager.stop() logger.info("🛑 消息管理器已停止") except Exception as e: @@ -241,7 +243,6 @@ MoFox_Bot(第三方修改版) # 处理所有缓存的事件订阅(插件加载完成后) event_manager.process_all_pending_subscriptions() - # 初始化表情管理器 get_emoji_manager().initialize() logger.info("表情包管理器初始化成功") diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py index e0d753842..ba8ee54eb 100644 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -50,13 +50,13 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: query_text=message.processed_plain_text, user_id=str(message.user_info.user_id), scope_id=message.chat_id, - limit=5 + limit=5, ) # 基于检索结果计算兴趣度 if enhanced_memories: # 有相关记忆,兴趣度基于相似度计算 - max_score = max(getattr(memory, 'relevance_score', 0.5) for memory in enhanced_memories) + max_score = max(getattr(memory, "relevance_score", 0.5) for memory in enhanced_memories) interested_rate = min(max_score, 1.0) # 限制在0-1之间 else: # 没有相关记忆,给予基础兴趣度 diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index a43d1186d..1c8782d23 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -4,6 +4,7 @@ from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat import time from src.chat.utils.utils import get_recent_group_speaker + # 旧的Hippocampus系统已被移除,现在使用增强记忆系统 # from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager import random @@ -181,14 +182,16 @@ class PromptBuilder: query_text=text, user_id="system", # 系统查询 scope_id="system", - limit=5 + limit=5, ) related_memory_info = "" if enhanced_memories: for memory_chunk in enhanced_memories: related_memory_info += memory_chunk.display or memory_chunk.text_content or "" - return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info.strip()) + return await global_prompt_manager.format_prompt( + "memory_prompt", memory_info=related_memory_info.strip() + ) return "" except Exception as e: diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index dc7f0f24b..66fcee96f 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -149,7 +149,7 @@ class ChatMood: self.mood_state = response self.last_change_time = message_time - + async def regress_mood(self): message_time = time.time() message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 4cfb75f94..478d4c9fb 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -1,4 +1,3 @@ -import asyncio import copy import datetime import hashlib @@ -145,7 +144,7 @@ class PersonInfoManager: except Exception as e: logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}") return "" - + @staticmethod async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str): """判断是否认识某人""" @@ -166,7 +165,7 @@ class PersonInfoManager: await person_info_manager.update_one_field( person_id=person_id, field_name="nickname", value=user_nickname, data=data ) - + @staticmethod async def create_person_info(person_id: str, data: Optional[dict] = None): """创建一个项""" @@ -491,7 +490,9 @@ class PersonInfoManager: async def _db_check_name_exists_async(name_to_check): async with get_db_session() as session: - result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)) + result = await session.execute( + select(PersonInfo).where(PersonInfo.person_name == name_to_check) + ) record = result.scalar() return record is not None @@ -552,7 +553,6 @@ class PersonInfoManager: else: logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行") - @staticmethod async def get_value(person_id: str, field_name: str) -> Any: """获取单个字段值(同步版本)""" @@ -623,6 +623,7 @@ class PersonInfoManager: result[field_name] = copy.deepcopy(person_info_default.get(field_name)) return result + @staticmethod async def get_specific_value_list( field_name: str, @@ -694,7 +695,7 @@ class PersonInfoManager: return record, False # 其他协程已创建,返回现有记录 # 如果仍然失败,重新抛出异常 raise e - + unique_nickname = await self._generate_unique_person_name(nickname) initial_data = { "person_id": person_id, diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index df0b25e77..4dc478f6c 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -303,12 +303,14 @@ class RelationshipBuilder: if not self.person_engaged_cache: return f"{self.log_prefix} 关系缓存为空" - status_lines = [f"{self.log_prefix} 关系缓存状态:", - f"最后处理消息时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.last_processed_message_time)) if self.last_processed_message_time > 0 else '未设置'}", - f"最后清理时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.last_cleanup_time)) if self.last_cleanup_time > 0 else '未执行'}", - f"总用户数:{len(self.person_engaged_cache)}", - f"清理配置:{'启用' if SEGMENT_CLEANUP_CONFIG['enable_cleanup'] else '禁用'} (最大保存{SEGMENT_CLEANUP_CONFIG['max_segment_age_days']}天, 每用户最多{SEGMENT_CLEANUP_CONFIG['max_segments_per_user']}段)", - ""] + status_lines = [ + f"{self.log_prefix} 关系缓存状态:", + f"最后处理消息时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.last_processed_message_time)) if self.last_processed_message_time > 0 else '未设置'}", + f"最后清理时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.last_cleanup_time)) if self.last_cleanup_time > 0 else '未执行'}", + f"总用户数:{len(self.person_engaged_cache)}", + f"清理配置:{'启用' if SEGMENT_CLEANUP_CONFIG['enable_cleanup'] else '禁用'} (最大保存{SEGMENT_CLEANUP_CONFIG['max_segment_age_days']}天, 每用户最多{SEGMENT_CLEANUP_CONFIG['max_segments_per_user']}段)", + "", + ] for person_id, segments in self.person_engaged_cache.items(): total_count = self._get_total_message_count(person_id) @@ -369,7 +371,7 @@ class RelationshipBuilder: for person_id, segments in self.person_engaged_cache.items(): total_message_count = self._get_total_message_count(person_id) person_name = get_person_info_manager().get_value(person_id, "person_name") or person_id - + if total_message_count >= max_build_threshold or ( total_message_count >= 5 and (immediate_build == person_id or immediate_build == "all") ): @@ -428,7 +430,9 @@ class RelationshipBuilder: start_date = time.strftime("%Y-%m-%d %H:%M", time.localtime(start_time)) # 获取该段的消息(包含边界) - segment_messages = await get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time) + segment_messages = await get_raw_msg_by_timestamp_with_chat_inclusive( + self.chat_id, start_time, end_time + ) logger.debug( f"消息段: {start_date} - {time.strftime('%Y-%m-%d %H:%M', time.localtime(end_time))}, 消息数: {len(segment_messages)}" ) diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 93269123b..90a353291 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -1,7 +1,6 @@ import time import traceback import orjson -import random from typing import List, Dict, Any from json_repair import repair_json diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py index 411a2d326..c80c5942c 100644 --- a/src/plugin_system/apis/__init__.py +++ b/src/plugin_system/apis/__init__.py @@ -19,7 +19,7 @@ from src.plugin_system.apis import ( send_api, tool_api, permission_api, - schedule_api + schedule_api, ) from src.plugin_system.apis.chat_api import ChatManager as context_api from .logging_api import get_logger diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index 612c243a3..baf6418dd 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -8,7 +8,7 @@ readable_text = message_api.build_readable_messages(messages) """ -from typing import List, Dict, Any, Tuple, Optional, Coroutine +from typing import List, Dict, Any, Tuple, Optional from src.config.config import global_config import time from src.chat.utils.chat_message_builder import ( @@ -181,9 +181,7 @@ async def get_messages_by_time_in_chat_for_users( raise ValueError("chat_id 不能为空") if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") - return await get_raw_msg_by_timestamp_with_chat_users( - chat_id, start_time, end_time, person_ids, limit, limit_mode - ) + return await get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode) async def get_random_chat_messages( @@ -384,9 +382,7 @@ async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Op return await num_new_messages_since(chat_id, start_time, end_time) -async def count_new_messages_for_users( - chat_id: str, start_time: float, end_time: float, person_ids: List[str] -) -> int: +async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int: """ 计算指定聊天中指定用户从开始时间到结束时间的新消息数量 diff --git a/src/plugin_system/apis/permission_api.py b/src/plugin_system/apis/permission_api.py index 97fde236c..61b4ca40f 100644 --- a/src/plugin_system/apis/permission_api.py +++ b/src/plugin_system/apis/permission_api.py @@ -61,8 +61,7 @@ class PermissionAPI: def __init__(self): self._permission_manager: Optional[IPermissionManager] = None # 需要保留的前缀(视为绝对节点名,不再自动加 plugins.. 前缀) - self.RESERVED_PREFIXES: tuple[str, ...] = ( - "system.") + self.RESERVED_PREFIXES: tuple[str, ...] = "system." # 系统节点列表 (name, description, default_granted) self._SYSTEM_NODES: list[tuple[str, str, bool]] = [ ("system.superuser", "系统超级管理员:拥有所有权限", False), diff --git a/src/plugin_system/apis/schedule_api.py b/src/plugin_system/apis/schedule_api.py index f1d81049e..e3e759968 100644 --- a/src/plugin_system/apis/schedule_api.py +++ b/src/plugin_system/apis/schedule_api.py @@ -28,6 +28,7 @@ asyncio.run(main()) """ + from datetime import datetime from typing import List, Dict, Any, Optional @@ -176,4 +177,4 @@ async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool: async def archive_monthly_plans(target_month: Optional[str] = None) -> bool: """(异步) 归档指定月份的月度计划的便捷函数""" - return await ScheduleAPI.archive_monthly_plans(target_month) \ No newline at end of file + return await ScheduleAPI.archive_monthly_plans(target_month) diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index ad6621b3a..c770db78b 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -80,7 +80,9 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa message_info = { "platform": message_dict.get("chat_info_platform", ""), - "message_id": message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id"), + "message_id": message_dict.get("message_id") + or message_dict.get("chat_info_message_id") + or message_dict.get("id"), "time": message_dict.get("time"), "group_info": group_info, "user_info": user_info, diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 26b79d4df..d3f012be5 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -2,7 +2,7 @@ import time import asyncio from abc import ABC, abstractmethod -from typing import Tuple, Optional, List, Dict, Any +from typing import Tuple, Optional, List, Dict from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream @@ -27,7 +27,7 @@ class BaseAction(ABC): - parallel_action: 是否允许并行执行 - random_activation_probability: 随机激活概率 - llm_judge_prompt: LLM判断提示词 - + 二步Action相关属性: - is_two_step_action: 是否为二步Action - step_one_description: 第一步的描述 @@ -434,7 +434,9 @@ class BaseAction(ABC): # 确保获取的是Action组件 if component_info.component_type != ComponentType.ACTION: - logger.error(f"{log_prefix} 尝试调用的组件 '{action_name}' 不是一个Action,而是一个 '{component_info.component_type.value}'") + logger.error( + f"{log_prefix} 尝试调用的组件 '{action_name}' 不是一个Action,而是一个 '{component_info.component_type.value}'" + ) return False, f"组件 '{action_name}' 不是一个有效的Action" plugin_config = component_registry.get_plugin_config(component_info.plugin_name) @@ -527,20 +529,20 @@ class BaseAction(ABC): # 第一步:展示可用的子Action available_actions = [sub_action[0] for sub_action in self.sub_actions] description = self.step_one_description or f"{self.action_name}支持以下操作" - + actions_list = "\n".join([f"- {action}: {desc}" for action, desc, _ in self.sub_actions]) response = f"{description}\n\n可用操作:\n{actions_list}\n\n请选择要执行的操作。" - + return True, response else: # 验证选择的子Action是否有效 valid_actions = [sub_action[0] for sub_action in self.sub_actions] if selected_action not in valid_actions: return False, f"无效的操作选择: {selected_action}。可用操作: {valid_actions}" - + # 保存选择的子Action self._selected_sub_action = selected_action - + # 调用第二步执行 return await self.execute_step_two(selected_action) @@ -571,7 +573,7 @@ class BaseAction(ABC): # 如果是二步Action,自动处理第一步 if self.is_two_step_action: return await self.handle_step_one() - + # 普通Action由子类实现 pass diff --git a/src/plugin_system/base/base_chatter.py b/src/plugin_system/base/base_chatter.py index 1bdb79c31..1dd225252 100644 --- a/src/plugin_system/base/base_chatter.py +++ b/src/plugin_system/base/base_chatter.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod -from typing import List, Optional, TYPE_CHECKING +from typing import List, TYPE_CHECKING from src.common.data_models.message_manager_data_model import StreamContext from .component_types import ChatType from src.plugin_system.base.component_types import ChatterInfo, ComponentType if TYPE_CHECKING: from src.chat.planner_actions.action_manager import ChatterActionManager - from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner as ActionPlanner + class BaseChatter(ABC): chatter_name: str = "" @@ -15,7 +15,7 @@ class BaseChatter(ABC): """Chatter组件的描述""" chat_types: List[ChatType] = [ChatType.PRIVATE, ChatType.GROUP] - def __init__(self, stream_id: str, action_manager: 'ChatterActionManager'): + def __init__(self, stream_id: str, action_manager: "ChatterActionManager"): """ 初始化聊天处理器 @@ -45,11 +45,10 @@ class BaseChatter(ABC): Returns: ChatterInfo对象 """ - + return ChatterInfo( name=cls.chatter_name, description=cls.chatter_description or "No description provided.", chat_type_allow=cls.chat_types[0], component_type=ComponentType.CHATTER, ) - diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py index 84dc8b150..229cadb63 100644 --- a/src/plugin_system/base/base_tool.py +++ b/src/plugin_system/base/base_tool.py @@ -64,7 +64,15 @@ class BaseTool(ABC): return { "name": cls.name, "description": cls.step_one_description or cls.description, - "parameters": [("action", ToolParamType.STRING, "选择要执行的操作", True, [sub_tool[0] for sub_tool in cls.sub_tools])] + "parameters": [ + ( + "action", + ToolParamType.STRING, + "选择要执行的操作", + True, + [sub_tool[0] for sub_tool in cls.sub_tools], + ) + ], } else: # 普通工具需要parameters @@ -88,12 +96,8 @@ class BaseTool(ABC): # 查找对应的子工具 for sub_name, sub_desc, sub_params in cls.sub_tools: if sub_name == sub_tool_name: - return { - "name": f"{cls.name}_{sub_tool_name}", - "description": sub_desc, - "parameters": sub_params - } - + return {"name": f"{cls.name}_{sub_tool_name}", "description": sub_desc, "parameters": sub_params} + raise ValueError(f"未找到子工具: {sub_tool_name}") @classmethod @@ -105,14 +109,10 @@ class BaseTool(ABC): """ if not cls.is_two_step_tool: return [] - + definitions = [] for sub_name, sub_desc, sub_params in cls.sub_tools: - definitions.append({ - "name": f"{cls.name}_{sub_name}", - "description": sub_desc, - "parameters": sub_params - }) + definitions.append({"name": f"{cls.name}_{sub_name}", "description": sub_desc, "parameters": sub_params}) return definitions @classmethod @@ -144,7 +144,7 @@ class BaseTool(ABC): # 如果是二步工具,处理第一步调用 if self.is_two_step_tool and "action" in function_args: return await self._handle_step_one(function_args) - + raise NotImplementedError("子类必须实现execute方法") async def _handle_step_one(self, function_args: dict[str, Any]) -> dict[str, Any]: @@ -174,17 +174,13 @@ class BaseTool(ABC): sub_name, sub_desc, sub_params = sub_tool_found # 返回第二步工具定义 - step_two_definition = { - "name": f"{self.name}_{sub_name}", - "description": sub_desc, - "parameters": sub_params - } + step_two_definition = {"name": f"{self.name}_{sub_name}", "description": sub_desc, "parameters": sub_params} return { "type": "two_step_tool_step_one", "content": f"已选择操作: {action}。请使用以下工具进行具体调用:", "next_tool_definition": step_two_definition, - "selected_action": action + "selected_action": action, } async def execute_step_two(self, sub_tool_name: str, function_args: dict[str, Any]) -> dict[str, Any]: diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 6d0590d43..2b1122b9f 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -40,7 +40,7 @@ class ActionActivationType(Enum): # 聊天模式枚举 class ChatMode(Enum): """聊天模式枚举""" - + FOCUS = "focus" # 专注模式 NORMAL = "normal" # Normal聊天模式 PROACTIVE = "proactive" # 主动思考模式 diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index 12797bafd..a61b8e04c 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -294,9 +294,7 @@ class PluginBase(ABC): changed = False # 内部递归函数 - def _sync_dicts( - schema_dict: Dict[str, Any], user_dict: Dict[str, Any], parent_key: str = "" - ) -> Dict[str, Any]: + def _sync_dicts(schema_dict: Dict[str, Any], user_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, Any]: nonlocal changed synced_dict = schema_dict.copy() diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 63594c53e..9c82553f8 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -1,7 +1,7 @@ from pathlib import Path import re -from typing import TYPE_CHECKING, Dict, List, Optional, Any, Pattern, Tuple, Union, Type +from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type from src.common.logger import get_logger from src.plugin_system.base.component_types import ( @@ -34,44 +34,46 @@ class ComponentRegistry: def __init__(self): # 命名空间式组件名构成法 f"{component_type}.{component_name}" - self._components: Dict[str, 'ComponentInfo'] = {} + self._components: Dict[str, "ComponentInfo"] = {} """组件注册表 命名空间式组件名 -> 组件信息""" - self._components_by_type: Dict['ComponentType', Dict[str, 'ComponentInfo']] = {types: {} for types in ComponentType} + self._components_by_type: Dict["ComponentType", Dict[str, "ComponentInfo"]] = { + types: {} for types in ComponentType + } """类型 -> 组件原名称 -> 组件信息""" self._components_classes: Dict[ - str, Type[Union['BaseCommand', 'BaseAction', 'BaseTool', 'BaseEventHandler', 'PlusCommand', 'BaseChatter']] + str, Type[Union["BaseCommand", "BaseAction", "BaseTool", "BaseEventHandler", "PlusCommand", "BaseChatter"]] ] = {} """命名空间式组件名 -> 组件类""" # 插件注册表 - self._plugins: Dict[str, 'PluginInfo'] = {} + self._plugins: Dict[str, "PluginInfo"] = {} """插件名 -> 插件信息""" # Action特定注册表 - self._action_registry: Dict[str, Type['BaseAction']] = {} + self._action_registry: Dict[str, Type["BaseAction"]] = {} """Action注册表 action名 -> action类""" - self._default_actions: Dict[str, 'ActionInfo'] = {} + self._default_actions: Dict[str, "ActionInfo"] = {} """默认动作集,即启用的Action集,用于重置ActionManager状态""" # Command特定注册表 - self._command_registry: Dict[str, Type['BaseCommand']] = {} + self._command_registry: Dict[str, Type["BaseCommand"]] = {} """Command类注册表 command名 -> command类""" self._command_patterns: Dict[Pattern, str] = {} """编译后的正则 -> command名""" # 工具特定注册表 - self._tool_registry: Dict[str, Type['BaseTool']] = {} # 工具名 -> 工具类 - self._llm_available_tools: Dict[str, Type['BaseTool']] = {} # llm可用的工具名 -> 工具类 + self._tool_registry: Dict[str, Type["BaseTool"]] = {} # 工具名 -> 工具类 + self._llm_available_tools: Dict[str, Type["BaseTool"]] = {} # llm可用的工具名 -> 工具类 # EventHandler特定注册表 - self._event_handler_registry: Dict[str, Type['BaseEventHandler']] = {} + self._event_handler_registry: Dict[str, Type["BaseEventHandler"]] = {} """event_handler名 -> event_handler类""" - self._enabled_event_handlers: Dict[str, Type['BaseEventHandler']] = {} + self._enabled_event_handlers: Dict[str, Type["BaseEventHandler"]] = {} """启用的事件处理器 event_handler名 -> event_handler类""" - self._chatter_registry: Dict[str, Type['BaseChatter']] = {} + self._chatter_registry: Dict[str, Type["BaseChatter"]] = {} """chatter名 -> chatter类""" - self._enabled_chatter_registry: Dict[str, Type['BaseChatter']] = {} + self._enabled_chatter_registry: Dict[str, Type["BaseChatter"]] = {} """启用的chatter名 -> chatter类""" logger.info("组件注册中心初始化完成") @@ -99,7 +101,7 @@ class ComponentRegistry: def register_component( self, component_info: ComponentInfo, - component_class: Type[Union['BaseCommand', 'BaseAction', 'BaseEventHandler', 'BaseTool', 'BaseChatter']], + component_class: Type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]], ) -> bool: """注册组件 @@ -172,7 +174,7 @@ class ComponentRegistry: ) return True - def _register_action_component(self, action_info: 'ActionInfo', action_class: Type['BaseAction']) -> bool: + def _register_action_component(self, action_info: "ActionInfo", action_class: Type["BaseAction"]) -> bool: """注册Action组件到Action特定注册表""" if not (action_name := action_info.name): logger.error(f"Action组件 {action_class.__name__} 必须指定名称") @@ -192,7 +194,7 @@ class ComponentRegistry: return True - def _register_command_component(self, command_info: 'CommandInfo', command_class: Type['BaseCommand']) -> bool: + def _register_command_component(self, command_info: "CommandInfo", command_class: Type["BaseCommand"]) -> bool: """注册Command组件到Command特定注册表""" if not (command_name := command_info.name): logger.error(f"Command组件 {command_class.__name__} 必须指定名称") @@ -219,7 +221,7 @@ class ComponentRegistry: return True def _register_plus_command_component( - self, plus_command_info: 'PlusCommandInfo', plus_command_class: Type['PlusCommand'] + self, plus_command_info: "PlusCommandInfo", plus_command_class: Type["PlusCommand"] ) -> bool: """注册PlusCommand组件到特定注册表""" plus_command_name = plus_command_info.name @@ -233,7 +235,7 @@ class ComponentRegistry: # 创建专门的PlusCommand注册表(如果还没有) if not hasattr(self, "_plus_command_registry"): - self._plus_command_registry: Dict[str, Type['PlusCommand']] = {} + self._plus_command_registry: Dict[str, Type["PlusCommand"]] = {} plus_command_class.plugin_name = plus_command_info.plugin_name # 设置插件配置 @@ -243,7 +245,7 @@ class ComponentRegistry: logger.debug(f"已注册PlusCommand组件: {plus_command_name}") return True - def _register_tool_component(self, tool_info: 'ToolInfo', tool_class: Type['BaseTool']) -> bool: + def _register_tool_component(self, tool_info: "ToolInfo", tool_class: Type["BaseTool"]) -> bool: """注册Tool组件到Tool特定注册表""" tool_name = tool_info.name @@ -259,7 +261,7 @@ class ComponentRegistry: return True def _register_event_handler_component( - self, handler_info: 'EventHandlerInfo', handler_class: Type['BaseEventHandler'] + self, handler_info: "EventHandlerInfo", handler_class: Type["BaseEventHandler"] ) -> bool: if not (handler_name := handler_info.name): logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称") @@ -285,7 +287,7 @@ class ComponentRegistry: handler_class, self.get_plugin_config(handler_info.plugin_name) or {} ) - def _register_chatter_component(self, chatter_info: 'ChatterInfo', chatter_class: Type['BaseChatter']) -> bool: + def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: Type["BaseChatter"]) -> bool: """注册Chatter组件到Chatter特定注册表""" chatter_name = chatter_info.name @@ -312,7 +314,7 @@ class ComponentRegistry: # === 组件移除相关 === - async def remove_component(self, component_name: str, component_type: 'ComponentType', plugin_name: str) -> bool: + async def remove_component(self, component_name: str, component_type: "ComponentType", plugin_name: str) -> bool: target_component_class = self.get_component_class(component_name, component_type) if not target_component_class: logger.warning(f"组件 {component_name} 未注册,无法移除") @@ -362,7 +364,7 @@ class ComponentRegistry: case ComponentType.CHATTER: # 移除Chatter注册 - if hasattr(self, '_chatter_registry'): + if hasattr(self, "_chatter_registry"): self._chatter_registry.pop(component_name, None) logger.debug(f"已移除Chatter组件: {component_name}") @@ -484,8 +486,8 @@ class ComponentRegistry: # === 组件查询方法 === def get_component_info( - self, component_name: str, component_type: Optional['ComponentType'] = None - ) -> Optional['ComponentInfo']: + self, component_name: str, component_type: Optional["ComponentType"] = None + ) -> Optional["ComponentInfo"]: # sourcery skip: class-extract-method """获取组件信息,支持自动命名空间解析 @@ -529,8 +531,8 @@ class ComponentRegistry: def get_component_class( self, component_name: str, - component_type: Optional['ComponentType'] = None, - ) -> Optional[Union[Type['BaseCommand'], Type['BaseAction'], Type['BaseEventHandler'], Type['BaseTool']]]: + component_type: Optional["ComponentType"] = None, + ) -> Optional[Union[Type["BaseCommand"], Type["BaseAction"], Type["BaseEventHandler"], Type["BaseTool"]]]: """获取组件类,支持自动命名空间解析 Args: @@ -572,22 +574,22 @@ class ComponentRegistry: # 4. 都没找到 return None - def get_components_by_type(self, component_type: 'ComponentType') -> Dict[str, 'ComponentInfo']: + def get_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]: """获取指定类型的所有组件""" return self._components_by_type.get(component_type, {}).copy() - def get_enabled_components_by_type(self, component_type: 'ComponentType') -> Dict[str, 'ComponentInfo']: + def get_enabled_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]: """获取指定类型的所有启用组件""" components = self.get_components_by_type(component_type) return {name: info for name, info in components.items() if info.enabled} # === Action特定查询方法 === - def get_action_registry(self) -> Dict[str, Type['BaseAction']]: + def get_action_registry(self) -> Dict[str, Type["BaseAction"]]: """获取Action注册表""" return self._action_registry.copy() - def get_registered_action_info(self, action_name: str) -> Optional['ActionInfo']: + def get_registered_action_info(self, action_name: str) -> Optional["ActionInfo"]: """获取Action信息""" info = self.get_component_info(action_name, ComponentType.ACTION) return info if isinstance(info, ActionInfo) else None @@ -598,11 +600,11 @@ class ComponentRegistry: # === Command特定查询方法 === - def get_command_registry(self) -> Dict[str, Type['BaseCommand']]: + def get_command_registry(self) -> Dict[str, Type["BaseCommand"]]: """获取Command注册表""" return self._command_registry.copy() - def get_registered_command_info(self, command_name: str) -> Optional['CommandInfo']: + def get_registered_command_info(self, command_name: str) -> Optional["CommandInfo"]: """获取Command信息""" info = self.get_component_info(command_name, ComponentType.COMMAND) return info if isinstance(info, CommandInfo) else None @@ -611,7 +613,7 @@ class ComponentRegistry: """获取Command模式注册表""" return self._command_patterns.copy() - def find_command_by_text(self, text: str) -> Optional[Tuple[Type['BaseCommand'], dict, 'CommandInfo']]: + def find_command_by_text(self, text: str) -> Optional[Tuple[Type["BaseCommand"], dict, "CommandInfo"]]: # sourcery skip: use-named-expression, use-next """根据文本查找匹配的命令 @@ -638,15 +640,15 @@ class ComponentRegistry: return None # === Tool 特定查询方法 === - def get_tool_registry(self) -> Dict[str, Type['BaseTool']]: + def get_tool_registry(self) -> Dict[str, Type["BaseTool"]]: """获取Tool注册表""" return self._tool_registry.copy() - def get_llm_available_tools(self) -> Dict[str, Type['BaseTool']]: + def get_llm_available_tools(self) -> Dict[str, Type["BaseTool"]]: """获取LLM可用的Tool列表""" return self._llm_available_tools.copy() - def get_registered_tool_info(self, tool_name: str) -> Optional['ToolInfo']: + def get_registered_tool_info(self, tool_name: str) -> Optional["ToolInfo"]: """获取Tool信息 Args: @@ -659,13 +661,13 @@ class ComponentRegistry: return info if isinstance(info, ToolInfo) else None # === PlusCommand 特定查询方法 === - def get_plus_command_registry(self) -> Dict[str, Type['PlusCommand']]: + def get_plus_command_registry(self) -> Dict[str, Type["PlusCommand"]]: """获取PlusCommand注册表""" if not hasattr(self, "_plus_command_registry"): self._plus_command_registry: Dict[str, Type[PlusCommand]] = {} return self._plus_command_registry.copy() - def get_registered_plus_command_info(self, command_name: str) -> Optional['PlusCommandInfo']: + def get_registered_plus_command_info(self, command_name: str) -> Optional["PlusCommandInfo"]: """获取PlusCommand信息 Args: @@ -679,44 +681,44 @@ class ComponentRegistry: # === EventHandler 特定查询方法 === - def get_event_handler_registry(self) -> Dict[str, Type['BaseEventHandler']]: + def get_event_handler_registry(self) -> Dict[str, Type["BaseEventHandler"]]: """获取事件处理器注册表""" return self._event_handler_registry.copy() - def get_registered_event_handler_info(self, handler_name: str) -> Optional['EventHandlerInfo']: + def get_registered_event_handler_info(self, handler_name: str) -> Optional["EventHandlerInfo"]: """获取事件处理器信息""" info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER) return info if isinstance(info, EventHandlerInfo) else None - def get_enabled_event_handlers(self) -> Dict[str, Type['BaseEventHandler']]: + def get_enabled_event_handlers(self) -> Dict[str, Type["BaseEventHandler"]]: """获取启用的事件处理器""" return self._enabled_event_handlers.copy() # === Chatter 特定查询方法 === - def get_chatter_registry(self) -> Dict[str, Type['BaseChatter']]: + def get_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]: """获取Chatter注册表""" - if not hasattr(self, '_chatter_registry'): + if not hasattr(self, "_chatter_registry"): self._chatter_registry: Dict[str, Type[BaseChatter]] = {} return self._chatter_registry.copy() - - def get_enabled_chatter_registry(self) -> Dict[str, Type['BaseChatter']]: + + def get_enabled_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]: """获取启用的Chatter注册表""" - if not hasattr(self, '_enabled_chatter_registry'): + if not hasattr(self, "_enabled_chatter_registry"): self._enabled_chatter_registry: Dict[str, Type[BaseChatter]] = {} return self._enabled_chatter_registry.copy() - - def get_registered_chatter_info(self, chatter_name: str) -> Optional['ChatterInfo']: + + def get_registered_chatter_info(self, chatter_name: str) -> Optional["ChatterInfo"]: """获取Chatter信息""" info = self.get_component_info(chatter_name, ComponentType.CHATTER) return info if isinstance(info, ChatterInfo) else None - + # === 插件查询方法 === - def get_plugin_info(self, plugin_name: str) -> Optional['PluginInfo']: + def get_plugin_info(self, plugin_name: str) -> Optional["PluginInfo"]: """获取插件信息""" return self._plugins.get(plugin_name) - def get_all_plugins(self) -> Dict[str, 'PluginInfo']: + def get_all_plugins(self) -> Dict[str, "PluginInfo"]: """获取所有插件""" return self._plugins.copy() @@ -724,7 +726,7 @@ class ComponentRegistry: # """获取所有启用的插件""" # return {name: info for name, info in self._plugins.items() if info.enabled} - def get_plugin_components(self, plugin_name: str) -> List['ComponentInfo']: + def get_plugin_components(self, plugin_name: str) -> List["ComponentInfo"]: """获取插件的所有组件""" plugin_info = self.get_plugin_info(plugin_name) return plugin_info.components if plugin_info else [] diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index c79012e13..0bb22afdf 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -95,17 +95,16 @@ class PermissionManager(IPermissionManager): # 检查用户是否有明确的权限设置 result = await session.execute( - select(UserPermissions) - .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) + select(UserPermissions).filter_by( + platform=user.platform, user_id=user.user_id, permission_node=permission_node + ) ) user_perm = result.scalar_one_or_none() if user_perm: # 有明确设置,返回设置的值 res = user_perm.granted - logger.debug( - f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {res}" - ) + logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {res}") return res else: # 没有明确设置,使用默认值 @@ -191,8 +190,9 @@ class PermissionManager(IPermissionManager): # 检查是否已有权限记录 result = await session.execute( - select(UserPermissions) - .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) + select(UserPermissions).filter_by( + platform=user.platform, user_id=user.user_id, permission_node=permission_node + ) ) existing_perm = result.scalar_one_or_none() @@ -244,8 +244,9 @@ class PermissionManager(IPermissionManager): # 检查是否已有权限记录 result = await session.execute( - select(UserPermissions) - .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) + select(UserPermissions).filter_by( + platform=user.platform, user_id=user.user_id, permission_node=permission_node + ) ) existing_perm = result.scalar_one_or_none() @@ -303,8 +304,9 @@ class PermissionManager(IPermissionManager): for node in all_nodes: # 检查用户是否有明确的权限设置 result = await session.execute( - select(UserPermissions) - .filter_by(platform=user.platform, user_id=user.user_id, permission_node=node.node_name) + select(UserPermissions).filter_by( + platform=user.platform, user_id=user.user_id, permission_node=node.node_name + ) ) user_perm = result.scalar_one_or_none() @@ -408,8 +410,7 @@ class PermissionManager(IPermissionManager): # 删除用户权限记录 result = await session.execute( - delete(UserPermissions) - .where(UserPermissions.permission_node.in_(node_names)) + delete(UserPermissions).where(UserPermissions.permission_node.in_(node_names)) ) deleted_user_perms = result.rowcount diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 9ec216a9b..2950101a9 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -1,7 +1,5 @@ import asyncio import os -import shutil -import hashlib import traceback import importlib @@ -106,7 +104,6 @@ class PluginManager: if not plugin_dir: return False, 1 - plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败) if not plugin_instance: logger.error(f"插件 {plugin_name} 实例化失败") @@ -545,9 +542,7 @@ class PluginManager: try: loop = asyncio.get_event_loop() if loop.is_running(): - fut = asyncio.run_coroutine_threadsafe( - component_registry.unregister_plugin(plugin_name), loop - ) + fut = asyncio.run_coroutine_threadsafe(component_registry.unregister_plugin(plugin_name), loop) fut.result(timeout=5) else: asyncio.run(component_registry.unregister_plugin(plugin_name)) diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index daa8244cf..e666e32d4 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -116,17 +116,17 @@ class ToolExecutor: def _get_tool_definitions(self) -> List[Dict[str, Any]]: all_tools = get_llm_available_tool_definitions() user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) - + # 获取基础工具定义(包括二步工具的第一步) tool_definitions = [definition for name, definition in all_tools if name not in user_disabled_tools] - + # 检查是否有待处理的二步工具第二步调用 - pending_step_two = getattr(self, '_pending_step_two_tools', {}) + pending_step_two = getattr(self, "_pending_step_two_tools", {}) if pending_step_two: # 添加第二步工具定义 for tool_name, step_two_def in pending_step_two.items(): tool_definitions.append(step_two_def) - + return tool_definitions async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: @@ -266,7 +266,7 @@ class ToolExecutor: f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}" ) function_args["llm_called"] = True # 标记为LLM调用 - + # 检查是否是二步工具的第二步调用 if "_" in function_name and function_name.count("_") >= 1: # 可能是二步工具的第二步调用,格式为 "tool_name_sub_tool_name" @@ -274,14 +274,14 @@ class ToolExecutor: if len(parts) == 2: base_tool_name, sub_tool_name = parts base_tool_instance = get_tool_instance(base_tool_name) - + if base_tool_instance and base_tool_instance.is_two_step_tool: logger.info(f"{self.log_prefix}执行二步工具第二步: {base_tool_name}.{sub_tool_name}") result = await base_tool_instance.execute_step_two(sub_tool_name, function_args) - + # 清理待处理的第二步工具 self._pending_step_two_tools.pop(base_tool_name, None) - + if result: logger.debug(f"{self.log_prefix}二步工具第二步 {function_name} 执行成功") return { @@ -291,7 +291,7 @@ class ToolExecutor: "type": "function", "content": result.get("content", ""), } - + # 获取对应工具实例 tool_instance = tool_instance or get_tool_instance(function_name) if not tool_instance: @@ -301,7 +301,7 @@ class ToolExecutor: # 执行工具并记录日志 logger.debug(f"{self.log_prefix}执行工具 {function_name},参数: {function_args}") result = await tool_instance.execute(function_args) - + # 检查是否是二步工具的第一步结果 if result and result.get("type") == "two_step_tool_step_one": logger.info(f"{self.log_prefix}二步工具第一步完成: {function_name}") @@ -310,7 +310,7 @@ class ToolExecutor: if next_tool_def: self._pending_step_two_tools[function_name] = next_tool_def logger.debug(f"{self.log_prefix}已保存第二步工具定义: {next_tool_def['name']}") - + if result: logger.debug(f"{self.log_prefix}工具 {function_name} 执行成功,结果: {result}") return { diff --git a/src/plugin_system/utils/permission_decorators.py b/src/plugin_system/utils/permission_decorators.py index 990f1c91c..278ab2068 100644 --- a/src/plugin_system/utils/permission_decorators.py +++ b/src/plugin_system/utils/permission_decorators.py @@ -79,9 +79,11 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None) if not iscoroutinefunction(func): logger.warning(f"函数 {func.__name__} 使用 require_permission 但非异步,已强制阻止执行") + async def blocked(*_a, **_k): logger.error("同步函数不再支持权限装饰器,请改为 async def") return None + return blocked return async_wrapper @@ -146,9 +148,11 @@ def require_master(deny_message: Optional[str] = None): if not iscoroutinefunction(func): logger.warning(f"函数 {func.__name__} 使用 require_master 但非异步,已强制阻止执行") + async def blocked(*_a, **_k): logger.error("同步函数不再支持 require_master,请改为 async def") return None + return blocked return async_wrapper @@ -164,7 +168,9 @@ class PermissionChecker: @staticmethod def check_permission(chat_stream: ChatStream, permission_node: str) -> bool: - raise RuntimeError("PermissionChecker.check_permission 已移除同步支持,请直接 await permission_api.check_permission") + raise RuntimeError( + "PermissionChecker.check_permission 已移除同步支持,请直接 await permission_api.check_permission" + ) @staticmethod def is_master(chat_stream: ChatStream) -> bool: diff --git a/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py index 8a331f99e..1bb60146b 100644 --- a/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py +++ b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py @@ -12,7 +12,7 @@ from src.common.data_models.info_data_model import InterestScore from src.chat.interest_system import bot_interest_manager from src.common.logger import get_logger from src.config.config import global_config -from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker + logger = get_logger("chatter_interest_scoring") # 定义颜色 @@ -45,7 +45,7 @@ class ChatterInterestScoringSystem: self.probability_boost_per_no_reply = ( affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count ) # 每次不回复增加的概率 - + # 用户关系数据 self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py index cf14c221e..8d322c880 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py @@ -65,7 +65,7 @@ class ChatterPlanExecutor: if not plan.decided_actions: logger.info("没有需要执行的动作。") return {"executed_count": 0, "results": []} - + # 像hfc一样,提前打印将要执行的动作 action_types = [action.action_type for action in plan.decided_actions] logger.info(f"选择动作: {', '.join(action_types) if action_types else '无'}") @@ -150,17 +150,19 @@ class ChatterPlanExecutor: for i, action_info in enumerate(unique_actions): is_last_action = i == total_actions - 1 if total_actions > 1: - logger.info(f"[多重回复] 正在执行第 {i+1}/{total_actions} 个回复...") + logger.info(f"[多重回复] 正在执行第 {i + 1}/{total_actions} 个回复...") # 传递 clear_unread 参数 result = await self._execute_single_reply_action(action_info, plan, clear_unread=is_last_action) results.append(result) if total_actions > 1: - logger.info(f"[多重回复] 所有回复任务执行完毕。") + logger.info("[多重回复] 所有回复任务执行完毕。") return {"results": results} - async def _execute_single_reply_action(self, action_info: ActionPlannerInfo, plan: Plan, clear_unread: bool = True) -> Dict[str, any]: + async def _execute_single_reply_action( + self, action_info: ActionPlannerInfo, plan: Plan, clear_unread: bool = True + ) -> Dict[str, any]: """执行单个回复动作""" start_time = time.time() success = False @@ -201,7 +203,7 @@ class ChatterPlanExecutor: execution_result = await self.action_manager.execute_action( action_name=action_info.action_type, **action_params ) - + # 从返回结果中提取真正的回复文本 if isinstance(execution_result, dict): reply_content = execution_result.get("reply_text", "") @@ -233,7 +235,9 @@ class ChatterPlanExecutor: "error_message": error_message, "execution_time": execution_time, "reasoning": action_info.reasoning, - "reply_content": reply_content[:200] + "..." if reply_content and len(reply_content) > 200 else reply_content, + "reply_content": reply_content[:200] + "..." + if reply_content and len(reply_content) > 200 + else reply_content, } async def _execute_other_actions(self, other_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]: diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py index f84bdec46..1bc153fad 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py @@ -100,7 +100,7 @@ class ChatterPlanFilter: # 预解析 action_type 来进行判断 thinking = item.get("thinking", "未提供思考过程") actions_obj = item.get("actions", {}) - + # 处理actions字段可能是字典或列表的情况 if isinstance(actions_obj, dict): action_type = actions_obj.get("action_type", "no_action") @@ -116,14 +116,12 @@ class ChatterPlanFilter: if action_type in reply_action_types: if not reply_action_added: - final_actions.extend( - await self._parse_single_action(item, used_message_id_list, plan) - ) + final_actions.extend(await self._parse_single_action(item, used_message_id_list, plan)) reply_action_added = True else: # 非回复类动作直接添加 final_actions.extend(await self._parse_single_action(item, used_message_id_list, plan)) - + if thinking and thinking != "未提供思考过程": logger.info(f"\n{SAKURA_PINK}思考: {thinking}{RESET_COLOR}\n") plan.decided_actions = self._filter_no_actions(final_actions) @@ -154,6 +152,7 @@ class ChatterPlanFilter: schedule_block = "" # 优先检查是否被吵醒 from src.chat.message_manager.message_manager import message_manager + angry_prompt_addition = "" wakeup_mgr = message_manager.wakeup_manager @@ -161,7 +160,7 @@ class ChatterPlanFilter: # 检查1: 直接从 wakeup_manager 获取 if wakeup_mgr.is_in_angry_state(): angry_prompt_addition = wakeup_mgr.get_angry_prompt_addition() - + # 检查2: 如果上面没获取到,再从 mood_manager 确认 if not angry_prompt_addition: chat_mood_for_check = mood_manager.get_mood_by_chat_id(plan.chat_id) @@ -274,7 +273,9 @@ class ChatterPlanFilter: is_group_chat = plan.chat_type == ChatType.GROUP chat_context_description = "你现在正在一个群聊中" if not is_group_chat and plan.target_info: - chat_target_name = plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方" + chat_target_name = ( + plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方" + ) chat_context_description = f"你正在和 {chat_target_name} 私聊" action_options_block = await self._build_action_options(plan.available_actions) @@ -315,12 +316,12 @@ class ChatterPlanFilter: """构建已读/未读历史消息块""" try: # 从message_manager获取真实的已读/未读消息 - from src.chat.message_manager.message_manager import message_manager from src.chat.utils.utils import assign_message_ids from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat # 获取聊天流的上下文 from src.plugin_system.apis.chat_api import get_chat_manager + chat_manager = get_chat_manager() chat_stream = chat_manager.get_stream(plan.chat_id) if not chat_stream: @@ -333,6 +334,7 @@ class ChatterPlanFilter: read_messages = stream_context.context.history_messages # 已读消息存储在history_messages中 if not read_messages: from src.common.data_models.database_data_model import DatabaseMessages + # 如果内存中没有已读消息(比如刚启动),则从数据库加载最近的上下文 fallback_messages_dicts = await get_raw_msg_before_timestamp_with_chat( chat_id=plan.chat_id, @@ -414,7 +416,7 @@ class ChatterPlanFilter: processed_plain_text=msg_dict.get("processed_plain_text", ""), key_words=msg_dict.get("key_words", "[]"), is_mentioned=msg_dict.get("is_mentioned", False), - **{"user_info": user_info_dict} # 通过kwargs传入user_info + **{"user_info": user_info_dict}, # 通过kwargs传入user_info ) else: # 如果没有user_info字段,使用平铺的字段(flatten()方法返回的格式) @@ -425,13 +427,12 @@ class ChatterPlanFilter: user_platform=msg_dict.get("user_platform", ""), processed_plain_text=msg_dict.get("processed_plain_text", ""), key_words=msg_dict.get("key_words", "[]"), - is_mentioned=msg_dict.get("is_mentioned", False) + is_mentioned=msg_dict.get("is_mentioned", False), ) # 计算消息兴趣度 interest_score_obj = await chatter_interest_scoring_system._calculate_single_message_score( - message=db_message, - bot_nickname=global_config.bot.nickname + message=db_message, bot_nickname=global_config.bot.nickname ) interest_score = interest_score_obj.total_score @@ -454,7 +455,7 @@ class ChatterPlanFilter: try: # 从新的actions结构中获取动作信息 actions_obj = action_json.get("actions", {}) - + # 处理actions字段可能是字典或列表的情况 actions_to_process = [] if isinstance(actions_obj, dict): @@ -463,19 +464,23 @@ class ChatterPlanFilter: actions_to_process.extend(actions_obj) if not actions_to_process: - actions_to_process.append({"action_type": "no_action", "reason": "actions格式错误"}) + actions_to_process.append({"action_type": "no_action", "reason": "actions格式错误"}) for single_action_obj in actions_to_process: if not isinstance(single_action_obj, dict): continue action = single_action_obj.get("action_type", "no_action") - reasoning = single_action_obj.get("reasoning", "未提供原因") # 兼容旧的reason字段 + reasoning = single_action_obj.get("reasoning", "未提供原因") # 兼容旧的reason字段 action_data = single_action_obj.get("action_data", {}) - + # 为了向后兼容,如果action_data不存在,则从顶层字段获取 if not action_data: - action_data = {k: v for k, v in single_action_obj.items() if k not in ["action_type", "reason", "reasoning", "thinking"]} + action_data = { + k: v + for k, v in single_action_obj.items() + if k not in ["action_type", "reason", "reasoning", "thinking"] + } # 保留原始的thinking字段(如果有) thinking = action_json.get("thinking", "") @@ -501,7 +506,9 @@ class ChatterPlanFilter: # reply动作必须有目标消息,使用最新消息作为兜底 target_message_dict = self._get_latest_message(message_id_list) if target_message_dict: - logger.info(f"[{action}] 使用最新消息作为目标: {target_message_dict.get('message_id')}") + logger.info( + f"[{action}] 使用最新消息作为目标: {target_message_dict.get('message_id')}" + ) else: logger.error(f"[{action}] 无法找到任何目标消息,降级为no_action") action = "no_action" @@ -509,15 +516,21 @@ class ChatterPlanFilter: elif action in ["poke_user", "set_emoji_like"]: # 这些动作可以尝试其他策略 - target_message_dict = self._find_poke_notice(message_id_list) or self._get_latest_message(message_id_list) + target_message_dict = self._find_poke_notice( + message_id_list + ) or self._get_latest_message(message_id_list) if target_message_dict: - logger.info(f"[{action}] 使用替代消息作为目标: {target_message_dict.get('message_id')}") + logger.info( + f"[{action}] 使用替代消息作为目标: {target_message_dict.get('message_id')}" + ) else: # 其他动作使用最新消息或跳过 target_message_dict = self._get_latest_message(message_id_list) if target_message_dict: - logger.info(f"[{action}] 使用最新消息作为目标: {target_message_dict.get('message_id')}") + logger.info( + f"[{action}] 使用最新消息作为目标: {target_message_dict.get('message_id')}" + ) else: # 如果LLM没有指定target_message_id,进行特殊处理 if action == "poke_user": @@ -614,7 +627,7 @@ class ChatterPlanFilter: query_text=query, user_id="system", # 系统查询 scope_id="system", - limit=5 + limit=5, ) if not enhanced_memories: @@ -627,7 +640,9 @@ class ChatterPlanFilter: memory_type = memory_chunk.memory_type.value if memory_chunk.memory_type else "unknown" retrieved_memories.append((memory_type, content)) - memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories] + memory_statements = [ + f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories + ] except Exception as e: logger.warning(f"增强记忆系统检索失败,使用默认回复: {e}") @@ -648,12 +663,17 @@ class ChatterPlanFilter: if action_name == "set_emoji_like" and p_name == "emoji": # 特殊处理set_emoji_like的emoji参数 from src.plugins.built_in.social_toolkit_plugin.qq_emoji_list import qq_face - emoji_options = [re.search(r"\[表情:(.+?)\]", name).group(1) for name in qq_face.values() if re.search(r"\[表情:(.+?)\]", name)] + + emoji_options = [ + re.search(r"\[表情:(.+?)\]", name).group(1) + for name in qq_face.values() + if re.search(r"\[表情:(.+?)\]", name) + ] example_value = f"<从'{', '.join(emoji_options[:10])}...'中选择一个>" else: example_value = f"<{p_desc}>" params_json_list.append(f' "{p_name}": "{example_value}"') - + # 基础动作信息 action_description = action_info.description action_require = "\n".join(f"- {req}" for req in action_info.action_require) @@ -666,11 +686,11 @@ class ChatterPlanFilter: # 将参数列表合并到JSON示例中 if params_json_list: # 移除最后一行的逗号 - json_example_lines.extend([line.rstrip(',') for line in params_json_list]) + json_example_lines.extend([line.rstrip(",") for line in params_json_list]) json_example_lines.append(' "reason": "<执行该动作的详细原因>"') json_example_lines.append(" }") - + # 使用逗号连接内部元素,除了最后一个 json_parts = [] for i, line in enumerate(json_example_lines): @@ -678,14 +698,14 @@ class ChatterPlanFilter: if line.strip() in ["{", "}"]: json_parts.append(line) continue - + # 检查是否是最后一个需要逗号的元素 is_last_item = True - for next_line in json_example_lines[i+1:]: + for next_line in json_example_lines[i + 1 :]: if next_line.strip() not in ["}"]: is_last_item = False break - + if not is_last_item: json_parts.append(f"{line},") else: @@ -713,7 +733,7 @@ class ChatterPlanFilter: # 1. 标准化处理:去除可能的格式干扰 original_id = str(message_id).strip() - normalized_id = original_id.strip('<>"\'').strip() + normalized_id = original_id.strip("<>\"'").strip() if not normalized_id: return None @@ -731,12 +751,13 @@ class ChatterPlanFilter: # 处理包含在文本中的ID格式 (如 "消息m123" -> 提取 m123) import re + # 尝试提取各种格式的ID id_patterns = [ - r'm\d+', # m123格式 - r'\d+', # 纯数字格式 - r'buffered-[a-f0-9-]+', # buffered-xxxx格式 - r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', # UUID格式 + r"m\d+", # m123格式 + r"\d+", # 纯数字格式 + r"buffered-[a-f0-9-]+", # buffered-xxxx格式 + r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", # UUID格式 ] for pattern in id_patterns: @@ -771,12 +792,12 @@ class ChatterPlanFilter: # 4. 尝试模糊匹配(数字部分匹配) for candidate in candidate_ids: # 提取数字部分进行模糊匹配 - number_part = re.sub(r'[^0-9]', '', candidate) + number_part = re.sub(r"[^0-9]", "", candidate) if number_part: for item in message_id_list: if isinstance(item, dict): item_id = item.get("id", "") - item_number = re.sub(r'[^0-9]', '', item_id) + item_number = re.sub(r"[^0-9]", "", item_id) # 数字部分匹配 if item_number == number_part: @@ -787,7 +808,7 @@ class ChatterPlanFilter: message_obj = item.get("message") if isinstance(message_obj, dict): orig_mid = message_obj.get("message_id") or message_obj.get("id") - orig_number = re.sub(r'[^0-9]', '', str(orig_mid)) if orig_mid else "" + orig_number = re.sub(r"[^0-9]", "", str(orig_mid)) if orig_mid else "" if orig_number == number_part: logger.debug(f"模糊匹配成功(消息对象): {candidate} -> {orig_mid}") return message_obj diff --git a/src/plugins/built_in/affinity_flow_chatter/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner.py index 420487d4f..b6be512b9 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner.py @@ -4,7 +4,6 @@ """ from dataclasses import asdict -import time from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor @@ -90,15 +89,16 @@ class ChatterActionPlanner: try: # 在规划前,先进行动作修改 from src.chat.planner_actions.action_modifier import ActionModifier + action_modifier = ActionModifier(self.action_manager, self.chat_id) await action_modifier.modify_actions() - + # 1. 生成初始 Plan initial_plan = await self.generator.generate(context.chat_mode) # 确保Plan中包含所有当前可用的动作 initial_plan.available_actions = self.action_manager.get_using_actions() - + unread_messages = context.get_unread_messages() if context else [] # 2. 使用新的兴趣度管理系统进行评分 score = 0.0 @@ -117,7 +117,9 @@ class ChatterActionPlanner: message_interest = interest_score.total_score message.interest_value = message_interest - message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold + message.should_reply = ( + message_interest > global_config.affinity_flow.non_reply_action_interest_threshold + ) interest_updates.append( { diff --git a/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py b/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py index 66b3ca31f..2320670a0 100644 --- a/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py +++ b/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py @@ -596,7 +596,7 @@ class ChatterRelationshipTracker: quality = response_data.get("interaction_quality", "medium") # 更新数据库 - await self._update_user_relationship_in_db(user_id, new_text, new_score) + await self._update_user_relationship_in_db(user_id, new_text, new_score) # 更新缓存 self.user_relationship_cache[user_id] = { diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index f043e4f49..a477fdf0a 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -268,7 +268,7 @@ class EmojiAction(BaseAction): if not success: logger.error(f"{self.log_prefix} 表情包发送失败") await self.store_action_info( - action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包,但失败了", action_done=False + action_build_into_prompt=True, action_prompt_display="发送了一个表情包,但失败了", action_done=False ) return False, "表情包发送失败" @@ -279,7 +279,7 @@ class EmojiAction(BaseAction): logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}") await self.store_action_info( - action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包", action_done=True + action_build_into_prompt=True, action_prompt_display="发送了一个表情包", action_done=True ) return True, f"发送表情包: {emoji_description}" diff --git a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py index 91dea3105..194a2c5ef 100644 --- a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py +++ b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py @@ -47,9 +47,9 @@ class SearchKnowledgeFromLPMMTool(BaseTool): knowledge_parts = [] for i, item in enumerate(knowledge_info["knowledge_items"]): knowledge_parts.append(f"- {item.get('content', 'N/A')}") - + knowledge_text = "\n".join(knowledge_parts) - summary = knowledge_info.get('summary', '无总结') + summary = knowledge_info.get("summary", "无总结") content = f"关于 '{query}', 你知道以下信息:\n{knowledge_text}\n\n总结: {summary}" else: content = f"关于 '{query}',你的知识库里好像没有相关的信息呢" diff --git a/src/plugins/built_in/maizone_refactored/plugin.py b/src/plugins/built_in/maizone_refactored/plugin.py index 4ea543f68..e8259b5cb 100644 --- a/src/plugins/built_in/maizone_refactored/plugin.py +++ b/src/plugins/built_in/maizone_refactored/plugin.py @@ -87,6 +87,7 @@ class MaiZoneRefactoredPlugin(BasePlugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + async def on_plugin_loaded(self): """插件加载完成后的回调,初始化服务并启动后台任务""" # --- 注册权限节点 --- diff --git a/src/plugins/built_in/maizone_refactored/services/cookie_service.py b/src/plugins/built_in/maizone_refactored/services/cookie_service.py index e3692f883..9da05582c 100644 --- a/src/plugins/built_in/maizone_refactored/services/cookie_service.py +++ b/src/plugins/built_in/maizone_refactored/services/cookie_service.py @@ -140,5 +140,7 @@ class CookieService: self._save_cookies_to_file(qq_account, cookies) return cookies - logger.error(f"为 {qq_account} 获取Cookie的所有方法均失败。请确保Napcat HTTP服务或Adapter连接至少有一个正常工作,或存在有效的本地Cookie文件。") + logger.error( + f"为 {qq_account} 获取Cookie的所有方法均失败。请确保Napcat HTTP服务或Adapter连接至少有一个正常工作,或存在有效的本地Cookie文件。" + ) return None diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index 186079965..c0e00b80d 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -290,9 +290,7 @@ class QZoneService: comment_content = comment.get("content", "") try: - reply_content = await self.content_service.generate_comment_reply( - content, comment_content, nickname - ) + reply_content = await self.content_service.generate_comment_reply(content, comment_content, nickname) if reply_content: success = await api_client["reply"](fid, qq_account, nickname, reply_content, comment_tid) if success: @@ -532,7 +530,9 @@ class QZoneService: async def _get_api_client(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict]: cookies = await self.cookie_service.get_cookies(qq_account, stream_id) if not cookies: - logger.error("获取API客户端失败:未能获取到Cookie。请检查Napcat连接是否正常,或是否存在有效的本地Cookie文件。") + logger.error( + "获取API客户端失败:未能获取到Cookie。请检查Napcat连接是否正常,或是否存在有效的本地Cookie文件。" + ) return None p_skey = cookies.get("p_skey") or cookies.get("p_skey".upper()) @@ -726,7 +726,8 @@ class QZoneService: return {"pic_bo": picbo, "richval": richval} except Exception as e: logger.error( - f"从上传结果中提取图片参数失败: {e}, 上传结果: {upload_result}", exc_info=True + f"从上传结果中提取图片参数失败: {e}, 上传结果: {upload_result}", + exc_info=True, ) return None else: @@ -764,7 +765,9 @@ class QZoneService: json_data = orjson.loads(res_text) if json_data.get("code") != 0: - logger.warning(f"获取说说列表API返回错误: code={json_data.get('code')}, message={json_data.get('message')}") + logger.warning( + f"获取说说列表API返回错误: code={json_data.get('code')}, message={json_data.get('message')}" + ) return [] feeds_list = [] @@ -797,7 +800,7 @@ class QZoneService: for c in commentlist: if not isinstance(c, dict): continue - + # 添加主评论 comments.append( { @@ -822,7 +825,7 @@ class QZoneService: "parent_tid": c.get("tid"), # 父ID是主评论的ID } ) - + feeds_list.append( { "tid": msg.get("tid", ""), diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py index 8725edeff..2c1c4d861 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py @@ -835,7 +835,7 @@ class MessageHandler: if music: tag = music.get("tag", "未知来源") logger.debug(f"检测到【{tag}】音乐分享消息 (music view),开始提取信息") - + title = music.get("title", "未知歌曲") desc = music.get("desc", "未知艺术家") jump_url = music.get("jumpUrl", "") @@ -853,7 +853,7 @@ class MessageHandler: artist = parts[1] else: artist = desc - + formatted_content = ( f"这是一张来自【{tag}】的音乐分享卡片:\n" f"歌曲: {song_title}\n" @@ -870,12 +870,12 @@ class MessageHandler: if news and "网易云音乐" in news.get("tag", ""): tag = news.get("tag") logger.debug(f"检测到【{tag}】音乐分享消息 (news view),开始提取信息") - + title = news.get("title", "未知歌曲") desc = news.get("desc", "未知艺术家") jump_url = news.get("jumpUrl", "") preview_url = news.get("preview", "") - + formatted_content = ( f"这是一张来自【{tag}】的音乐分享卡片:\n" f"标题: {title}\n" diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py index 411b957dc..11e4275f4 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py @@ -3,7 +3,6 @@ import time import random import websockets as Server import uuid -import asyncio from maim_message import ( UserInfo, GroupInfo, @@ -96,7 +95,9 @@ class SendHandler: logger.error("无法识别的消息类型") return logger.info("尝试发送到napcat") - logger.debug(f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'") + logger.debug( + f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'" + ) response = await self.send_message_to_napcat( action, { diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py index 1a8c3a95d..3ec4ca181 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py @@ -6,7 +6,6 @@ import urllib3 import ssl import io -import asyncio import time from asyncio import Lock @@ -75,7 +74,7 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d except Exception as e: logger.error(f"获取群信息失败: {e}") return None - + data = socket_response.get("data") if data: await set_to_cache(cache_key, data) @@ -114,7 +113,7 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use cached_data = await get_from_cache(cache_key) if cached_data: return cached_data - + logger.debug(f"获取群成员信息中 (无缓存): group={group_id}, user={user_id}") request_uuid = str(uuid.uuid4()) payload = json.dumps( @@ -133,7 +132,7 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use except Exception as e: logger.error(f"获取成员信息失败: {e}") return None - + data = socket_response.get("data") if data: await set_to_cache(cache_key, data) @@ -203,7 +202,7 @@ async def get_self_info(websocket: Server.ServerConnection) -> dict | None: except Exception as e: logger.error(f"获取自身信息失败: {e}") return None - + data = response.get("data") if data: await set_to_cache(cache_key, data) diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index 0f575e13f..5061cf496 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -260,7 +260,6 @@ class ManagementCommand(PlusCommand): except Exception as e: await self.send_text(f"❌ 强制重载过程中发生错误: {str(e)}") - async def _add_dir(self, dir_path: str): """添加插件目录""" await self.send_text(f"📁 正在添加插件目录: `{dir_path}`") @@ -501,13 +500,13 @@ class PluginManagementPlugin(BasePlugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # 注册权限节点 - + async def on_plugin_loaded(self): await permission_api.register_permission_node( - "plugin.management.admin", - "插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作", - "plugin_management", - False, + "plugin.management.admin", + "插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作", + "plugin_management", + False, ) def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]: diff --git a/src/plugins/built_in/proactive_thinker/plugin.py b/src/plugins/built_in/proactive_thinker/plugin.py index c04f82927..5e55e9101 100644 --- a/src/plugins/built_in/proactive_thinker/plugin.py +++ b/src/plugins/built_in/proactive_thinker/plugin.py @@ -1,27 +1,23 @@ -from typing import List, Tuple, Union, Type, Optional +from typing import List, Tuple, Type from src.common.logger import get_logger -from src.config.official_configs import AffinityFlowConfig from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system import ( BasePlugin, ConfigField, register_plugin, - plugin_manage_api, - component_manage_api, - ComponentInfo, - ComponentType, EventHandlerInfo, - EventType, BaseEventHandler, ) from .proacive_thinker_event import ProactiveThinkerEventHandler logger = get_logger(__name__) + @register_plugin class ProactiveThinkerPlugin(BasePlugin): """一个主动思考的插件,但现在还只是个空壳子""" + plugin_name: str = "proactive_thinker" enable_plugin: bool = False dependencies: list[str] = [] @@ -33,6 +29,7 @@ class ProactiveThinkerPlugin(BasePlugin): "config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本"), }, } + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -42,4 +39,3 @@ class ProactiveThinkerPlugin(BasePlugin): (ProactiveThinkerEventHandler.get_handler_info(), ProactiveThinkerEventHandler) ] return components - diff --git a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py index 55492c7f9..5ad560243 100644 --- a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py +++ b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py @@ -2,7 +2,7 @@ import asyncio import random import time from datetime import datetime -from typing import List, Union, Type, Optional +from typing import List, Union from maim_message import UserInfo @@ -69,7 +69,7 @@ class ColdStartTask(AsyncTask): # 创建 UserInfo 对象,这是创建聊天流的必要信息 user_info = UserInfo(platform=platform, user_id=str(user_id), user_nickname=user_nickname) - + # 【关键步骤】主动创建聊天流。 # 创建后,该用户就进入了机器人的“好友列表”,后续将由 ProactiveThinkingTask 接管 stream = await self.chat_manager.get_or_create_stream(platform, user_info) @@ -175,10 +175,12 @@ class ProactiveThinkingTask(AsyncTask): # 2. 【核心逻辑】检查聊天冷却时间是否足够长 time_since_last_active = time.time() - stream.last_active_time if time_since_last_active > next_interval: - logger.info(f"【日常唤醒】聊天流 {stream.stream_id} 已冷却 {time_since_last_active:.2f} 秒,触发主动对话。") - + logger.info( + f"【日常唤醒】聊天流 {stream.stream_id} 已冷却 {time_since_last_active:.2f} 秒,触发主动对话。" + ) + await self.executor.execute(stream_id=stream.stream_id, start_mode="wake_up") - + # 【关键步骤】在触发后,立刻更新活跃时间并保存。 # 这可以防止在同一个检查周期内,对同一个目标因为意外的延迟而发送多条消息。 stream.update_active_time() diff --git a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py index b5f280fee..ab3631450 100644 --- a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py +++ b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py @@ -3,7 +3,16 @@ from typing import Optional, Dict, Any from datetime import datetime from src.common.logger import get_logger -from src.plugin_system.apis import chat_api, person_api, schedule_api, send_api, llm_api, message_api, generator_api, database_api +from src.plugin_system.apis import ( + chat_api, + person_api, + schedule_api, + send_api, + llm_api, + message_api, + generator_api, + database_api, +) from src.config.config import global_config, model_config from src.person_info.person_info import get_person_info_manager @@ -39,17 +48,16 @@ class ProactiveThinkerExecutor: # 2. 决策阶段 decision_result = await self._make_decision(context, start_mode) - if not decision_result or not decision_result.get("should_reply"): reason = decision_result.get("reason", "未提供") if decision_result else "决策过程返回None" logger.info(f"决策结果为:不回复。原因: {reason}") await database_api.store_action_info( - chat_stream=self._get_stream_from_id(stream_id), - action_name="proactive_decision", - action_prompt_display=f"主动思考决定不回复,原因: {reason}", - action_done = True, - action_data=decision_result - ) + chat_stream=self._get_stream_from_id(stream_id), + action_name="proactive_decision", + action_prompt_display=f"主动思考决定不回复,原因: {reason}", + action_done=True, + action_data=decision_result, + ) return # 3. 规划与执行阶段 @@ -59,15 +67,17 @@ class ProactiveThinkerExecutor: chat_stream=self._get_stream_from_id(stream_id), action_name="proactive_decision", action_prompt_display=f"主动思考决定回复,原因: {reason},话题:{topic}", - action_done = True, - action_data=decision_result + action_done=True, + action_data=decision_result, ) logger.info(f"决策结果为:回复。话题: {topic}") - + plan_prompt = self._build_plan_prompt(context, start_mode, topic, reason) - - is_success, response, _, _ = await llm_api.generate_with_model(prompt=plan_prompt, model_config=model_config.model_task_config.utils) - + + is_success, response, _, _ = await llm_api.generate_with_model( + prompt=plan_prompt, model_config=model_config.model_task_config.utils + ) + if is_success and response: stream = self._get_stream_from_id(stream_id) if stream: @@ -104,33 +114,41 @@ class ProactiveThinkerExecutor: if not user_info or not user_info.platform or not user_info.user_id: logger.warning(f"Stream {stream_id} 的 user_info 不完整") return None - + person_id = person_api.get_person_id(user_info.platform, int(user_info.user_id)) person_info_manager = get_person_info_manager() # 获取日程 schedules = await schedule_api.ScheduleAPI.get_today_schedule() - schedule_context = "\n".join([f"- {s['title']} ({s['start_time']}-{s['end_time']})" for s in schedules]) if schedules else "今天没有日程安排。" + schedule_context = ( + "\n".join([f"- {s['title']} ({s['start_time']}-{s['end_time']})" for s in schedules]) + if schedules + else "今天没有日程安排。" + ) # 获取关系信息 short_impression = await person_info_manager.get_value(person_id, "short_impression") or "无" impression = await person_info_manager.get_value(person_id, "impression") or "无" attitude = await person_info_manager.get_value(person_id, "attitude") or 50 - + # 获取最近聊天记录 recent_messages = await message_api.get_recent_messages(stream_id, limit=10) - recent_chat_history = await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else "无" - + recent_chat_history = ( + await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else "无" + ) + # 获取最近的动作历史 action_history = await database_api.db_query( database_api.MODEL_MAPPING["ActionRecords"], filters={"chat_id": stream_id, "action_name": "proactive_decision"}, limit=3, - order_by=["-time"] + order_by=["-time"], ) action_history_context = "无" if isinstance(action_history, list): - action_history_context = "\n".join([f"- {a['action_data']}" for a in action_history if isinstance(a, dict)]) or "无" + action_history_context = ( + "\n".join([f"- {a['action_data']}" for a in action_history if isinstance(a, dict)]) or "无" + ) return { "person_id": person_id, @@ -138,47 +156,43 @@ class ProactiveThinkerExecutor: "schedule_context": schedule_context, "recent_chat_history": recent_chat_history, "action_history_context": action_history_context, - "relationship": { - "short_impression": short_impression, - "impression": impression, - "attitude": attitude - }, + "relationship": {"short_impression": short_impression, "impression": impression, "attitude": attitude}, "persona": { "core": global_config.personality.personality_core, "side": global_config.personality.personality_side, "identity": global_config.personality.identity, }, - "current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } async def _make_decision(self, context: Dict[str, Any], start_mode: str) -> Optional[Dict[str, Any]]: """ 决策模块:判断是否应该主动发起对话,以及聊什么话题 """ - persona = context['persona'] - user_info = context['user_info'] - relationship = context['relationship'] + persona = context["persona"] + user_info = context["user_info"] + relationship = context["relationship"] prompt = f""" # 角色 你的名字是{global_config.bot.nickname},你的人设如下: -- 核心人设: {persona['core']} -- 侧面人设: {persona['side']} -- 身份: {persona['identity']} +- 核心人设: {persona["core"]} +- 侧面人设: {persona["side"]} +- 身份: {persona["identity"]} # 任务 -现在是 {context['current_time']},你需要根据当前的情境,决定是否要主动向用户 '{user_info.user_nickname}' 发起对话。 +现在是 {context["current_time"]},你需要根据当前的情境,决定是否要主动向用户 '{user_info.user_nickname}' 发起对话。 # 情境分析 -1. **启动模式**: {start_mode} ({'初次见面/很久未见' if start_mode == 'cold_start' else '日常唤醒'}) +1. **启动模式**: {start_mode} ({"初次见面/很久未见" if start_mode == "cold_start" else "日常唤醒"}) 2. **你的日程**: -{context['schedule_context']} +{context["schedule_context"]} 3. **你和Ta的关系**: - - 简短印象: {relationship['short_impression']} - - 详细印象: {relationship['impression']} - - 好感度: {relationship['attitude']}/100 + - 简短印象: {relationship["short_impression"]} + - 详细印象: {relationship["impression"]} + - 好感度: {relationship["attitude"]}/100 4. **最近的聊天摘要**: -{context['recent_chat_history']} +{context["recent_chat_history"]} # 决策指令 请综合以上所有信息,做出决策。你的决策需要以JSON格式输出,包含以下字段: @@ -204,9 +218,11 @@ class ProactiveThinkerExecutor: 请输出你的决策: """ - - is_success, response, _, _ = await llm_api.generate_with_model(prompt=prompt, model_config=model_config.model_task_config.utils) - + + is_success, response, _, _ = await llm_api.generate_with_model( + prompt=prompt, model_config=model_config.model_task_config.utils + ) + if not is_success: return {"should_reply": False, "reason": "决策模型生成失败"} @@ -222,17 +238,17 @@ class ProactiveThinkerExecutor: """ 根据启动模式和决策话题,构建最终的规划提示词 """ - persona = context['persona'] - user_info = context['user_info'] - relationship = context['relationship'] + persona = context["persona"] + user_info = context["user_info"] + relationship = context["relationship"] if start_mode == "cold_start": prompt = f""" # 角色 你的名字是{global_config.bot.nickname},你的人设如下: -- 核心人设: {persona['core']} -- 侧面人设: {persona['side']} -- 身份: {persona['identity']} +- 核心人设: {persona["core"]} +- 侧面人设: {persona["side"]} +- 身份: {persona["identity"]} # 任务 你需要主动向一个新朋友 '{user_info.user_nickname}' 发起对话。这是你们的第一次交流,或者很久没聊了。 @@ -240,9 +256,9 @@ class ProactiveThinkerExecutor: # 决策上下文 - **决策理由**: {reason} - **你和Ta的关系**: - - 简短印象: {relationship['short_impression']} - - 详细印象: {relationship['impression']} - - 好感度: {relationship['attitude']}/100 + - 简短印象: {relationship["short_impression"]} + - 详细印象: {relationship["impression"]} + - 好感度: {relationship["attitude"]}/100 # 对话指引 - 你的目标是“破冰”,让对话自然地开始。 @@ -254,26 +270,26 @@ class ProactiveThinkerExecutor: prompt = f""" # 角色 你的名字是{global_config.bot.nickname},你的人设如下: -- 核心人设: {persona['core']} -- 侧面人设: {persona['side']} -- 身份: {persona['identity']} +- 核心人设: {persona["core"]} +- 侧面人设: {persona["side"]} +- 身份: {persona["identity"]} # 任务 -现在是 {context['current_time']},你需要主动向你的朋友 '{user_info.user_nickname}' 发起对话。 +现在是 {context["current_time"]},你需要主动向你的朋友 '{user_info.user_nickname}' 发起对话。 # 决策上下文 - **决策理由**: {reason} # 情境分析 1. **你的日程**: -{context['schedule_context']} +{context["schedule_context"]} 2. **你和Ta的关系**: - - 详细印象: {relationship['impression']} - - 好感度: {relationship['attitude']}/100 + - 详细印象: {relationship["impression"]} + - 好感度: {relationship["attitude"]}/100 3. **最近的聊天摘要**: -{context['recent_chat_history']} +{context["recent_chat_history"]} 4. **你最近的相关动作**: -{context['action_history_context']} +{context["action_history_context"]} # 对话指引 - 你决定和Ta聊聊关于“{topic}”的话题。 diff --git a/src/plugins/built_in/social_toolkit_plugin/plugin.py b/src/plugins/built_in/social_toolkit_plugin/plugin.py index 1dcd5c08e..a26879da7 100644 --- a/src/plugins/built_in/social_toolkit_plugin/plugin.py +++ b/src/plugins/built_in/social_toolkit_plugin/plugin.py @@ -16,8 +16,6 @@ from src.person_info.person_info import get_person_info_manager from dateutil.parser import parse as parse_datetime from src.manager.async_task_manager import AsyncTask, async_task_manager from src.plugin_system.apis import send_api, llm_api, generator_api -from src.plugin_system.base.component_types import ComponentType -from typing import Optional from src.chat.message_receive.chat_stream import ChatStream import asyncio import datetime diff --git a/src/schedule/database.py b/src/schedule/database.py index 9117b9586..b33bfb953 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -95,9 +95,7 @@ async def mark_plans_completed(plan_ids: List[int]): plan_details = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans_to_mark)]) logger.info(f"以下 {len(plans_to_mark)} 条月度计划将被标记为已完成:\n{plan_details}") - await session.execute( - update(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)).values(status="completed") - ) + await session.execute(update(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)).values(status="completed")) await session.commit() except Exception as e: logger.error(f"标记月度计划为完成时发生错误: {e}") @@ -184,9 +182,7 @@ async def update_plan_usage(plan_ids: List[int], used_date: str): raise -async def get_smart_plans_for_daily_schedule( - month: str, max_count: int = 3, avoid_days: int = 7 -) -> List[MonthlyPlan]: +async def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_days: int = 7) -> List[MonthlyPlan]: """ 智能抽取月度计划用于每日日程生成。 @@ -208,14 +204,10 @@ async def get_smart_plans_for_daily_schedule( avoid_date = (datetime.now() - timedelta(days=avoid_days)).strftime("%Y-%m-%d") # 查询符合条件的计划 - query = select(MonthlyPlan).where( - MonthlyPlan.target_month == month, MonthlyPlan.status == "active" - ) + query = select(MonthlyPlan).where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") # 排除最近使用过的计划 - query = query.where( - (MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date) - ) + query = query.where((MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date)) # 按使用次数升序排列,优先选择使用次数少的 result = await session.execute(query.order_by(MonthlyPlan.usage_count.asc())) @@ -274,9 +266,7 @@ async def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]: async with get_db_session() as session: try: result = await session.execute( - select(MonthlyPlan).where( - MonthlyPlan.target_month == month, MonthlyPlan.status == "archived" - ) + select(MonthlyPlan).where(MonthlyPlan.target_month == month, MonthlyPlan.status == "archived") ) return result.scalars().all() except Exception as e: From 7923eafef389a7047755989325fb64b4ac4ad604 Mon Sep 17 00:00:00 2001 From: John Richard Date: Thu, 2 Oct 2025 20:26:01 +0800 Subject: [PATCH 2/3] =?UTF-8?q?re-style:=20=E6=A0=BC=E5=BC=8F=E5=8C=96?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __main__.py | 2 +- bot.py | 24 ++-- plugins/bilibli/__init__.py | 1 - plugins/bilibli/bilibli_base.py | 27 ++-- plugins/bilibli/plugin.py | 21 +-- plugins/echo_example/plugin.py | 25 ++-- plugins/hello_world_plugin/plugin.py | 32 ++--- pyproject.toml | 49 ++++--- scripts/expression_stats.py | 13 +- scripts/interest_value_analysis.py | 15 +-- scripts/log_viewer_optimized.py | 27 ++-- scripts/lpmm_learning_tool.py | 39 +++--- scripts/manifest_tool.py | 8 +- scripts/mongodb_to_sqlite.py | 82 +++++------ scripts/rebuild_metadata_index.py | 5 +- scripts/run_multi_stage_smoke.py | 3 +- scripts/text_length_analysis.py | 17 ++- scripts/update_prompt_imports.py | 2 +- src/__init__.py | 8 +- src/chat/__init__.py | 2 +- src/chat/antipromptinjector/__init__.py | 27 ++-- src/chat/antipromptinjector/anti_injector.py | 30 +++-- src/chat/antipromptinjector/core/__init__.py | 3 +- src/chat/antipromptinjector/core/detector.py | 21 ++- src/chat/antipromptinjector/core/shield.py | 7 +- src/chat/antipromptinjector/counter_attack.py | 6 +- .../antipromptinjector/decision/__init__.py | 5 +- .../decision/counter_attack.py | 6 +- .../decision/decision_maker.py | 2 +- src/chat/antipromptinjector/decision_maker.py | 2 +- src/chat/antipromptinjector/detector.py | 21 ++- .../antipromptinjector/management/__init__.py | 1 - .../management/statistics.py | 7 +- .../antipromptinjector/management/user_ban.py | 7 +- .../antipromptinjector/processors/__init__.py | 1 - .../processors/message_processor.py | 6 +- src/chat/antipromptinjector/types.py | 6 +- src/chat/chatter_manager.py | 23 ++-- src/chat/emoji_system/emoji_history.py | 6 +- src/chat/emoji_system/emoji_manager.py | 91 ++++++------- src/chat/energy_system/__init__.py | 20 +-- src/chat/energy_system/energy_manager.py | 36 ++--- src/chat/express/expression_learner.py | 41 +++--- src/chat/express/expression_selector.py | 32 ++--- src/chat/frequency_analyzer/analyzer.py | 11 +- src/chat/frequency_analyzer/tracker.py | 10 +- src/chat/frequency_analyzer/trigger.py | 7 +- src/chat/interest_system/__init__.py | 5 +- .../interest_system/bot_interest_manager.py | 49 +++---- src/chat/knowledge/embedding_store.py | 66 +++++---- src/chat/knowledge/ie_process.py | 19 +-- src/chat/knowledge/kg_manager.py | 42 +++--- src/chat/knowledge/knowledge_lib.py | 11 +- src/chat/knowledge/open_ie.py | 17 +-- src/chat/knowledge/qa_manager.py | 15 ++- src/chat/knowledge/utils/dyn_topk.py | 6 +- src/chat/memory_system/__init__.py | 28 ++-- .../enhanced_memory_adapter.py | 44 +++--- .../enhanced_memory_hooks.py | 12 +- .../enhanced_memory_integration.py | 22 +-- .../deprecated_backup/enhanced_reranker.py | 27 ++-- .../deprecated_backup/integration_layer.py | 26 ++-- .../memory_integration_hooks.py | 28 ++-- .../deprecated_backup/metadata_index.py | 86 ++++++------ .../multi_stage_retrieval.py | 113 ++++++++-------- .../deprecated_backup/vector_storage.py | 77 +++++------ .../enhanced_memory_activator.py | 21 ++- .../memory_system/memory_activator_new.py | 21 ++- src/chat/memory_system/memory_builder.py | 73 +++++----- src/chat/memory_system/memory_chunk.py | 77 +++++------ .../memory_system/memory_forgetting_engine.py | 22 ++- src/chat/memory_system/memory_fusion.py | 38 +++--- src/chat/memory_system/memory_manager.py | 38 +++--- .../memory_system/memory_metadata_index.py | 98 +++++++------- .../memory_system/memory_query_planner.py | 33 +++-- src/chat/memory_system/memory_system.py | 106 +++++++-------- .../memory_system/vector_memory_storage_v2.py | 66 ++++----- src/chat/message_manager/__init__.py | 4 +- src/chat/message_manager/context_manager.py | 23 ++-- .../message_manager/distribution_manager.py | 24 ++-- src/chat/message_manager/message_manager.py | 23 ++-- .../sleep_manager/sleep_manager.py | 9 +- .../sleep_manager/sleep_state.py | 9 +- .../sleep_manager/time_checker.py | 12 +- .../sleep_manager/wakeup_manager.py | 13 +- src/chat/message_receive/__init__.py | 5 +- src/chat/message_receive/bot.py | 39 +++--- src/chat/message_receive/chat_stream.py | 54 ++++---- src/chat/message_receive/message.py | 37 +++-- src/chat/message_receive/storage.py | 18 +-- .../message_receive/uni_message_sender.py | 9 +- src/chat/planner_actions/action_manager.py | 38 +++--- src/chat/planner_actions/action_modifier.py | 37 +++-- src/chat/replyer/default_generator.py | 84 ++++++------ src/chat/replyer/replyer_manager.py | 12 +- src/chat/utils/chat_message_builder.py | 81 ++++++----- src/chat/utils/memory_mappings.py | 1 - src/chat/utils/prompt.py | 91 ++++++------- src/chat/utils/statistic.py | 40 +++--- src/chat/utils/timer_calculator.py | 14 +- src/chat/utils/typo_generator.py | 12 +- src/chat/utils/utils.py | 28 ++-- src/chat/utils/utils_image.py | 40 +++--- src/chat/utils/utils_video.py | 55 ++++---- src/chat/utils/utils_video_legacy.py | 34 ++--- src/chat/utils/utils_voice.py | 8 +- src/common/cache_manager.py | 42 +++--- src/common/config_helpers.py | 8 +- .../data_models/bot_interest_data_model.py | 28 ++-- src/common/data_models/database_data_model.py | 52 +++---- src/common/data_models/info_data_model.py | 28 ++-- src/common/data_models/llm_data_model.py | 16 +-- .../data_models/message_manager_data_model.py | 27 ++-- src/common/database/database.py | 5 +- .../database/sqlalchemy_database_api.py | 70 +++++----- src/common/database/sqlalchemy_init.py | 6 +- src/common/database/sqlalchemy_models.py | 15 ++- src/common/logger.py | 14 +- src/common/message/__init__.py | 1 - src/common/message/api.py | 10 +- src/common/message_repository.py | 16 +-- src/common/remote.py | 6 +- src/common/server.py | 14 +- src/common/tcp_connector.py | 3 +- src/common/vector_db/__init__.py | 2 +- src/common/vector_db/base.py | 34 ++--- src/common/vector_db/chromadb_impl.py | 41 +++--- src/config/api_ada_configs.py | 17 +-- src/config/config.py | 92 +++++++------ src/config/config_base.py | 10 +- src/config/official_configs.py | 45 ++++--- src/individuality/individuality.py | 19 +-- src/individuality/not_using/offline_llm.py | 20 +-- src/individuality/not_using/per_bf_gen.py | 30 ++--- src/individuality/not_using/scene.py | 5 +- src/llm_models/exceptions.py | 1 - .../model_client/aiohttp_gemini_client.py | 40 +++--- src/llm_models/model_client/base_client.py | 15 ++- src/llm_models/model_client/openai_client.py | 53 ++++---- src/llm_models/payload_content/message.py | 1 - src/llm_models/payload_content/resp_format.py | 13 +- src/llm_models/utils.py | 15 ++- src/llm_models/utils_model.py | 83 ++++++------ src/main.py | 56 ++++---- src/mais4u/mai_think.py | 13 +- .../body_emotion_action_manager.py | 12 +- src/mais4u/mais4u_chat/context_web_manager.py | 12 +- src/mais4u/mais4u_chat/gift_manager.py | 12 +- src/mais4u/mais4u_chat/internal_manager.py | 2 +- src/mais4u/mais4u_chat/s4u_chat.py | 44 +++--- src/mais4u/mais4u_chat/s4u_mood_manager.py | 11 +- src/mais4u/mais4u_chat/s4u_msg_processor.py | 14 +- src/mais4u/mais4u_chat/s4u_prompt.py | 34 ++--- .../mais4u_chat/s4u_stream_generator.py | 14 +- src/mais4u/mais4u_chat/screen_manager.py | 2 +- src/mais4u/mais4u_chat/super_chat_manager.py | 14 +- src/mais4u/mais4u_chat/yes_or_no.py | 2 +- src/mais4u/openai_client.py | 29 ++-- src/mais4u/s4u_config.py | 21 +-- src/manager/async_task_manager.py | 11 +- src/manager/local_store_manager.py | 7 +- src/mood/mood_manager.py | 11 +- src/person_info/person_info.py | 19 +-- src/person_info/relationship_builder.py | 21 +-- .../relationship_builder_manager.py | 11 +- src/person_info/relationship_fetcher.py | 17 ++- src/person_info/relationship_manager.py | 23 ++-- src/plugin_system/__init__.py | 70 +++++----- src/plugin_system/apis/__init__.py | 15 ++- src/plugin_system/apis/chat_api.py | 38 +++--- .../apis/component_manage_api.py | 23 ++-- src/plugin_system/apis/config_api.py | 1 + src/plugin_system/apis/cross_context_api.py | 20 +-- src/plugin_system/apis/database_api.py | 4 +- src/plugin_system/apis/emoji_api.py | 13 +- src/plugin_system/apis/generator_api.py | 44 +++--- src/plugin_system/apis/llm_api.py | 29 ++-- src/plugin_system/apis/message_api.py | 66 ++++----- src/plugin_system/apis/permission_api.py | 18 +-- src/plugin_system/apis/person_api.py | 7 +- src/plugin_system/apis/plugin_manage_api.py | 11 +- src/plugin_system/apis/plugin_register_api.py | 2 +- src/plugin_system/apis/schedule_api.py | 22 +-- src/plugin_system/apis/send_api.py | 37 +++-- src/plugin_system/apis/tool_api.py | 8 +- src/plugin_system/base/__init__.py | 20 +-- src/plugin_system/base/base_action.py | 39 +++--- src/plugin_system/base/base_chatter.py | 8 +- src/plugin_system/base/base_command.py | 16 +-- src/plugin_system/base/base_event.py | 18 +-- src/plugin_system/base/base_events_handler.py | 8 +- src/plugin_system/base/base_plugin.py | 20 ++- src/plugin_system/base/base_tool.py | 13 +- src/plugin_system/base/command_args.py | 5 +- src/plugin_system/base/component_types.py | 63 ++++----- src/plugin_system/base/config_types.py | 6 +- src/plugin_system/base/plugin_base.py | 51 +++---- src/plugin_system/base/plus_command.py | 31 +++-- src/plugin_system/core/__init__.py | 4 +- src/plugin_system/core/component_registry.py | 127 +++++++++--------- src/plugin_system/core/event_manager.py | 46 +++---- .../core/global_announcement_manager.py | 18 ++- src/plugin_system/core/permission_manager.py | 22 +-- src/plugin_system/core/plugin_manager.py | 46 +++---- src/plugin_system/core/tool_use.py | 43 +++--- src/plugin_system/utils/dependency_alias.py | 1 - src/plugin_system/utils/dependency_config.py | 3 +- src/plugin_system/utils/dependency_manager.py | 25 ++-- src/plugin_system/utils/manifest_utils.py | 17 +-- .../utils/permission_decorators.py | 16 +-- .../affinity_flow_chatter/affinity_chatter.py | 20 +-- .../affinity_flow_chatter/interest_scoring.py | 22 +-- .../affinity_flow_chatter/plan_executor.py | 19 ++- .../affinity_flow_chatter/plan_filter.py | 27 ++-- .../affinity_flow_chatter/plan_generator.py | 5 +- .../built_in/affinity_flow_chatter/planner.py | 34 +++-- .../built_in/affinity_flow_chatter/plugin.py | 6 +- .../relationship_tracker.py | 44 +++--- .../core_actions/anti_injector_manager.py | 6 +- src/plugins/built_in/core_actions/emoji.py | 19 ++- src/plugins/built_in/core_actions/plugin.py | 15 +-- .../built_in/knowledge/lpmm_get_knowledge.py | 10 +- .../built_in/maizone_refactored/__init__.py | 5 +- .../actions/read_feed_action.py | 8 +- .../actions/send_feed_action.py | 8 +- .../commands/send_feed_command.py | 10 +- .../built_in/maizone_refactored/plugin.py | 20 ++- .../services/content_service.py | 28 ++-- .../services/cookie_service.py | 22 +-- .../services/image_service.py | 3 +- .../maizone_refactored/services/manager.py | 7 +- .../services/monitor_service.py | 4 +- .../services/qzone_service.py | 53 ++++---- .../services/reply_tracker_service.py | 16 +-- .../services/scheduler_service.py | 10 +- .../maizone_refactored/utils/history_utils.py | 14 +- .../built_in/permission_management/plugin.py | 28 ++-- .../built_in/plugin_management/plugin.py | 31 +++-- .../built_in/proactive_thinker/plugin.py | 13 +- .../proacive_thinker_event.py | 8 +- .../proactive_thinker_executor.py | 23 ++-- .../built_in/social_toolkit_plugin/plugin.py | 59 ++++---- src/plugins/built_in/tts_plugin/plugin.py | 9 +- .../built_in/web_search_tool/engines/base.py | 4 +- .../web_search_tool/engines/bing_engine.py | 18 +-- .../web_search_tool/engines/ddg_engine.py | 6 +- .../web_search_tool/engines/exa_engine.py | 8 +- .../web_search_tool/engines/tavily_engine.py | 8 +- .../built_in/web_search_tool/plugin.py | 18 ++- .../web_search_tool/tools/url_parser.py | 15 ++- .../web_search_tool/tools/web_search.py | 26 ++-- .../web_search_tool/utils/api_key_manager.py | 14 +- .../web_search_tool/utils/formatters.py | 8 +- .../web_search_tool/utils/url_utils.py | 5 +- src/schedule/database.py | 19 +-- src/schedule/llm_generator.py | 16 ++- src/schedule/monthly_plan_manager.py | 4 +- src/schedule/plan_manager.py | 16 +-- src/schedule/schedule_manager.py | 17 +-- src/schedule/schemas.py | 6 +- src/utils/message_chunker.py | 16 ++- src/utils/timing_utils.py | 12 +- ui_log_adapter.py | 4 +- 263 files changed, 3103 insertions(+), 3123 deletions(-) diff --git a/__main__.py b/__main__.py index 15bf83a4e..f6d2a3178 100644 --- a/__main__.py +++ b/__main__.py @@ -12,7 +12,7 @@ if __name__ == "__main__": # 执行bot.py的代码 bot_file = current_dir / "bot.py" - with open(bot_file, "r", encoding="utf-8") as f: + with open(bot_file, encoding="utf-8") as f: exec(f.read()) diff --git a/bot.py b/bot.py index 798247c96..985065b99 100644 --- a/bot.py +++ b/bot.py @@ -1,30 +1,30 @@ # import asyncio import asyncio import os +import platform import sys import time -import platform import traceback from pathlib import Path -from rich.traceback import install -from colorama import init, Fore + +from colorama import Fore, init from dotenv import load_dotenv # 处理.env文件 +from rich.traceback import install # maim_message imports for console input - # 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式 -from src.common.logger import initialize_logging, get_logger, shutdown_logging +from src.common.logger import get_logger, initialize_logging, shutdown_logging # UI日志适配器 initialize_logging() from src.main import MainSystem # noqa -from src import BaseMain # noqa -from src.manager.async_task_manager import async_task_manager # noqa -from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge # noqa -from src.config.config import global_config # noqa -from src.common.database.database import initialize_sql_database # noqa -from src.common.database.sqlalchemy_models import initialize_database as init_db # noqa +from src import BaseMain +from src.manager.async_task_manager import async_task_manager +from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge +from src.config.config import global_config +from src.common.database.database import initialize_sql_database +from src.common.database.sqlalchemy_models import initialize_database as init_db logger = get_logger("main") @@ -247,7 +247,7 @@ if __name__ == "__main__": # The actual shutdown logic is now in the finally block. except Exception as e: - logger.error(f"主程序发生异常: {str(e)} {str(traceback.format_exc())}") + logger.error(f"主程序发生异常: {e!s} {traceback.format_exc()!s}") exit_code = 1 # 标记发生错误 finally: # 确保 loop 在任何情况下都尝试关闭(如果存在且未关闭) diff --git a/plugins/bilibli/__init__.py b/plugins/bilibli/__init__.py index ca649acac..7f6e5c3c2 100644 --- a/plugins/bilibli/__init__.py +++ b/plugins/bilibli/__init__.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ Bilibili 插件包 提供B站视频观看体验功能,像真实用户一样浏览和评价视频 diff --git a/plugins/bilibli/bilibli_base.py b/plugins/bilibli/bilibli_base.py index 34e794fd7..c35538dba 100644 --- a/plugins/bilibli/bilibli_base.py +++ b/plugins/bilibli/bilibli_base.py @@ -1,16 +1,17 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ Bilibili 工具基础模块 提供 B 站视频信息获取和视频分析功能 """ -import re -import aiohttp import asyncio -from typing import Optional, Dict, Any -from src.common.logger import get_logger +import re +from typing import Any + +import aiohttp + from src.chat.utils.utils_video import get_video_analyzer +from src.common.logger import get_logger logger = get_logger("bilibili_tool") @@ -25,7 +26,7 @@ class BilibiliVideoAnalyzer: "Referer": "https://www.bilibili.com/", } - def extract_bilibili_url(self, text: str) -> Optional[str]: + def extract_bilibili_url(self, text: str) -> str | None: """从文本中提取哔哩哔哩视频链接""" # 哔哩哔哩短链接模式 short_pattern = re.compile(r"https?://b23\.tv/[\w]+", re.IGNORECASE) @@ -44,7 +45,7 @@ class BilibiliVideoAnalyzer: return None - async def get_video_info(self, url: str) -> Optional[Dict[str, Any]]: + async def get_video_info(self, url: str) -> dict[str, Any] | None: """获取哔哩哔哩视频基本信息""" try: logger.info(f"🔍 解析视频URL: {url}") @@ -127,7 +128,7 @@ class BilibiliVideoAnalyzer: logger.exception("详细错误信息:") return None - async def get_video_stream_url(self, aid: int, cid: int) -> Optional[str]: + async def get_video_stream_url(self, aid: int, cid: int) -> str | None: """获取视频流URL""" try: logger.info(f"🎥 获取视频流URL: aid={aid}, cid={cid}") @@ -164,7 +165,7 @@ class BilibiliVideoAnalyzer: return stream_url # 降级到FLV格式 - if "durl" in play_data and play_data["durl"]: + if play_data.get("durl"): logger.info("📹 使用FLV格式视频流") stream_url = play_data["durl"][0].get("url") if stream_url: @@ -185,7 +186,7 @@ class BilibiliVideoAnalyzer: logger.exception("详细错误信息:") return None - async def download_video_bytes(self, stream_url: str, max_size_mb: int = 100) -> Optional[bytes]: + async def download_video_bytes(self, stream_url: str, max_size_mb: int = 100) -> bytes | None: """下载视频字节数据 Args: @@ -244,7 +245,7 @@ class BilibiliVideoAnalyzer: logger.exception("详细错误信息:") return None - async def analyze_bilibili_video(self, url: str, prompt: str = None) -> Dict[str, Any]: + async def analyze_bilibili_video(self, url: str, prompt: str = None) -> dict[str, Any]: """分析哔哩哔哩视频并返回详细信息和AI分析结果""" try: logger.info(f"🎬 开始分析哔哩哔哩视频: {url}") @@ -322,10 +323,10 @@ class BilibiliVideoAnalyzer: return result except Exception as e: - error_msg = f"分析哔哩哔哩视频时发生异常: {str(e)}" + error_msg = f"分析哔哩哔哩视频时发生异常: {e!s}" logger.error(f"❌ {error_msg}") logger.exception("详细错误信息:") # 记录完整的异常堆栈 - return {"error": f"分析失败: {str(e)}"} + return {"error": f"分析失败: {e!s}"} # 全局实例 diff --git a/plugins/bilibli/plugin.py b/plugins/bilibli/plugin.py index 72129c034..41f97bdeb 100644 --- a/plugins/bilibli/plugin.py +++ b/plugins/bilibli/plugin.py @@ -1,14 +1,15 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ Bilibili 视频观看体验工具 支持哔哩哔哩视频链接解析和AI视频内容分析 """ -from typing import Dict, Any, List, Tuple, Type -from src.plugin_system import BaseTool, ToolParamType, BasePlugin, register_plugin, ComponentInfo, ConfigField -from .bilibli_base import get_bilibili_analyzer +from typing import Any + from src.common.logger import get_logger +from src.plugin_system import BasePlugin, BaseTool, ComponentInfo, ConfigField, ToolParamType, register_plugin + +from .bilibli_base import get_bilibili_analyzer logger = get_logger("bilibili_tool") @@ -41,7 +42,7 @@ class BilibiliTool(BaseTool): super().__init__(plugin_config) self.analyzer = get_bilibili_analyzer() - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """执行哔哩哔哩视频观看体验""" try: url = function_args.get("url", "").strip() @@ -83,7 +84,7 @@ class BilibiliTool(BaseTool): return {"name": self.name, "content": content.strip()} except Exception as e: - error_msg = f"😅 看视频的时候出了点问题: {str(e)}" + error_msg = f"😅 看视频的时候出了点问题: {e!s}" logger.error(error_msg) return {"name": self.name, "content": error_msg} @@ -104,7 +105,7 @@ class BilibiliTool(BaseTool): return base_prompt - def _format_watch_experience(self, video_info: Dict, ai_analysis: str, interest_focus: str = None) -> str: + def _format_watch_experience(self, video_info: dict, ai_analysis: str, interest_focus: str = None) -> str: """格式化观看体验报告""" # 根据播放量生成热度评价 @@ -191,8 +192,8 @@ class BilibiliPlugin(BasePlugin): # 插件基本信息 plugin_name: str = "bilibili_video_watcher" enable_plugin: bool = True - dependencies: List[str] = [] - python_dependencies: List[str] = [] + dependencies: list[str] = [] + python_dependencies: list[str] = [] config_file_name: str = "config.toml" # 配置节描述 @@ -220,6 +221,6 @@ class BilibiliPlugin(BasePlugin): }, } - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: """返回插件包含的工具组件""" return [(BilibiliTool.get_tool_info(), BilibiliTool)] diff --git a/plugins/echo_example/plugin.py b/plugins/echo_example/plugin.py index 6f99cc901..e03429805 100644 --- a/plugins/echo_example/plugin.py +++ b/plugins/echo_example/plugin.py @@ -4,14 +4,15 @@ Echo 示例插件 展示增强命令系统的使用方法 """ -from typing import List, Tuple, Type, Optional, Union +from typing import Union + from src.plugin_system import ( BasePlugin, - PlusCommand, - CommandArgs, - PlusCommandInfo, - ConfigField, ChatType, + CommandArgs, + ConfigField, + PlusCommand, + PlusCommandInfo, register_plugin, ) from src.plugin_system.base.component_types import PythonDependency @@ -27,7 +28,7 @@ class EchoCommand(PlusCommand): chat_type_allow = ChatType.ALL intercept_message = True - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: """执行echo命令""" if args.is_empty(): await self.send_text("❓ 请提供要回显的内容\n用法: /echo <内容>") @@ -56,7 +57,7 @@ class HelloCommand(PlusCommand): chat_type_allow = ChatType.ALL intercept_message = True - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: """执行hello命令""" if args.is_empty(): await self.send_text("👋 Hello! 很高兴见到你!") @@ -77,7 +78,7 @@ class InfoCommand(PlusCommand): chat_type_allow = ChatType.ALL intercept_message = True - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: """执行info命令""" info_text = ( "📋 Echo 示例插件信息\n" @@ -105,7 +106,7 @@ class TestCommand(PlusCommand): chat_type_allow = ChatType.ALL intercept_message = True - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: """执行test命令""" if args.is_empty(): help_text = ( @@ -166,8 +167,8 @@ class EchoExamplePlugin(BasePlugin): plugin_name: str = "echo_example_plugin" enable_plugin: bool = True - dependencies: List[str] = [] - python_dependencies: List[Union[str, "PythonDependency"]] = [] + dependencies: list[str] = [] + python_dependencies: list[Union[str, "PythonDependency"]] = [] config_file_name: str = "config.toml" config_schema = { @@ -187,7 +188,7 @@ class EchoExamplePlugin(BasePlugin): "commands": "命令相关配置", } - def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type]]: + def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type]]: """获取插件组件""" components = [] diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index ca7a6a13a..2c71293a1 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -1,20 +1,20 @@ -from typing import List, Tuple, Type, Dict, Any, Optional import logging import random +from typing import Any from src.plugin_system import ( - BasePlugin, - register_plugin, - ComponentInfo, - BaseEventHandler, - EventType, - BaseTool, - PlusCommand, - CommandArgs, - ChatType, - BaseAction, ActionActivationType, + BaseAction, + BaseEventHandler, + BasePlugin, + BaseTool, + ChatType, + CommandArgs, + ComponentInfo, ConfigField, + EventType, + PlusCommand, + register_plugin, ) from src.plugin_system.base.base_event import HandlerResult @@ -39,7 +39,7 @@ class GetSystemInfoTool(BaseTool): available_for_llm = True parameters = [] - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: return {"name": self.name, "content": "系统版本: 1.0.1, 状态: 运行正常"} @@ -51,7 +51,7 @@ class HelloCommand(PlusCommand): command_aliases = ["hi", "你好"] chat_type_allow = ChatType.ALL - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: greeting = str(self.get_config("greeting.message", "Hello, World! 我是一个由 MoFox_Bot 驱动的插件。")) await self.send_text(greeting) return True, "成功发送问候", True @@ -67,7 +67,7 @@ class RandomEmojiAction(BaseAction): action_require = ["当对话气氛轻松时", "可以用来回应简单的情感表达"] associated_types = ["text"] - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: emojis = ["😊", "😂", "👍", "🎉", "🤔", "🤖"] await self.send_text(random.choice(emojis)) return True, "成功发送了一个随机表情" @@ -99,9 +99,9 @@ class HelloWorldPlugin(BasePlugin): }, } - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: """根据配置文件动态注册插件的功能组件。""" - components: List[Tuple[ComponentInfo, Type]] = [] + components: list[tuple[ComponentInfo, type]] = [] components.append((StartupMessageHandler.get_handler_info(), StartupMessageHandler)) components.append((GetSystemInfoTool.get_tool_info(), GetSystemInfoTool)) diff --git a/pyproject.toml b/pyproject.toml index a67f28472..04bb07299 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ dependencies = [ "tqdm>=4.67.1", "urllib3>=2.5.0", "uvicorn>=0.35.0", + "watchdog>=6.0.0", "websockets>=15.0.1", "aiomysql>=0.2.0", "aiosqlite>=0.21.0", @@ -80,29 +81,41 @@ dependencies = [ url = "https://pypi.tuna.tsinghua.edu.cn/simple" default = true +[tool.uv.sources] +amrita = { workspace = true } + [tool.ruff] - -include = ["*.py"] - -# 行长度设置 line-length = 120 +target-version = "py310" [tool.ruff.lint] -fixable = ["ALL"] -unfixable = [] +select = [ + "F", # Pyflakes + "W", # pycodestyle warnings + "E", # pycodestyle errors + "UP", # pyupgrade + "ASYNC", # flake8-async + "C4", # flake8-comprehensions + "T10", # flake8-debugger + "PYI", # flake8-pyi + "PT", # flake8-pytest-style + "Q", # flake8-quotes + "RUF", # Ruff-specific rules + "I", # isort + "PERF", # pylint-performance +] +ignore = [ + "E402", # module-import-not-at-top-of-file + "E501", # line-too-long + "UP037", # quoted-annotation + "RUF001", # ambiguous-unicode-character-string + "RUF002", # ambiguous-unicode-character-docstring + "RUF003", # ambiguous-unicode-character-comment +] + # 如果一个变量的名称以下划线开头,即使它未被使用,也不应该被视为错误或警告。 dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" - -# 启用的规则 -select = [ - "E", # pycodestyle 错误 - "F", # pyflakes - "B", # flake8-bugbear -] - -ignore = ["E711","E501"] - [tool.ruff.format] docstring-code-format = true indent-style = "space" @@ -124,6 +137,4 @@ skip-magic-trailing-comma = false line-ending = "auto" [dependency-groups] -lint = [ - "loguru>=0.7.3", -] +lint = ["loguru>=0.7.3"] diff --git a/scripts/expression_stats.py b/scripts/expression_stats.py index 133f3d73b..b79819493 100644 --- a/scripts/expression_stats.py +++ b/scripts/expression_stats.py @@ -1,10 +1,9 @@ -import time -import sys import os -from typing import Dict, List +import sys +import time # Add project root to Python path -from src.common.database.database_model import Expression, ChatStreams +from src.common.database.database_model import ChatStreams, Expression project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, project_root) @@ -30,7 +29,7 @@ def get_chat_name(chat_id: str) -> str: return f"查询失败 ({chat_id})" -def calculate_time_distribution(expressions) -> Dict[str, int]: +def calculate_time_distribution(expressions) -> dict[str, int]: """Calculate distribution of last active time in days""" now = time.time() distribution = { @@ -64,7 +63,7 @@ def calculate_time_distribution(expressions) -> Dict[str, int]: return distribution -def calculate_count_distribution(expressions) -> Dict[str, int]: +def calculate_count_distribution(expressions) -> dict[str, int]: """Calculate distribution of count values""" distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0} for expr in expressions: @@ -86,7 +85,7 @@ def calculate_count_distribution(expressions) -> Dict[str, int]: return distribution -def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]: +def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> list[Expression]: """Get top N most used expressions for a specific chat_id""" return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n) diff --git a/scripts/interest_value_analysis.py b/scripts/interest_value_analysis.py index bce37b4a2..e464c905c 100644 --- a/scripts/interest_value_analysis.py +++ b/scripts/interest_value_analysis.py @@ -1,7 +1,6 @@ -import time -import sys import os -from typing import Dict, List, Tuple, Optional +import sys +import time from datetime import datetime # Add project root to Python path @@ -35,7 +34,7 @@ def format_timestamp(timestamp: float) -> str: return "未知时间" -def calculate_interest_value_distribution(messages) -> Dict[str, int]: +def calculate_interest_value_distribution(messages) -> dict[str, int]: """Calculate distribution of interest_value""" distribution = { "0.000-0.010": 0, @@ -76,7 +75,7 @@ def calculate_interest_value_distribution(messages) -> Dict[str, int]: return distribution -def get_interest_value_stats(messages) -> Dict[str, float]: +def get_interest_value_stats(messages) -> dict[str, float]: """Calculate basic statistics for interest_value""" values = [ float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0 @@ -97,7 +96,7 @@ def get_interest_value_stats(messages) -> Dict[str, float]: } -def get_available_chats() -> List[Tuple[str, str, int]]: +def get_available_chats() -> list[tuple[str, str, int]]: """Get all available chats with message counts""" try: # 获取所有有消息的chat_id @@ -130,7 +129,7 @@ def get_available_chats() -> List[Tuple[str, str, int]]: return [] -def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: +def get_time_range_input() -> tuple[float | None, float | None]: """Get time range input from user""" print("\n时间范围选择:") print("1. 最近1天") @@ -170,7 +169,7 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: def analyze_interest_values( - chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None + chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None ) -> None: """Analyze interest values with optional filters""" diff --git a/scripts/log_viewer_optimized.py b/scripts/log_viewer_optimized.py index 2103e5486..f38dafa64 100644 --- a/scripts/log_viewer_optimized.py +++ b/scripts/log_viewer_optimized.py @@ -1,13 +1,14 @@ -import tkinter as tk -from tkinter import ttk, messagebox, filedialog, colorchooser -import orjson -from pathlib import Path -import threading -import toml -from datetime import datetime -from collections import defaultdict import os +import threading import time +import tkinter as tk +from collections import defaultdict +from datetime import datetime +from pathlib import Path +from tkinter import colorchooser, filedialog, messagebox, ttk + +import orjson +import toml class LogIndex: @@ -409,7 +410,7 @@ class AsyncLogLoader: file_size = os.path.getsize(file_path) processed_size = 0 - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: line_count = 0 batch_size = 1000 # 批量处理 @@ -561,7 +562,7 @@ class LogViewer: try: if config_path.exists(): - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: bot_config = toml.load(f) if "log" in bot_config: self.log_config.update(bot_config["log"]) @@ -575,7 +576,7 @@ class LogViewer: try: if viewer_config_path.exists(): - with open(viewer_config_path, "r", encoding="utf-8") as f: + with open(viewer_config_path, encoding="utf-8") as f: viewer_config = toml.load(f) if "viewer" in viewer_config: self.viewer_config.update(viewer_config["viewer"]) @@ -843,7 +844,7 @@ class LogViewer: mapping_file = Path("config/module_mapping.json") if mapping_file.exists(): try: - with open(mapping_file, "r", encoding="utf-8") as f: + with open(mapping_file, encoding="utf-8") as f: custom_mapping = orjson.loads(f.read()) self.module_name_mapping.update(custom_mapping) except Exception as e: @@ -1172,7 +1173,7 @@ class LogViewer: """读取新的日志条目并返回它们""" new_entries = [] new_modules = set() # 收集新发现的模块 - with open(self.current_log_file, "r", encoding="utf-8") as f: + with open(self.current_log_file, encoding="utf-8") as f: f.seek(from_position) line_count = self.log_index.total_entries for line in f: diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 9caafc7fd..58aa91c64 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -1,36 +1,37 @@ import asyncio +import datetime import os import shutil import sys -import orjson -import datetime -from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path from threading import Lock -from typing import Optional + +import orjson from json_repair import repair_json # 将项目根目录添加到 sys.path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from src.common.logger import get_logger -from src.chat.knowledge.utils.hash import get_sha256 -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config -from src.chat.knowledge.open_ie import OpenIE -from src.chat.knowledge.embedding_store import EmbeddingManager -from src.chat.knowledge.kg_manager import KGManager from rich.progress import ( - Progress, BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, TimeElapsedColumn, TimeRemainingColumn, - TaskProgressColumn, - MofNCompleteColumn, - SpinnerColumn, - TextColumn, ) +from src.chat.knowledge.embedding_store import EmbeddingManager +from src.chat.knowledge.kg_manager import KGManager +from src.chat.knowledge.open_ie import OpenIE +from src.chat.knowledge.utils.hash import get_sha256 +from src.common.logger import get_logger +from src.config.config import model_config +from src.llm_models.utils_model import LLMRequest + logger = get_logger("LPMM_LearningTool") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data") @@ -59,7 +60,7 @@ def clear_cache(): def process_text_file(file_path): - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: raw = f.read() return [p.strip() for p in raw.split("\n\n") if p.strip()] @@ -86,7 +87,7 @@ def preprocess_raw_data(): # --- 模块二:信息提取 --- -def _parse_and_repair_json(json_string: str) -> Optional[dict]: +def _parse_and_repair_json(json_string: str) -> dict | None: """ 尝试解析JSON字符串,如果失败则尝试修复并重新解析。 @@ -249,7 +250,7 @@ def extract_information(paragraphs_dict, model_set): # --- 模块三:数据导入 --- -async def import_data(openie_obj: Optional[OpenIE] = None): +async def import_data(openie_obj: OpenIE | None = None): """ 将OpenIE数据导入知识库(Embedding Store 和 KG) diff --git a/scripts/manifest_tool.py b/scripts/manifest_tool.py index 6f9a3a6d0..c18b6a208 100644 --- a/scripts/manifest_tool.py +++ b/scripts/manifest_tool.py @@ -4,11 +4,13 @@ 提供插件manifest文件的创建、验证和管理功能 """ +import argparse import os import sys -import argparse -import orjson from pathlib import Path + +import orjson + from src.common.logger import get_logger from src.plugin_system.utils.manifest_utils import ( ManifestValidator, @@ -124,7 +126,7 @@ def validate_manifest_file(plugin_dir: str) -> bool: return False try: - with open(manifest_path, "r", encoding="utf-8") as f: + with open(manifest_path, encoding="utf-8") as f: manifest_data = orjson.loads(f.read()) validator = ManifestValidator() diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py index 789c5860a..36b7aa9ab 100644 --- a/scripts/mongodb_to_sqlite.py +++ b/scripts/mongodb_to_sqlite.py @@ -1,46 +1,48 @@ import os -import orjson -import sys # 新增系统模块导入 # import time import pickle +import sys # 新增系统模块导入 from pathlib import Path +import orjson + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from typing import Dict, Any, List, Optional, Type from dataclasses import dataclass, field from datetime import datetime +from typing import Any + +from peewee import Field, IntegrityError, Model from pymongo import MongoClient from pymongo.errors import ConnectionFailure -from peewee import Model, Field, IntegrityError # Rich 进度条和显示组件 from rich.console import Console +from rich.panel import Panel from rich.progress import ( - Progress, - TextColumn, BarColumn, - TaskProgressColumn, - TimeRemainingColumn, - TimeElapsedColumn, + Progress, SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, ) from rich.table import Table -from rich.panel import Panel -# from rich.text import Text +# from rich.text import Text from src.common.database.database import db from src.common.database.sqlalchemy_models import ( ChatStreams, Emoji, - Messages, - Images, - ImageDescriptions, - PersonInfo, - Knowledges, - ThinkingLog, - GraphNodes, GraphEdges, + GraphNodes, + ImageDescriptions, + Images, + Knowledges, + Messages, + PersonInfo, + ThinkingLog, ) from src.common.logger import get_logger @@ -54,12 +56,12 @@ class MigrationConfig: """迁移配置类""" mongo_collection: str - target_model: Type[Model] - field_mapping: Dict[str, str] + target_model: type[Model] + field_mapping: dict[str, str] batch_size: int = 500 enable_validation: bool = True skip_duplicates: bool = True - unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段 + unique_fields: list[str] = field(default_factory=list) # 用于重复检查的字段 # 数据验证相关类已移除 - 用户要求不要数据验证 @@ -73,7 +75,7 @@ class MigrationCheckpoint: processed_count: int last_processed_id: Any timestamp: datetime - batch_errors: List[Dict[str, Any]] = field(default_factory=list) + batch_errors: list[dict[str, Any]] = field(default_factory=list) @dataclass @@ -88,11 +90,11 @@ class MigrationStats: duplicate_count: int = 0 validation_errors: int = 0 batch_insert_count: int = 0 - errors: List[Dict[str, Any]] = field(default_factory=list) - start_time: Optional[datetime] = None - end_time: Optional[datetime] = None + errors: list[dict[str, Any]] = field(default_factory=list) + start_time: datetime | None = None + end_time: datetime | None = None - def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None): + def add_error(self, doc_id: Any, error: str, doc_data: dict | None = None): """添加错误记录""" self.errors.append( {"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data} @@ -108,10 +110,10 @@ class MigrationStats: class MongoToSQLiteMigrator: """MongoDB到SQLite数据迁移器 - 使用Peewee ORM""" - def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None): + def __init__(self, mongo_uri: str | None = None, database_name: str | None = None): self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot") self.mongo_uri = mongo_uri or self._build_mongo_uri() - self.mongo_client: Optional[MongoClient] = None + self.mongo_client: MongoClient | None = None self.mongo_db = None # 迁移配置 @@ -142,7 +144,7 @@ class MongoToSQLiteMigrator: else: return f"mongodb://{host}:{port}/{self.database_name}" - def _initialize_migration_configs(self) -> List[MigrationConfig]: + def _initialize_migration_configs(self) -> list[MigrationConfig]: """初始化迁移配置""" return [ # 表情包迁移配置 MigrationConfig( @@ -306,7 +308,7 @@ class MongoToSQLiteMigrator: ), ] - def _initialize_validation_rules(self) -> Dict[str, Any]: + def _initialize_validation_rules(self) -> dict[str, Any]: """数据验证已禁用 - 返回空字典""" return {} @@ -337,7 +339,7 @@ class MongoToSQLiteMigrator: self.mongo_client.close() logger.info("MongoDB连接已关闭") - def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any: + def _get_nested_value(self, document: dict[str, Any], field_path: str) -> Any: """获取嵌套字段的值""" if "." not in field_path: return document.get(field_path) @@ -434,7 +436,7 @@ class MongoToSQLiteMigrator: return None - def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool: + def _validate_data(self, collection_name: str, data: dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool: """数据验证已禁用 - 始终返回True""" return True @@ -454,7 +456,7 @@ class MongoToSQLiteMigrator: except Exception as e: logger.warning(f"保存断点失败: {e}") - def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]: + def _load_checkpoint(self, collection_name: str) -> MigrationCheckpoint | None: """加载迁移断点""" checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl" if not checkpoint_file.exists(): @@ -467,7 +469,7 @@ class MongoToSQLiteMigrator: logger.warning(f"加载断点失败: {e}") return None - def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int: + def _batch_insert(self, model: type[Model], data_list: list[dict[str, Any]]) -> int: """批量插入数据""" if not data_list: return 0 @@ -494,7 +496,7 @@ class MongoToSQLiteMigrator: return success_count def _check_duplicate_by_unique_fields( - self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str] + self, model: type[Model], data: dict[str, Any], unique_fields: list[str] ) -> bool: """根据唯一字段检查重复""" if not unique_fields: @@ -512,7 +514,7 @@ class MongoToSQLiteMigrator: logger.debug(f"重复检查失败: {e}") return False - def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]: + def _create_model_instance(self, model: type[Model], data: dict[str, Any]) -> Model | None: """使用ORM创建模型实例""" try: # 过滤掉不存在的字段 @@ -669,7 +671,7 @@ class MongoToSQLiteMigrator: return stats - def migrate_all(self) -> Dict[str, MigrationStats]: + def migrate_all(self) -> dict[str, MigrationStats]: """执行所有迁移任务""" logger.info("开始执行数据库迁移...") @@ -730,7 +732,7 @@ class MongoToSQLiteMigrator: self._print_migration_summary(all_stats) return all_stats - def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]): + def _print_migration_summary(self, all_stats: dict[str, MigrationStats]): """使用Rich打印美观的迁移汇总信息""" # 计算总体统计 total_processed = sum(stats.processed_count for stats in all_stats.values()) @@ -857,7 +859,7 @@ class MongoToSQLiteMigrator: """添加新的迁移配置""" self.migration_configs.append(config) - def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]: + def migrate_single_collection(self, collection_name: str) -> MigrationStats | None: """迁移单个指定的集合""" config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None) if not config: @@ -875,7 +877,7 @@ class MongoToSQLiteMigrator: finally: self.disconnect_mongodb() - def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str): + def export_error_report(self, all_stats: dict[str, MigrationStats], filepath: str): """导出错误报告""" error_report = { "timestamp": datetime.now().isoformat(), diff --git a/scripts/rebuild_metadata_index.py b/scripts/rebuild_metadata_index.py index d1990fecc..b4d786019 100644 --- a/scripts/rebuild_metadata_index.py +++ b/scripts/rebuild_metadata_index.py @@ -1,17 +1,16 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ 从现有ChromaDB数据重建JSON元数据索引 """ import asyncio -import sys import os +import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from src.chat.memory_system.memory_system import MemorySystem from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndexEntry +from src.chat.memory_system.memory_system import MemorySystem from src.common.logger import get_logger logger = get_logger(__name__) diff --git a/scripts/run_multi_stage_smoke.py b/scripts/run_multi_stage_smoke.py index 000336244..634f97210 100644 --- a/scripts/run_multi_stage_smoke.py +++ b/scripts/run_multi_stage_smoke.py @@ -1,12 +1,11 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ 轻量烟雾测试:初始化 MemorySystem 并运行一次检索,验证 MemoryMetadata.source 访问不再报错 """ import asyncio -import sys import os +import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) diff --git a/scripts/text_length_analysis.py b/scripts/text_length_analysis.py index 5a329b93c..818b5f6e1 100644 --- a/scripts/text_length_analysis.py +++ b/scripts/text_length_analysis.py @@ -1,8 +1,7 @@ -import time -import sys import os import re -from typing import Dict, List, Tuple, Optional +import sys +import time from datetime import datetime # Add project root to Python path @@ -63,7 +62,7 @@ def format_timestamp(timestamp: float) -> str: return "未知时间" -def calculate_text_length_distribution(messages) -> Dict[str, int]: +def calculate_text_length_distribution(messages) -> dict[str, int]: """Calculate distribution of processed_plain_text length""" distribution = { "0": 0, # 空文本 @@ -126,7 +125,7 @@ def calculate_text_length_distribution(messages) -> Dict[str, int]: return distribution -def get_text_length_stats(messages) -> Dict[str, float]: +def get_text_length_stats(messages) -> dict[str, float]: """Calculate basic statistics for processed_plain_text length""" lengths = [] null_count = 0 @@ -168,7 +167,7 @@ def get_text_length_stats(messages) -> Dict[str, float]: } -def get_available_chats() -> List[Tuple[str, str, int]]: +def get_available_chats() -> list[tuple[str, str, int]]: """Get all available chats with message counts""" try: # 获取所有有消息的chat_id,排除特殊类型消息 @@ -202,7 +201,7 @@ def get_available_chats() -> List[Tuple[str, str, int]]: return [] -def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: +def get_time_range_input() -> tuple[float | None, float | None]: """Get time range input from user""" print("\n时间范围选择:") print("1. 最近1天") @@ -241,7 +240,7 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: return None, None -def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]: +def get_top_longest_messages(messages, top_n: int = 10) -> list[tuple[str, int, str, str]]: """Get top N longest messages""" message_lengths = [] @@ -266,7 +265,7 @@ def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, def analyze_text_lengths( - chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None + chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None ) -> None: """Analyze processed_plain_text lengths with optional filters""" diff --git a/scripts/update_prompt_imports.py b/scripts/update_prompt_imports.py index 227491ec2..3917c9408 100644 --- a/scripts/update_prompt_imports.py +++ b/scripts/update_prompt_imports.py @@ -30,7 +30,7 @@ def update_prompt_imports(file_path): print(f"文件不存在: {file_path}") return False - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: content = f.read() # 替换导入语句 diff --git a/src/__init__.py b/src/__init__.py index bdb90be85..d23d01ddb 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,13 +1,15 @@ import random -from typing import List, Optional, Sequence -from colorama import init, Fore +from collections.abc import Sequence +from typing import List, Optional + +from colorama import Fore, init from src.common.logger import get_logger egg = get_logger("小彩蛋") -def weighted_choice(data: Sequence[str], weights: Optional[List[float]] = None) -> str: +def weighted_choice(data: Sequence[str], weights: list[float] | None = None) -> str: """ 从 data 中按权重随机返回一条。 若 weights 为 None,则所有元素权重默认为 1。 diff --git a/src/chat/__init__.py b/src/chat/__init__.py index a569c0226..2f7da45ce 100644 --- a/src/chat/__init__.py +++ b/src/chat/__init__.py @@ -3,8 +3,8 @@ MaiBot模块系统 包含聊天、情绪、记忆、日程等功能模块 """ -from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager +from src.chat.message_receive.chat_stream import get_chat_manager # 导出主要组件供外部使用 __all__ = [ diff --git a/src/chat/antipromptinjector/__init__.py b/src/chat/antipromptinjector/__init__.py index e5a672c86..fb45f006a 100644 --- a/src/chat/antipromptinjector/__init__.py +++ b/src/chat/antipromptinjector/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ MaiBot 反注入系统模块 @@ -14,25 +13,25 @@ MaiBot 反注入系统模块 """ from .anti_injector import AntiPromptInjector, get_anti_injector, initialize_anti_injector -from .types import DetectionResult, ProcessResult -from .core import PromptInjectionDetector, MessageShield -from .processors.message_processor import MessageProcessor -from .management import AntiInjectionStatistics, UserBanManager +from .core import MessageShield, PromptInjectionDetector from .decision import CounterAttackGenerator, ProcessingDecisionMaker +from .management import AntiInjectionStatistics, UserBanManager +from .processors.message_processor import MessageProcessor +from .types import DetectionResult, ProcessResult __all__ = [ + "AntiInjectionStatistics", "AntiPromptInjector", + "CounterAttackGenerator", + "DetectionResult", + "MessageProcessor", + "MessageShield", + "ProcessResult", + "ProcessingDecisionMaker", + "PromptInjectionDetector", + "UserBanManager", "get_anti_injector", "initialize_anti_injector", - "DetectionResult", - "ProcessResult", - "PromptInjectionDetector", - "MessageShield", - "MessageProcessor", - "AntiInjectionStatistics", - "UserBanManager", - "CounterAttackGenerator", - "ProcessingDecisionMaker", ] diff --git a/src/chat/antipromptinjector/anti_injector.py b/src/chat/antipromptinjector/anti_injector.py index b2c2e3232..23ff3a7ee 100644 --- a/src/chat/antipromptinjector/anti_injector.py +++ b/src/chat/antipromptinjector/anti_injector.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ LLM反注入系统主模块 @@ -12,15 +11,16 @@ LLM反注入系统主模块 """ import time -from typing import Optional, Tuple, Dict, Any +from typing import Any from src.common.logger import get_logger from src.config.config import global_config -from .types import ProcessResult -from .core import PromptInjectionDetector, MessageShield -from .processors.message_processor import MessageProcessor -from .management import AntiInjectionStatistics, UserBanManager + +from .core import MessageShield, PromptInjectionDetector from .decision import CounterAttackGenerator, ProcessingDecisionMaker +from .management import AntiInjectionStatistics, UserBanManager +from .processors.message_processor import MessageProcessor +from .types import ProcessResult logger = get_logger("anti_injector") @@ -43,7 +43,7 @@ class AntiPromptInjector: async def process_message( self, message_data: dict, chat_stream=None - ) -> Tuple[ProcessResult, Optional[str], Optional[str]]: + ) -> tuple[ProcessResult, str | None, str | None]: """处理字典格式的消息并返回结果 Args: @@ -102,7 +102,7 @@ class AntiPromptInjector: await self.statistics.update_stats(error_count=1) # 异常情况下直接阻止消息 - return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {str(e)}" + return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {e!s}" finally: # 更新处理时间统计 @@ -111,7 +111,7 @@ class AntiPromptInjector: async def _process_message_internal( self, text_to_detect: str, user_id: str, platform: str, processed_plain_text: str, start_time: float - ) -> Tuple[ProcessResult, Optional[str], Optional[str]]: + ) -> tuple[ProcessResult, str | None, str | None]: """内部消息处理逻辑(共用的检测核心)""" # 如果是纯引用消息,直接允许通过 @@ -218,7 +218,7 @@ class AntiPromptInjector: return ProcessResult.ALLOWED, None, "消息检查通过" async def handle_message_storage( - self, result: ProcessResult, modified_content: Optional[str], reason: str, message_data: dict + self, result: ProcessResult, modified_content: str | None, reason: str, message_data: dict ) -> None: """处理违禁消息的数据库存储,根据处理模式决定如何处理""" if result == ProcessResult.BLOCKED_INJECTION or result == ProcessResult.COUNTER_ATTACK: @@ -253,9 +253,10 @@ class AntiPromptInjector: async def _delete_message_from_storage(message_data: dict) -> None: """从数据库中删除违禁消息记录""" try: - from src.common.database.sqlalchemy_models import Messages, get_db_session from sqlalchemy import delete + from src.common.database.sqlalchemy_models import Messages, get_db_session + message_id = message_data.get("message_id") if not message_id: logger.warning("无法删除消息:缺少message_id") @@ -279,9 +280,10 @@ class AntiPromptInjector: async def _update_message_in_storage(message_data: dict, new_content: str) -> None: """更新数据库中的消息内容为加盾版本""" try: - from src.common.database.sqlalchemy_models import Messages, get_db_session from sqlalchemy import update + from src.common.database.sqlalchemy_models import Messages, get_db_session + message_id = message_data.get("message_id") if not message_id: logger.warning("无法更新消息:缺少message_id") @@ -305,7 +307,7 @@ class AntiPromptInjector: except Exception as e: logger.error(f"更新消息内容失败: {e}") - async def get_stats(self) -> Dict[str, Any]: + async def get_stats(self) -> dict[str, Any]: """获取统计信息""" return await self.statistics.get_stats() @@ -315,7 +317,7 @@ class AntiPromptInjector: # 全局反注入器实例 -_global_injector: Optional[AntiPromptInjector] = None +_global_injector: AntiPromptInjector | None = None def get_anti_injector() -> AntiPromptInjector: diff --git a/src/chat/antipromptinjector/core/__init__.py b/src/chat/antipromptinjector/core/__init__.py index f4087c4f3..5f751d823 100644 --- a/src/chat/antipromptinjector/core/__init__.py +++ b/src/chat/antipromptinjector/core/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 反注入系统核心检测模块 @@ -10,4 +9,4 @@ from .detector import PromptInjectionDetector from .shield import MessageShield -__all__ = ["PromptInjectionDetector", "MessageShield"] +__all__ = ["MessageShield", "PromptInjectionDetector"] diff --git a/src/chat/antipromptinjector/core/detector.py b/src/chat/antipromptinjector/core/detector.py index 39e65db8b..202c9bb5b 100644 --- a/src/chat/antipromptinjector/core/detector.py +++ b/src/chat/antipromptinjector/core/detector.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 提示词注入检测器模块 @@ -8,19 +7,19 @@ 3. 缓存机制优化性能 """ +import hashlib import re import time -import hashlib -from typing import Dict, List from dataclasses import asdict from src.common.logger import get_logger from src.config.config import global_config -from ..types import DetectionResult # 导入LLM API from src.plugin_system.apis import llm_api +from ..types import DetectionResult + logger = get_logger("anti_injector.detector") @@ -30,8 +29,8 @@ class PromptInjectionDetector: def __init__(self): """初始化检测器""" self.config = global_config.anti_prompt_injection - self._cache: Dict[str, DetectionResult] = {} - self._compiled_patterns: List[re.Pattern] = [] + self._cache: dict[str, DetectionResult] = {} + self._compiled_patterns: list[re.Pattern] = [] self._compile_patterns() def _compile_patterns(self): @@ -224,7 +223,7 @@ class PromptInjectionDetector: matched_patterns=[], processing_time=processing_time, detection_method="llm", - reason=f"LLM检测出错: {str(e)}", + reason=f"LLM检测出错: {e!s}", ) @staticmethod @@ -250,7 +249,7 @@ class PromptInjectionDetector: 请客观分析,避免误判正常对话。""" @staticmethod - def _parse_llm_response(response: str) -> Dict: + def _parse_llm_response(response: str) -> dict: """解析LLM响应""" try: lines = response.strip().split("\n") @@ -280,7 +279,7 @@ class PromptInjectionDetector: except Exception as e: logger.error(f"解析LLM响应失败: {e}") - return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"} + return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"} async def detect(self, message: str) -> DetectionResult: """执行检测""" @@ -331,7 +330,7 @@ class PromptInjectionDetector: return final_result - def _merge_results(self, results: List[DetectionResult]) -> DetectionResult: + def _merge_results(self, results: list[DetectionResult]) -> DetectionResult: """合并多个检测结果""" if not results: return DetectionResult(reason="无检测结果") @@ -384,7 +383,7 @@ class PromptInjectionDetector: if expired_keys: logger.debug(f"清理了{len(expired_keys)}个过期缓存项") - def get_cache_stats(self) -> Dict: + def get_cache_stats(self) -> dict: """获取缓存统计信息""" return { "cache_size": len(self._cache), diff --git a/src/chat/antipromptinjector/core/shield.py b/src/chat/antipromptinjector/core/shield.py index c7a2e78bc..399ec9025 100644 --- a/src/chat/antipromptinjector/core/shield.py +++ b/src/chat/antipromptinjector/core/shield.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 消息加盾模块 @@ -6,8 +5,6 @@ 主要通过注入系统提示词来指导AI安全响应。 """ -from typing import List - from src.common.logger import get_logger from src.config.config import global_config @@ -35,7 +32,7 @@ class MessageShield: return SAFETY_SYSTEM_PROMPT @staticmethod - def is_shield_needed(confidence: float, matched_patterns: List[str]) -> bool: + def is_shield_needed(confidence: float, matched_patterns: list[str]) -> bool: """判断是否需要加盾 Args: @@ -60,7 +57,7 @@ class MessageShield: return False @staticmethod - def create_safety_summary(confidence: float, matched_patterns: List[str]) -> str: + def create_safety_summary(confidence: float, matched_patterns: list[str]) -> str: """创建安全处理摘要 Args: diff --git a/src/chat/antipromptinjector/counter_attack.py b/src/chat/antipromptinjector/counter_attack.py index 7c2bd86c5..2a094e419 100644 --- a/src/chat/antipromptinjector/counter_attack.py +++ b/src/chat/antipromptinjector/counter_attack.py @@ -1,15 +1,13 @@ -# -*- coding: utf-8 -*- """ 反击消息生成模块 负责生成个性化的反击消息回应提示词注入攻击 """ -from typing import Optional - from src.common.logger import get_logger from src.config.config import global_config from src.plugin_system.apis import llm_api + from .types import DetectionResult logger = get_logger("anti_injector.counter_attack") @@ -55,7 +53,7 @@ class CounterAttackGenerator: async def generate_counter_attack_message( self, original_message: str, detection_result: DetectionResult - ) -> Optional[str]: + ) -> str | None: """生成反击消息 Args: diff --git a/src/chat/antipromptinjector/decision/__init__.py b/src/chat/antipromptinjector/decision/__init__.py index 5778ca4ed..358147066 100644 --- a/src/chat/antipromptinjector/decision/__init__.py +++ b/src/chat/antipromptinjector/decision/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 反注入系统决策模块 @@ -7,7 +6,7 @@ - counter_attack: 反击消息生成器 """ -from .decision_maker import ProcessingDecisionMaker from .counter_attack import CounterAttackGenerator +from .decision_maker import ProcessingDecisionMaker -__all__ = ["ProcessingDecisionMaker", "CounterAttackGenerator"] +__all__ = ["CounterAttackGenerator", "ProcessingDecisionMaker"] diff --git a/src/chat/antipromptinjector/decision/counter_attack.py b/src/chat/antipromptinjector/decision/counter_attack.py index 9d6aac2ff..ad305b9c4 100644 --- a/src/chat/antipromptinjector/decision/counter_attack.py +++ b/src/chat/antipromptinjector/decision/counter_attack.py @@ -1,15 +1,13 @@ -# -*- coding: utf-8 -*- """ 反击消息生成模块 负责生成个性化的反击消息回应提示词注入攻击 """ -from typing import Optional - from src.common.logger import get_logger from src.config.config import global_config from src.plugin_system.apis import llm_api + from ..types import DetectionResult logger = get_logger("anti_injector.counter_attack") @@ -55,7 +53,7 @@ class CounterAttackGenerator: async def generate_counter_attack_message( self, original_message: str, detection_result: DetectionResult - ) -> Optional[str]: + ) -> str | None: """生成反击消息 Args: diff --git a/src/chat/antipromptinjector/decision/decision_maker.py b/src/chat/antipromptinjector/decision/decision_maker.py index 12a2c95b5..be3d3ccfb 100644 --- a/src/chat/antipromptinjector/decision/decision_maker.py +++ b/src/chat/antipromptinjector/decision/decision_maker.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 处理决策器模块 @@ -6,6 +5,7 @@ """ from src.common.logger import get_logger + from ..types import DetectionResult logger = get_logger("anti_injector.decision_maker") diff --git a/src/chat/antipromptinjector/decision_maker.py b/src/chat/antipromptinjector/decision_maker.py index 972253fab..893da059f 100644 --- a/src/chat/antipromptinjector/decision_maker.py +++ b/src/chat/antipromptinjector/decision_maker.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 处理决策器模块 @@ -6,6 +5,7 @@ """ from src.common.logger import get_logger + from .types import DetectionResult logger = get_logger("anti_injector.decision_maker") diff --git a/src/chat/antipromptinjector/detector.py b/src/chat/antipromptinjector/detector.py index 6c1e3b4bd..59d1132b1 100644 --- a/src/chat/antipromptinjector/detector.py +++ b/src/chat/antipromptinjector/detector.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 提示词注入检测器模块 @@ -8,19 +7,19 @@ 3. 缓存机制优化性能 """ +import hashlib import re import time -import hashlib -from typing import Dict, List from dataclasses import asdict from src.common.logger import get_logger from src.config.config import global_config -from .types import DetectionResult # 导入LLM API from src.plugin_system.apis import llm_api +from .types import DetectionResult + logger = get_logger("anti_injector.detector") @@ -30,8 +29,8 @@ class PromptInjectionDetector: def __init__(self): """初始化检测器""" self.config = global_config.anti_prompt_injection - self._cache: Dict[str, DetectionResult] = {} - self._compiled_patterns: List[re.Pattern] = [] + self._cache: dict[str, DetectionResult] = {} + self._compiled_patterns: list[re.Pattern] = [] self._compile_patterns() def _compile_patterns(self): @@ -221,7 +220,7 @@ class PromptInjectionDetector: matched_patterns=[], processing_time=processing_time, detection_method="llm", - reason=f"LLM检测出错: {str(e)}", + reason=f"LLM检测出错: {e!s}", ) @staticmethod @@ -247,7 +246,7 @@ class PromptInjectionDetector: 请客观分析,避免误判正常对话。""" @staticmethod - def _parse_llm_response(response: str) -> Dict: + def _parse_llm_response(response: str) -> dict: """解析LLM响应""" try: lines = response.strip().split("\n") @@ -277,7 +276,7 @@ class PromptInjectionDetector: except Exception as e: logger.error(f"解析LLM响应失败: {e}") - return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"} + return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"} async def detect(self, message: str) -> DetectionResult: """执行检测""" @@ -328,7 +327,7 @@ class PromptInjectionDetector: return final_result - def _merge_results(self, results: List[DetectionResult]) -> DetectionResult: + def _merge_results(self, results: list[DetectionResult]) -> DetectionResult: """合并多个检测结果""" if not results: return DetectionResult(reason="无检测结果") @@ -381,7 +380,7 @@ class PromptInjectionDetector: if expired_keys: logger.debug(f"清理了{len(expired_keys)}个过期缓存项") - def get_cache_stats(self) -> Dict: + def get_cache_stats(self) -> dict: """获取缓存统计信息""" return { "cache_size": len(self._cache), diff --git a/src/chat/antipromptinjector/management/__init__.py b/src/chat/antipromptinjector/management/__init__.py index eaef392c4..28b1bcee2 100644 --- a/src/chat/antipromptinjector/management/__init__.py +++ b/src/chat/antipromptinjector/management/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 反注入系统管理模块 diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index 9d44faa78..0525754f1 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 反注入系统统计模块 @@ -6,12 +5,12 @@ """ import datetime -from typing import Dict, Any +from typing import Any from sqlalchemy import select -from src.common.logger import get_logger from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session +from src.common.logger import get_logger from src.config.config import global_config logger = get_logger("anti_injector.statistics") @@ -94,7 +93,7 @@ class AntiInjectionStatistics: except Exception as e: logger.error(f"更新统计数据失败: {e}") - async def get_stats(self) -> Dict[str, Any]: + async def get_stats(self) -> dict[str, Any]: """获取统计信息""" try: # 检查反注入系统是否启用 diff --git a/src/chat/antipromptinjector/management/user_ban.py b/src/chat/antipromptinjector/management/user_ban.py index b965a08af..f1b82a8dc 100644 --- a/src/chat/antipromptinjector/management/user_ban.py +++ b/src/chat/antipromptinjector/management/user_ban.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 用户封禁管理模块 @@ -6,12 +5,12 @@ """ import datetime -from typing import Optional, Tuple from sqlalchemy import select -from src.common.logger import get_logger from src.common.database.sqlalchemy_models import BanUser, get_db_session +from src.common.logger import get_logger + from ..types import DetectionResult logger = get_logger("anti_injector.user_ban") @@ -28,7 +27,7 @@ class UserBanManager: """ self.config = config - async def check_user_ban(self, user_id: str, platform: str) -> Optional[Tuple[bool, Optional[str], str]]: + async def check_user_ban(self, user_id: str, platform: str) -> tuple[bool, str | None, str] | None: """检查用户是否被封禁 Args: diff --git a/src/chat/antipromptinjector/processors/__init__.py b/src/chat/antipromptinjector/processors/__init__.py index 1db74557f..40de37df9 100644 --- a/src/chat/antipromptinjector/processors/__init__.py +++ b/src/chat/antipromptinjector/processors/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 反注入系统消息处理模块 diff --git a/src/chat/antipromptinjector/processors/message_processor.py b/src/chat/antipromptinjector/processors/message_processor.py index 935848c2d..0e37efc0d 100644 --- a/src/chat/antipromptinjector/processors/message_processor.py +++ b/src/chat/antipromptinjector/processors/message_processor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 消息内容处理模块 @@ -6,10 +5,9 @@ """ import re -from typing import Optional -from src.common.logger import get_logger from src.chat.message_receive.message import MessageRecv +from src.common.logger import get_logger logger = get_logger("anti_injector.message_processor") @@ -66,7 +64,7 @@ class MessageProcessor: return new_content @staticmethod - def check_whitelist(message: MessageRecv, whitelist: list) -> Optional[tuple]: + def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None: """检查用户白名单 Args: diff --git a/src/chat/antipromptinjector/types.py b/src/chat/antipromptinjector/types.py index 81d775ffc..ac436cc90 100644 --- a/src/chat/antipromptinjector/types.py +++ b/src/chat/antipromptinjector/types.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 反注入系统数据类型定义模块 @@ -10,7 +9,6 @@ """ import time -from typing import List, Optional from dataclasses import dataclass, field from enum import Enum @@ -31,8 +29,8 @@ class DetectionResult: is_injection: bool = False confidence: float = 0.0 - matched_patterns: List[str] = field(default_factory=list) - llm_analysis: Optional[str] = None + matched_patterns: list[str] = field(default_factory=list) + llm_analysis: str | None = None processing_time: float = 0.0 detection_method: str = "unknown" reason: str = "" diff --git a/src/chat/chatter_manager.py b/src/chat/chatter_manager.py index d22d39440..d8eda9baa 100644 --- a/src/chat/chatter_manager.py +++ b/src/chat/chatter_manager.py @@ -1,10 +1,11 @@ -from typing import Dict, List, Optional, Any import time -from src.plugin_system.base.base_chatter import BaseChatter -from src.common.data_models.message_manager_data_model import StreamContext +from typing import Any + from src.chat.planner_actions.action_manager import ChatterActionManager -from src.plugin_system.base.component_types import ChatType +from src.common.data_models.message_manager_data_model import StreamContext from src.common.logger import get_logger +from src.plugin_system.base.base_chatter import BaseChatter +from src.plugin_system.base.component_types import ChatType logger = get_logger("chatter_manager") @@ -12,8 +13,8 @@ logger = get_logger("chatter_manager") class ChatterManager: def __init__(self, action_manager: ChatterActionManager): self.action_manager = action_manager - self.chatter_classes: Dict[ChatType, List[type]] = {} - self.instances: Dict[str, BaseChatter] = {} + self.chatter_classes: dict[ChatType, list[type]] = {} + self.instances: dict[str, BaseChatter] = {} # 管理器统计 self.stats = { @@ -46,21 +47,21 @@ class ChatterManager: self.stats["chatters_registered"] += 1 - def get_chatter_class(self, chat_type: ChatType) -> Optional[type]: + def get_chatter_class(self, chat_type: ChatType) -> type | None: """获取指定聊天类型的聊天处理器类""" if chat_type in self.chatter_classes: return self.chatter_classes[chat_type][0] return None - def get_supported_chat_types(self) -> List[ChatType]: + def get_supported_chat_types(self) -> list[ChatType]: """获取支持的聊天类型列表""" return list(self.chatter_classes.keys()) - def get_registered_chatters(self) -> Dict[ChatType, List[type]]: + def get_registered_chatters(self) -> dict[ChatType, list[type]]: """获取已注册的聊天处理器""" return self.chatter_classes.copy() - def get_stream_instance(self, stream_id: str) -> Optional[BaseChatter]: + def get_stream_instance(self, stream_id: str) -> BaseChatter | None: """获取指定流的聊天处理器实例""" return self.instances.get(stream_id) @@ -139,7 +140,7 @@ class ChatterManager: logger.error(f"处理流 {stream_id} 时发生错误: {e}") raise - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: """获取管理器统计信息""" stats = self.stats.copy() stats["active_instances"] = len(self.instances) diff --git a/src/chat/emoji_system/emoji_history.py b/src/chat/emoji_system/emoji_history.py index dadd152a1..0e7d6a6e1 100644 --- a/src/chat/emoji_system/emoji_history.py +++ b/src/chat/emoji_system/emoji_history.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- """ 表情包发送历史记录模块 """ -from typing import List, Dict from collections import deque from src.common.logger import get_logger @@ -14,7 +12,7 @@ MAX_HISTORY_SIZE = 5 # 每个聊天会话最多保留最近5条表情历史 # 使用一个全局字典在内存中存储历史记录 # 键是 chat_id,值是一个 deque 对象 -_history_cache: Dict[str, deque] = {} +_history_cache: dict[str, deque] = {} def add_emoji_to_history(chat_id: str, emoji_description: str): @@ -38,7 +36,7 @@ def add_emoji_to_history(chat_id: str, emoji_description: str): logger.debug(f"已将表情 '{emoji_description}' 添加到聊天 {chat_id} 的内存历史中") -def get_recent_emojis(chat_id: str, limit: int = 5) -> List[str]: +def get_recent_emojis(chat_id: str, limit: int = 5) -> list[str]: """ 从内存中获取最近发送的表情包描述列表。 diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index cd472ec0c..62552a201 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -1,23 +1,24 @@ import asyncio import base64 +import binascii import hashlib +import io import os import random +import re import time import traceback -import io -import re -import binascii +from typing import Any, Optional -from typing import Optional, Tuple, List, Any from PIL import Image from rich.traceback import install from sqlalchemy import select + +from src.chat.utils.utils_image import get_image_manager, image_path_to_base64 from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import Emoji, Images from src.common.logger import get_logger from src.config.config import global_config, model_config -from src.chat.utils.utils_image import image_path_to_base64, get_image_manager from src.llm_models.utils_model import LLMRequest install(extra_lines=3) @@ -47,14 +48,14 @@ class MaiEmoji: self.embedding = [] self.hash = "" # 初始为空,在创建实例时会计算 self.description = "" - self.emotion: List[str] = [] + self.emotion: list[str] = [] self.usage_count = 0 self.last_used_time = time.time() self.register_time = time.time() self.is_deleted = False # 标记是否已被删除 self.format = "" - async def initialize_hash_format(self) -> Optional[bool]: + async def initialize_hash_format(self) -> bool | None: """从文件创建表情包实例, 计算哈希值和格式""" try: # 使用 full_path 检查文件是否存在 @@ -105,7 +106,7 @@ class MaiEmoji: self.is_deleted = True return None except Exception as e: - logger.error(f"[初始化错误] 初始化表情包时发生未预期错误 ({self.filename}): {str(e)}") + logger.error(f"[初始化错误] 初始化表情包时发生未预期错误 ({self.filename}): {e!s}") logger.error(traceback.format_exc()) self.is_deleted = True return None @@ -142,7 +143,7 @@ class MaiEmoji: self.path = EMOJI_REGISTERED_DIR # self.filename 保持不变 except Exception as move_error: - logger.error(f"[错误] 移动文件失败: {str(move_error)}") + logger.error(f"[错误] 移动文件失败: {move_error!s}") # 如果移动失败,尝试将实例状态恢复?暂时不处理,仅返回失败 return False @@ -174,11 +175,11 @@ class MaiEmoji: return True except Exception as db_error: - logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}") + logger.error(f"[错误] 保存数据库失败 ({self.filename}): {db_error!s}") return False except Exception as e: - logger.error(f"[错误] 注册表情包失败 ({self.filename}): {str(e)}") + logger.error(f"[错误] 注册表情包失败 ({self.filename}): {e!s}") logger.error(traceback.format_exc()) return False @@ -198,7 +199,7 @@ class MaiEmoji: os.remove(file_to_delete) logger.debug(f"[删除] 文件: {file_to_delete}") except Exception as e: - logger.error(f"[错误] 删除文件失败 {file_to_delete}: {str(e)}") + logger.error(f"[错误] 删除文件失败 {file_to_delete}: {e!s}") # 文件删除失败,但仍然尝试删除数据库记录 # 2. 删除数据库记录 @@ -214,7 +215,7 @@ class MaiEmoji: result = 1 # Successfully deleted one record await session.commit() except Exception as e: - logger.error(f"[错误] 删除数据库记录时出错: {str(e)}") + logger.error(f"[错误] 删除数据库记录时出错: {e!s}") result = 0 if result > 0: @@ -233,11 +234,11 @@ class MaiEmoji: return False except Exception as e: - logger.error(f"[错误] 删除表情包失败 ({self.filename}): {str(e)}") + logger.error(f"[错误] 删除表情包失败 ({self.filename}): {e!s}") return False -def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]: +def _emoji_objects_to_readable_list(emoji_objects: list["MaiEmoji"]) -> list[str]: """将表情包对象列表转换为可读的字符串列表 参数: @@ -256,7 +257,7 @@ def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str return emoji_info_list -def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]: +def _to_emoji_objects(data: Any) -> tuple[list["MaiEmoji"], int]: emoji_objects = [] load_errors = 0 emoji_data_list = list(data) @@ -300,7 +301,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]: logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}") load_errors += 1 except Exception as e: - logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {str(e)}") + logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {e!s}") load_errors += 1 return emoji_objects, load_errors @@ -335,7 +336,7 @@ async def clear_temp_emoji() -> None: logger.debug(f"[清理] 删除: {filename}") -async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], removed_count: int) -> int: +async def clean_unused_emojis(emoji_dir: str, emoji_objects: list["MaiEmoji"], removed_count: int) -> int: """清理指定目录中未被 emoji_objects 追踪的表情包文件""" if not os.path.exists(emoji_dir): logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") @@ -361,7 +362,7 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r logger.info(f"[清理] 删除未追踪的表情包文件: {file_full_path}") cleaned_count += 1 except Exception as e: - logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {str(e)}") + logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {e!s}") if cleaned_count > 0: logger.info(f"[清理] 在目录 {emoji_dir} 中清理了 {cleaned_count} 个破损表情包。") @@ -369,7 +370,7 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。") except Exception as e: - logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}") + logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {e!s}") return removed_count + cleaned_count @@ -437,9 +438,9 @@ class EmojiManager: emoji_update.last_used_time = time.time() # Update last used time await session.commit() except Exception as e: - logger.error(f"记录表情使用失败: {str(e)}") + logger.error(f"记录表情使用失败: {e!s}") - async def get_emoji_for_text(self, text_emotion: str) -> Optional[Tuple[str, str, str]]: + async def get_emoji_for_text(self, text_emotion: str) -> tuple[str, str, str] | None: """ 根据文本内容,使用LLM选择一个合适的表情包。 @@ -531,7 +532,7 @@ class EmojiManager: return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion except Exception as e: - logger.error(f"使用LLM获取表情包时发生错误: {str(e)}") + logger.error(f"使用LLM获取表情包时发生错误: {e!s}") logger.error(traceback.format_exc()) return None @@ -578,7 +579,7 @@ class EmojiManager: continue except Exception as item_error: - logger.error(f"[错误] 处理表情包记录时出错 ({emoji.filename}): {str(item_error)}") + logger.error(f"[错误] 处理表情包记录时出错 ({emoji.filename}): {item_error!s}") # 即使出错,也尝试继续检查下一个 continue @@ -597,7 +598,7 @@ class EmojiManager: logger.info(f"[检查] 已检查 {total_count} 个表情包记录,全部完好") except Exception as e: - logger.error(f"[错误] 检查表情包完整性失败: {str(e)}") + logger.error(f"[错误] 检查表情包完整性失败: {e!s}") logger.error(traceback.format_exc()) async def start_periodic_check_register(self) -> None: @@ -651,7 +652,7 @@ class EmojiManager: os.remove(file_path) logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}") except Exception as e: - logger.error(f"[错误] 扫描表情包目录失败: {str(e)}") + logger.error(f"[错误] 扫描表情包目录失败: {e!s}") await asyncio.sleep(global_config.emoji.check_interval * 60) @@ -674,11 +675,11 @@ class EmojiManager: logger.warning(f"[数据库] 加载过程中出现 {load_errors} 个错误。") except Exception as e: - logger.error(f"[错误] 从数据库加载所有表情包对象失败: {str(e)}") + logger.error(f"[错误] 从数据库加载所有表情包对象失败: {e!s}") self.emoji_objects = [] # 加载失败则清空列表 self.emoji_num = 0 - async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]: + async def get_emoji_from_db(self, emoji_hash: str | None = None) -> list["MaiEmoji"]: """获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找) 参数: @@ -707,7 +708,7 @@ class EmojiManager: return emoji_objects except Exception as e: - logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}") + logger.error(f"[错误] 从数据库获取表情包对象失败: {e!s}") return [] async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]: @@ -725,7 +726,7 @@ class EmojiManager: return emoji return None # 如果循环结束还没找到,则返回 None - async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]: + async def get_emoji_tag_by_hash(self, emoji_hash: str) -> str | None: """根据哈希值获取已注册表情包的描述 Args: @@ -753,10 +754,10 @@ class EmojiManager: return None except Exception as e: - logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}") + logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {e!s}") return None - async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]: + async def get_emoji_description_by_hash(self, emoji_hash: str) -> str | None: """根据哈希值获取已注册表情包的描述 Args: @@ -787,7 +788,7 @@ class EmojiManager: return None except Exception as e: - logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}") + logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {e!s}") return None async def delete_emoji(self, emoji_hash: str) -> bool: @@ -823,7 +824,7 @@ class EmojiManager: return False except Exception as e: - logger.error(f"[错误] 删除表情包失败: {str(e)}") + logger.error(f"[错误] 删除表情包失败: {e!s}") logger.error(traceback.format_exc()) return False @@ -909,11 +910,11 @@ class EmojiManager: return False except Exception as e: - logger.error(f"[错误] 替换表情包失败: {str(e)}") + logger.error(f"[错误] 替换表情包失败: {e!s}") logger.error(traceback.format_exc()) return False - async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]: + async def build_emoji_description(self, image_base64: str) -> tuple[str, list[str]]: """ 获取表情包的详细描述和情感关键词列表。 @@ -976,14 +977,14 @@ class EmojiManager: # 4. 内容审核,确保表情包符合规定 if global_config.emoji.content_filtration: - prompt = f''' + prompt = f""" 请根据以下标准审核这个表情包: 1. 主题必须符合:"{global_config.emoji.filtration_prompt}"。 2. 内容健康,不含色情、暴力、政治敏感等元素。 3. 必须是表情包,而不是普通的聊天截图或视频截图。 4. 表情包中的文字数量(如果有)不能超过5个。 这个表情包是否完全满足以上所有要求?请只回答“是”或“否”。 - ''' + """ content, _ = await self.vlm.generate_response_for_image( prompt, image_base64, image_format, temperature=0.1, max_tokens=10 ) @@ -1023,7 +1024,7 @@ class EmojiManager: return final_description, emotions except Exception as e: - logger.error(f"构建表情包描述时发生严重错误: {str(e)}") + logger.error(f"构建表情包描述时发生严重错误: {e!s}") logger.error(traceback.format_exc()) return "", [] @@ -1058,7 +1059,7 @@ class EmojiManager: os.remove(file_full_path) logger.info(f"[清理] 删除重复的待注册文件: {filename}") except Exception as e: - logger.error(f"[错误] 删除重复文件失败: {str(e)}") + logger.error(f"[错误] 删除重复文件失败: {e!s}") return False # 返回 False 表示未注册新表情 # 3. 构建描述和情感 @@ -1075,7 +1076,7 @@ class EmojiManager: os.remove(file_full_path) logger.info(f"[清理] 删除描述生成失败的文件: {filename}") except Exception as e: - logger.error(f"[错误] 删除描述生成失败文件时出错: {str(e)}") + logger.error(f"[错误] 删除描述生成失败文件时出错: {e!s}") return False new_emoji.description = description new_emoji.emotion = emotions @@ -1086,7 +1087,7 @@ class EmojiManager: os.remove(file_full_path) logger.info(f"[清理] 删除描述生成异常的文件: {filename}") except Exception as e: - logger.error(f"[错误] 删除描述生成异常文件时出错: {str(e)}") + logger.error(f"[错误] 删除描述生成异常文件时出错: {e!s}") return False # 4. 检查容量并决定是否替换或直接注册 @@ -1100,7 +1101,7 @@ class EmojiManager: os.remove(file_full_path) # new_emoji 的 full_path 此时还是源路径 logger.info(f"[清理] 删除替换失败的新表情文件: {filename}") except Exception as e: - logger.error(f"[错误] 删除替换失败文件时出错: {str(e)}") + logger.error(f"[错误] 删除替换失败文件时出错: {e!s}") return False # 替换成功时,replace_a_emoji 内部已处理 new_emoji 的注册和添加到列表 return True @@ -1122,11 +1123,11 @@ class EmojiManager: os.remove(file_full_path) logger.info(f"[清理] 删除注册失败的源文件: {filename}") except Exception as e: - logger.error(f"[错误] 删除注册失败源文件时出错: {str(e)}") + logger.error(f"[错误] 删除注册失败源文件时出错: {e!s}") return False except Exception as e: - logger.error(f"[错误] 注册表情包时发生未预期错误 ({filename}): {str(e)}") + logger.error(f"[错误] 注册表情包时发生未预期错误 ({filename}): {e!s}") logger.error(traceback.format_exc()) # 尝试删除源文件以避免循环处理 if os.path.exists(file_full_path): diff --git a/src/chat/energy_system/__init__.py b/src/chat/energy_system/__init__.py index 6cdf96da5..570e183e6 100644 --- a/src/chat/energy_system/__init__.py +++ b/src/chat/energy_system/__init__.py @@ -4,24 +4,24 @@ """ from .energy_manager import ( - EnergyManager, - EnergyLevel, - EnergyComponent, - EnergyCalculator, - InterestEnergyCalculator, ActivityEnergyCalculator, + EnergyCalculator, + EnergyComponent, + EnergyLevel, + EnergyManager, + InterestEnergyCalculator, RecencyEnergyCalculator, RelationshipEnergyCalculator, energy_manager, ) __all__ = [ - "EnergyManager", - "EnergyLevel", - "EnergyComponent", - "EnergyCalculator", - "InterestEnergyCalculator", "ActivityEnergyCalculator", + "EnergyCalculator", + "EnergyComponent", + "EnergyLevel", + "EnergyManager", + "InterestEnergyCalculator", "RecencyEnergyCalculator", "RelationshipEnergyCalculator", "energy_manager", diff --git a/src/chat/energy_system/energy_manager.py b/src/chat/energy_system/energy_manager.py index 4a92349bf..0bfb6fc4f 100644 --- a/src/chat/energy_system/energy_manager.py +++ b/src/chat/energy_system/energy_manager.py @@ -4,10 +4,10 @@ """ import time -from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict +from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum -from abc import ABC, abstractmethod +from typing import Any, TypedDict from src.common.logger import get_logger from src.config.config import global_config @@ -51,8 +51,8 @@ class EnergyContext(TypedDict): """能量计算上下文""" stream_id: str - messages: List[Any] - user_id: Optional[str] + messages: list[Any] + user_id: str | None class EnergyResult(TypedDict): @@ -61,7 +61,7 @@ class EnergyResult(TypedDict): energy: float level: EnergyLevel distribution_interval: float - component_scores: Dict[str, float] + component_scores: dict[str, float] cached: bool @@ -69,7 +69,7 @@ class EnergyCalculator(ABC): """能量计算器抽象基类""" @abstractmethod - def calculate(self, context: Dict[str, Any]) -> float: + def calculate(self, context: dict[str, Any]) -> float: """计算能量值""" pass @@ -82,7 +82,7 @@ class EnergyCalculator(ABC): class InterestEnergyCalculator(EnergyCalculator): """兴趣度能量计算器""" - def calculate(self, context: Dict[str, Any]) -> float: + def calculate(self, context: dict[str, Any]) -> float: """基于消息兴趣度计算能量""" messages = context.get("messages", []) if not messages: @@ -120,7 +120,7 @@ class ActivityEnergyCalculator(EnergyCalculator): def __init__(self): self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1} - def calculate(self, context: Dict[str, Any]) -> float: + def calculate(self, context: dict[str, Any]) -> float: """基于活跃度计算能量""" messages = context.get("messages", []) if not messages: @@ -150,7 +150,7 @@ class ActivityEnergyCalculator(EnergyCalculator): class RecencyEnergyCalculator(EnergyCalculator): """最近性能量计算器""" - def calculate(self, context: Dict[str, Any]) -> float: + def calculate(self, context: dict[str, Any]) -> float: """基于最近性计算能量""" messages = context.get("messages", []) if not messages: @@ -197,7 +197,7 @@ class RecencyEnergyCalculator(EnergyCalculator): class RelationshipEnergyCalculator(EnergyCalculator): """关系能量计算器""" - async def calculate(self, context: Dict[str, Any]) -> float: + async def calculate(self, context: dict[str, Any]) -> float: """基于关系计算能量""" user_id = context.get("user_id") if not user_id: @@ -223,7 +223,7 @@ class EnergyManager: """能量管理器 - 统一管理所有能量计算""" def __init__(self) -> None: - self.calculators: List[EnergyCalculator] = [ + self.calculators: list[EnergyCalculator] = [ InterestEnergyCalculator(), ActivityEnergyCalculator(), RecencyEnergyCalculator(), @@ -231,14 +231,14 @@ class EnergyManager: ] # 能量缓存 - self.energy_cache: Dict[str, Tuple[float, float]] = {} # stream_id -> (energy, timestamp) + self.energy_cache: dict[str, tuple[float, float]] = {} # stream_id -> (energy, timestamp) self.cache_ttl: int = 60 # 1分钟缓存 # AFC阈值配置 - self.thresholds: Dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2} + self.thresholds: dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2} # 统计信息 - self.stats: Dict[str, Union[int, float, str]] = { + self.stats: dict[str, int | float | str] = { "total_calculations": 0, "cache_hits": 0, "cache_misses": 0, @@ -272,7 +272,7 @@ class EnergyManager: except Exception as e: logger.warning(f"加载AFC阈值失败,使用默认值: {e}") - async def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float: + async def calculate_focus_energy(self, stream_id: str, messages: list[Any], user_id: str | None = None) -> float: """计算聊天流的focus_energy""" start_time = time.time() @@ -297,7 +297,7 @@ class EnergyManager: } # 计算各组件能量 - component_scores: Dict[str, float] = {} + component_scores: dict[str, float] = {} total_weight = 0.0 for calculator in self.calculators: @@ -437,7 +437,7 @@ class EnergyManager: if expired_keys: logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存") - def get_statistics(self) -> Dict[str, Any]: + def get_statistics(self) -> dict[str, Any]: """获取统计信息""" return { "cache_size": len(self.energy_cache), @@ -446,7 +446,7 @@ class EnergyManager: "performance_stats": self.stats.copy(), } - def update_thresholds(self, new_thresholds: Dict[str, float]) -> None: + def update_thresholds(self, new_thresholds: dict[str, float]) -> None: """更新阈值""" self.thresholds.update(new_thresholds) diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 596322ebd..f9e0e68af 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -1,21 +1,20 @@ -import time -import random -import orjson import os +import random +import time from datetime import datetime +from typing import Any -from typing import List, Dict, Optional, Any, Tuple - -from src.common.logger import get_logger -from src.common.database.sqlalchemy_database_api import get_db_session +import orjson from sqlalchemy import select -from src.common.database.sqlalchemy_models import Expression -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config, global_config -from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.chat_message_builder import build_anonymous_messages, get_raw_msg_by_timestamp_with_chat_inclusive +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import Expression +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest MAX_EXPRESSION_COUNT = 300 DECAY_DAYS = 30 # 30天衰减到0.01 @@ -193,7 +192,7 @@ class ExpressionLearner: logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}") return False - async def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]: + async def get_expression_by_chat_id(self) -> tuple[list[dict[str, float]], list[dict[str, float]]]: """ 获取指定chat_id的style和grammar表达方式 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作 @@ -341,7 +340,7 @@ class ExpressionLearner: return [] # 按chat_id分组 - chat_dict: Dict[str, List[Dict[str, Any]]] = {} + chat_dict: dict[str, list[dict[str, Any]]] = {} for chat_id, situation, style in learnt_expressions: if chat_id not in chat_dict: chat_dict[chat_id] = [] @@ -398,7 +397,7 @@ class ExpressionLearner: return learnt_expressions return None - async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]: + async def learn_expression(self, type: str, num: int = 10) -> tuple[list[tuple[str, str, str]], str] | None: """从指定聊天流学习表达方式 Args: @@ -416,7 +415,7 @@ class ExpressionLearner: current_time = time.time() # 获取上次学习时间 - random_msg: Optional[List[Dict[str, Any]]] = await get_raw_msg_by_timestamp_with_chat_inclusive( + random_msg: list[dict[str, Any]] | None = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_learning_time, timestamp_end=current_time, @@ -447,16 +446,16 @@ class ExpressionLearner: logger.debug(f"学习{type_str}的response: {response}") - expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id) + expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id) return expressions, chat_id @staticmethod - def parse_expression_response(response: str, chat_id: str) -> List[Tuple[str, str, str]]: + def parse_expression_response(response: str, chat_id: str) -> list[tuple[str, str, str]]: """ 解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组 """ - expressions: List[Tuple[str, str, str]] = [] + expressions: list[tuple[str, str, str]] = [] for line in response.splitlines(): line = line.strip() if not line: @@ -562,7 +561,7 @@ class ExpressionLearnerManager: if not os.path.exists(expr_file): continue try: - with open(expr_file, "r", encoding="utf-8") as f: + with open(expr_file, encoding="utf-8") as f: expressions = orjson.loads(f.read()) if not isinstance(expressions, list): diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index ff4083a3b..431d55b46 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -1,18 +1,18 @@ -import orjson -import time -import random import hashlib +import random +import time +from typing import Any -from typing import List, Dict, Tuple, Optional, Any +import orjson from json_repair import repair_json - -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger from sqlalchemy import select -from src.common.database.sqlalchemy_models import Expression + from src.chat.utils.prompt import Prompt, global_prompt_manager from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import Expression +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("expression_selector") @@ -45,7 +45,7 @@ def init_prompt(): Prompt(expression_evaluation_prompt, "expression_evaluation_prompt") -def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]: +def weighted_sample(population: list[dict], weights: list[float], k: int) -> list[dict]: """按权重随机抽样""" if not population or not weights or k <= 0: return [] @@ -95,7 +95,7 @@ class ExpressionSelector: return False @staticmethod - def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]: + def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None: """解析'platform:id:type'为chat_id(与get_stream_id一致)""" try: parts = stream_config_str.split(":") @@ -114,7 +114,7 @@ class ExpressionSelector: except Exception: return None - def get_related_chat_ids(self, chat_id: str) -> List[str]: + def get_related_chat_ids(self, chat_id: str) -> list[str]: """根据expression.rules配置,获取与当前chat_id相关的所有chat_id(包括自身)""" rules = global_config.expression.rules current_group = None @@ -139,7 +139,7 @@ class ExpressionSelector: async def get_random_expressions( self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float - ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: # sourcery skip: extract-duplicate-method, move-assign # 支持多chat_id合并抽选 related_chat_ids = self.get_related_chat_ids(chat_id) @@ -195,7 +195,7 @@ class ExpressionSelector: return selected_style, selected_grammar @staticmethod - async def update_expressions_count_batch(expressions_to_update: List[Dict[str, Any]], increment: float = 0.1): + async def update_expressions_count_batch(expressions_to_update: list[dict[str, Any]], increment: float = 0.1): """对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库""" if not expressions_to_update: return @@ -240,8 +240,8 @@ class ExpressionSelector: chat_info: str, max_num: int = 10, min_num: int = 5, - target_message: Optional[str] = None, - ) -> List[Dict[str, Any]]: + target_message: str | None = None, + ) -> list[dict[str, Any]]: # sourcery skip: inline-variable, list-comprehension """使用LLM选择适合的表达方式""" diff --git a/src/chat/frequency_analyzer/analyzer.py b/src/chat/frequency_analyzer/analyzer.py index 1493c47ea..a3e6addea 100644 --- a/src/chat/frequency_analyzer/analyzer.py +++ b/src/chat/frequency_analyzer/analyzer.py @@ -16,8 +16,7 @@ Chat Frequency Analyzer """ import time as time_module -from datetime import datetime, timedelta, time -from typing import List, Tuple, Optional +from datetime import datetime, time, timedelta from .tracker import chat_frequency_tracker @@ -42,7 +41,7 @@ class ChatFrequencyAnalyzer: self._cache_ttl_seconds = 60 * 30 # 缓存30分钟 @staticmethod - def _find_peak_windows(timestamps: List[float]) -> List[Tuple[datetime, datetime]]: + def _find_peak_windows(timestamps: list[float]) -> list[tuple[datetime, datetime]]: """ 使用滑动窗口算法来识别时间戳列表中的高峰时段。 @@ -59,7 +58,7 @@ class ChatFrequencyAnalyzer: datetimes = [datetime.fromtimestamp(ts) for ts in timestamps] datetimes.sort() - peak_windows: List[Tuple[datetime, datetime]] = [] + peak_windows: list[tuple[datetime, datetime]] = [] window_start_idx = 0 for i in range(len(datetimes)): @@ -83,7 +82,7 @@ class ChatFrequencyAnalyzer: return peak_windows - def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]: + def get_peak_chat_times(self, chat_id: str) -> list[tuple[time, time]]: """ 获取指定用户的高峰聊天时间段。 @@ -116,7 +115,7 @@ class ChatFrequencyAnalyzer: return peak_time_windows - def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool: + def is_in_peak_time(self, chat_id: str, now: datetime | None = None) -> bool: """ 检查当前时间是否处于用户的高峰聊天时段内。 diff --git a/src/chat/frequency_analyzer/tracker.py b/src/chat/frequency_analyzer/tracker.py index 3621cb5b4..371fc6351 100644 --- a/src/chat/frequency_analyzer/tracker.py +++ b/src/chat/frequency_analyzer/tracker.py @@ -1,8 +1,8 @@ -import orjson import time -from typing import Dict, List, Optional from pathlib import Path +import orjson + from src.common.logger import get_logger # 数据存储路径 @@ -19,10 +19,10 @@ class ChatFrequencyTracker: """ def __init__(self): - self._timestamps: Dict[str, List[float]] = self._load_timestamps() + self._timestamps: dict[str, list[float]] = self._load_timestamps() @staticmethod - def _load_timestamps() -> Dict[str, List[float]]: + def _load_timestamps() -> dict[str, list[float]]: """从本地文件加载时间戳数据。""" if not TRACKER_FILE.exists(): return {} @@ -61,7 +61,7 @@ class ChatFrequencyTracker: logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}") self._save_timestamps() - def get_timestamps_for_chat(self, chat_id: str) -> Optional[List[float]]: + def get_timestamps_for_chat(self, chat_id: str) -> list[float] | None: """ 获取指定聊天的所有时间戳记录。 diff --git a/src/chat/frequency_analyzer/trigger.py b/src/chat/frequency_analyzer/trigger.py index 2d8e8b56f..9d8a4fea0 100644 --- a/src/chat/frequency_analyzer/trigger.py +++ b/src/chat/frequency_analyzer/trigger.py @@ -18,11 +18,10 @@ Frequency-Based Proactive Trigger import asyncio import time from datetime import datetime -from typing import Dict, Optional from src.common.logger import get_logger -# AFC manager has been moved to chatter plugin +# AFC manager has been moved to chatter plugin # TODO: 需要重新实现主动思考和睡眠管理功能 from .analyzer import chat_frequency_analyzer @@ -42,10 +41,10 @@ class FrequencyBasedTrigger: def __init__(self): # TODO: 需要重新实现睡眠管理器 - self._task: Optional[asyncio.Task] = None + self._task: asyncio.Task | None = None # 记录上次为用户触发的时间,用于冷却控制 # 格式: { "chat_id": timestamp } - self._last_triggered: Dict[str, float] = {} + self._last_triggered: dict[str, float] = {} async def _run_trigger_cycle(self): """触发器的主要循环逻辑。""" diff --git a/src/chat/interest_system/__init__.py b/src/chat/interest_system/__init__.py index e05cbeebf..0d1a9bbe8 100644 --- a/src/chat/interest_system/__init__.py +++ b/src/chat/interest_system/__init__.py @@ -3,13 +3,14 @@ 提供机器人兴趣标签和智能匹配功能 """ -from .bot_interest_manager import BotInterestManager, bot_interest_manager from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult +from .bot_interest_manager import BotInterestManager, bot_interest_manager + __all__ = [ "BotInterestManager", - "bot_interest_manager", "BotInterestTag", "BotPersonalityInterests", "InterestMatchResult", + "bot_interest_manager", ] diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index 8fee48d1c..b26095f4c 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -3,17 +3,18 @@ 基于人设生成兴趣标签,并使用embedding计算匹配度 """ -import orjson import traceback -from typing import List, Dict, Optional, Any from datetime import datetime +from typing import Any + import numpy as np +import orjson from sqlalchemy import select +from src.common.config_helpers import resolve_embedding_dimension +from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult from src.common.logger import get_logger from src.config.config import global_config -from src.common.config_helpers import resolve_embedding_dimension -from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult logger = get_logger("bot_interest_manager") @@ -22,8 +23,8 @@ class BotInterestManager: """机器人兴趣标签管理器""" def __init__(self): - self.current_interests: Optional[BotPersonalityInterests] = None - self.embedding_cache: Dict[str, List[float]] = {} # embedding缓存 + self.current_interests: BotPersonalityInterests | None = None + self.embedding_cache: dict[str, list[float]] = {} # embedding缓存 self._initialized = False # Embedding客户端配置 @@ -31,7 +32,7 @@ class BotInterestManager: self.embedding_config = None configured_dim = resolve_embedding_dimension() self.embedding_dimension = int(configured_dim) if configured_dim else 0 - self._detected_embedding_dimension: Optional[int] = None + self._detected_embedding_dimension: int | None = None @property def is_initialized(self) -> bool: @@ -145,7 +146,7 @@ class BotInterestManager: async def _generate_interests_from_personality( self, personality_description: str, personality_id: str - ) -> Optional[BotPersonalityInterests]: + ) -> BotPersonalityInterests | None: """根据人设生成兴趣标签""" try: logger.info("🎨 开始根据人设生成兴趣标签...") @@ -226,14 +227,14 @@ class BotInterestManager: traceback.print_exc() raise - async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]: + async def _call_llm_for_interest_generation(self, prompt: str) -> str | None: """调用LLM生成兴趣标签""" try: logger.info("🔧 配置LLM客户端...") # 使用llm_api来处理请求 - from src.plugin_system.apis import llm_api from src.config.config import model_config + from src.plugin_system.apis import llm_api # 构建完整的提示词,明确要求只返回纯JSON full_prompt = f"""你是一个专业的机器人人设分析师,擅长根据人设描述生成合适的兴趣标签。 @@ -342,7 +343,7 @@ class BotInterestManager: logger.info(f"🗃️ 总缓存大小: {len(self.embedding_cache)}") logger.info("=" * 50) - async def _get_embedding(self, text: str) -> List[float]: + async def _get_embedding(self, text: str) -> list[float]: """获取文本的embedding向量""" if not hasattr(self, "embedding_request"): raise RuntimeError("❌ Embedding请求客户端未初始化") @@ -383,7 +384,7 @@ class BotInterestManager: else: raise RuntimeError(f"❌ 返回的embedding为空: {embedding}") - async def _generate_message_embedding(self, message_text: str, keywords: List[str]) -> List[float]: + async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> list[float]: """为消息生成embedding向量""" # 组合消息文本和关键词作为embedding输入 if keywords: @@ -399,7 +400,7 @@ class BotInterestManager: return embedding async def _calculate_similarity_scores( - self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str] + self, result: InterestMatchResult, message_embedding: list[float], keywords: list[str] ): """计算消息与兴趣标签的相似度分数""" try: @@ -428,7 +429,7 @@ class BotInterestManager: except Exception as e: logger.error(f"❌ 计算相似度分数失败: {e}") - async def calculate_interest_match(self, message_text: str, keywords: List[str] = None) -> InterestMatchResult: + async def calculate_interest_match(self, message_text: str, keywords: list[str] = None) -> InterestMatchResult: """计算消息与机器人兴趣的匹配度""" if not self.current_interests or not self._initialized: raise RuntimeError("❌ 兴趣标签系统未初始化") @@ -528,7 +529,7 @@ class BotInterestManager: ) return result - def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]: + def _calculate_keyword_match_bonus(self, keywords: list[str], matched_tags: list[str]) -> dict[str, float]: """计算关键词直接匹配奖励""" if not keywords or not matched_tags: return {} @@ -610,7 +611,7 @@ class BotInterestManager: return previous_row[-1] - def _calculate_cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float: + def _calculate_cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float: """计算余弦相似度""" try: vec1 = np.array(vec1) @@ -629,16 +630,17 @@ class BotInterestManager: logger.error(f"计算余弦相似度失败: {e}") return 0.0 - async def _load_interests_from_database(self, personality_id: str) -> Optional[BotPersonalityInterests]: + async def _load_interests_from_database(self, personality_id: str) -> BotPersonalityInterests | None: """从数据库加载兴趣标签""" try: logger.debug(f"从数据库加载兴趣标签, personality_id: {personality_id}") # 导入SQLAlchemy相关模块 - from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests - from src.common.database.sqlalchemy_database_api import get_db_session import orjson + from src.common.database.sqlalchemy_database_api import get_db_session + from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests + async with get_db_session() as session: # 查询最新的兴趣标签配置 db_interests = ( @@ -716,10 +718,11 @@ class BotInterestManager: logger.info(f"🔄 版本: {interests.version}") # 导入SQLAlchemy相关模块 - from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests - from src.common.database.sqlalchemy_database_api import get_db_session import orjson + from src.common.database.sqlalchemy_database_api import get_db_session + from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests + # 将兴趣标签转换为JSON格式 tags_data = [] for tag in interests.interest_tags: @@ -803,11 +806,11 @@ class BotInterestManager: logger.error("🔍 错误详情:") traceback.print_exc() - def get_current_interests(self) -> Optional[BotPersonalityInterests]: + def get_current_interests(self) -> BotPersonalityInterests | None: """获取当前的兴趣标签配置""" return self.current_interests - def get_interest_stats(self) -> Dict[str, Any]: + def get_interest_stats(self) -> dict[str, Any]: """获取兴趣系统统计信息""" if not self.current_interests: return {"initialized": False} diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index f6fae8d6c..7ef04f985 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -1,33 +1,31 @@ -from dataclasses import dataclass -import orjson -import os -import math import asyncio +import math +import os from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Dict, List, Tuple - -import numpy as np -import pandas as pd +from dataclasses import dataclass # import tqdm import faiss - -from .utils.hash import get_sha256 -from .global_logger import logger -from rich.traceback import install +import numpy as np +import orjson +import pandas as pd from rich.progress import ( - Progress, BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, TimeElapsedColumn, TimeRemainingColumn, - TaskProgressColumn, - MofNCompleteColumn, - SpinnerColumn, - TextColumn, ) -from src.config.config import global_config -from src.common.config_helpers import resolve_embedding_dimension +from rich.traceback import install +from src.common.config_helpers import resolve_embedding_dimension +from src.config.config import global_config + +from .global_logger import logger +from .utils.hash import get_sha256 install(extra_lines=3) @@ -79,7 +77,7 @@ def cosine_similarity(a, b): class EmbeddingStoreItem: """嵌入库中的项""" - def __init__(self, item_hash: str, embedding: List[float], content: str): + def __init__(self, item_hash: str, embedding: list[float], content: str): self.hash = item_hash self.embedding = embedding self.str = content @@ -127,7 +125,7 @@ class EmbeddingStore: self.idx2hash = None @staticmethod - def _get_embedding(s: str) -> List[float]: + def _get_embedding(s: str) -> list[float]: """获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题""" # 创建新的事件循环并在完成后立即关闭 loop = asyncio.new_event_loop() @@ -135,8 +133,8 @@ class EmbeddingStore: try: # 创建新的LLMRequest实例 - from src.llm_models.utils_model import LLMRequest from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") @@ -161,8 +159,8 @@ class EmbeddingStore: @staticmethod def _get_embeddings_batch_threaded( - strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None - ) -> List[Tuple[str, List[float]]]: + strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None + ) -> list[tuple[str, list[float]]]: """使用多线程批量获取嵌入向量 Args: @@ -192,8 +190,8 @@ class EmbeddingStore: chunk_results = [] # 为每个线程创建独立的LLMRequest实例 - from src.llm_models.utils_model import LLMRequest from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest try: # 创建线程专用的LLM实例 @@ -303,7 +301,7 @@ class EmbeddingStore: path = self.get_test_file_path() if not os.path.exists(path): return None - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return orjson.loads(f.read()) def check_embedding_model_consistency(self): @@ -345,7 +343,7 @@ class EmbeddingStore: logger.info("嵌入模型一致性校验通过。") return True - def batch_insert_strs(self, strs: List[str], times: int) -> None: + def batch_insert_strs(self, strs: list[str], times: int) -> None: """向库中存入字符串(使用多线程优化)""" if not strs: return @@ -481,7 +479,7 @@ class EmbeddingStore: if os.path.exists(self.idx2hash_file_path): logger.info(f"正在加载{self.namespace}嵌入库的idx2hash映射...") logger.debug(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射") - with open(self.idx2hash_file_path, "r") as f: + with open(self.idx2hash_file_path) as f: self.idx2hash = orjson.loads(f.read()) logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功") else: @@ -511,7 +509,7 @@ class EmbeddingStore: self.faiss_index = faiss.IndexFlatIP(embedding_dim) self.faiss_index.add(embeddings) - def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]: + def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]: """搜索最相似的k个项,以余弦相似度为度量 Args: query: 查询的embedding @@ -575,11 +573,11 @@ class EmbeddingManager: """对所有嵌入库做模型一致性校验""" return self.paragraphs_embedding_store.check_embedding_model_consistency() - def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): + def _store_pg_into_embedding(self, raw_paragraphs: dict[str, str]): """将段落编码存入Embedding库""" self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1) - def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): + def _store_ent_into_embedding(self, triple_list_data: dict[str, list[list[str]]]): """将实体编码存入Embedding库""" entities = set() for triple_list in triple_list_data.values(): @@ -588,7 +586,7 @@ class EmbeddingManager: entities.add(triple[2]) self.entities_embedding_store.batch_insert_strs(list(entities), times=2) - def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): + def _store_rel_into_embedding(self, triple_list_data: dict[str, list[list[str]]]): """将关系编码存入Embedding库""" graph_triples = [] # a list of unique relation triple (in tuple) from all chunks for triples in triple_list_data.values(): @@ -606,8 +604,8 @@ class EmbeddingManager: def store_new_data_set( self, - raw_paragraphs: Dict[str, str], - triple_list_data: Dict[str, List[List[str]]], + raw_paragraphs: dict[str, str], + triple_list_data: dict[str, list[list[str]]], ): if not self.check_all_embedding_model_consistency(): raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。") diff --git a/src/chat/knowledge/ie_process.py b/src/chat/knowledge/ie_process.py index 457396d0a..e74b7d127 100644 --- a/src/chat/knowledge/ie_process.py +++ b/src/chat/knowledge/ie_process.py @@ -1,14 +1,15 @@ import asyncio -import orjson import time -from typing import List, Union -from .global_logger import logger -from . import prompt_template -from .knowledge_lib import INVALID_ENTITY -from src.llm_models.utils_model import LLMRequest +import orjson from json_repair import repair_json +from src.llm_models.utils_model import LLMRequest + +from . import prompt_template +from .global_logger import logger +from .knowledge_lib import INVALID_ENTITY + def _extract_json_from_text(text: str): # sourcery skip: assign-if-exp, extract-method @@ -46,7 +47,7 @@ def _extract_json_from_text(text: str): return [] -def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: +def _entity_extract(llm_req: LLMRequest, paragraph: str) -> list[str]: # sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" entity_extract_context = prompt_template.build_entity_extract_context(paragraph) @@ -92,7 +93,7 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: return entity_extract_result -def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]: +def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> list[list[str]]: """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" rdf_extract_context = prompt_template.build_rdf_triple_extract_context( paragraph, entities=orjson.dumps(entities).decode("utf-8") @@ -141,7 +142,7 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> def info_extract_from_str( llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str -) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]: +) -> tuple[None, None] | tuple[list[str], list[list[str]]]: try_count = 0 while True: try: diff --git a/src/chat/knowledge/kg_manager.py b/src/chat/knowledge/kg_manager.py index 6d0585226..f590fad7d 100644 --- a/src/chat/knowledge/kg_manager.py +++ b/src/chat/knowledge/kg_manager.py @@ -1,28 +1,26 @@ -import orjson import os import time -from typing import Dict, List, Tuple import numpy as np +import orjson import pandas as pd +from quick_algo import di_graph, pagerank from rich.progress import ( - Progress, BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, TimeElapsedColumn, TimeRemainingColumn, - TaskProgressColumn, - MofNCompleteColumn, - SpinnerColumn, - TextColumn, ) -from quick_algo import di_graph, pagerank - -from .utils.hash import get_sha256 -from .embedding_store import EmbeddingManager, EmbeddingStoreItem from src.config.config import global_config +from .embedding_store import EmbeddingManager, EmbeddingStoreItem from .global_logger import logger +from .utils.hash import get_sha256 def _get_kg_dir(): @@ -87,7 +85,7 @@ class KGManager: raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在") # 加载段落hash - with open(self.pg_hash_file_path, "r", encoding="utf-8") as f: + with open(self.pg_hash_file_path, encoding="utf-8") as f: data = orjson.loads(f.read()) self.stored_paragraph_hashes = set(data["stored_paragraph_hashes"]) @@ -100,8 +98,8 @@ class KGManager: def _build_edges_between_ent( self, - node_to_node: Dict[Tuple[str, str], float], - triple_list_data: Dict[str, List[List[str]]], + node_to_node: dict[tuple[str, str], float], + triple_list_data: dict[str, list[list[str]]], ): """构建实体节点之间的关系,同时统计实体出现次数""" for triple_list in triple_list_data.values(): @@ -124,8 +122,8 @@ class KGManager: @staticmethod def _build_edges_between_ent_pg( - node_to_node: Dict[Tuple[str, str], float], - triple_list_data: Dict[str, List[List[str]]], + node_to_node: dict[tuple[str, str], float], + triple_list_data: dict[str, list[list[str]]], ): """构建实体节点与文段节点之间的关系""" for idx in triple_list_data: @@ -136,8 +134,8 @@ class KGManager: @staticmethod def _synonym_connect( - node_to_node: Dict[Tuple[str, str], float], - triple_list_data: Dict[str, List[List[str]]], + node_to_node: dict[tuple[str, str], float], + triple_list_data: dict[str, list[list[str]]], embedding_manager: EmbeddingManager, ) -> int: """同义词连接""" @@ -208,7 +206,7 @@ class KGManager: def _update_graph( self, - node_to_node: Dict[Tuple[str, str], float], + node_to_node: dict[tuple[str, str], float], embedding_manager: EmbeddingManager, ): """更新KG图结构 @@ -280,7 +278,7 @@ class KGManager: def build_kg( self, - triple_list_data: Dict[str, List[List[str]]], + triple_list_data: dict[str, list[list[str]]], embedding_manager: EmbeddingManager, ): """增量式构建KG @@ -317,8 +315,8 @@ class KGManager: def kg_search( self, - relation_search_result: List[Tuple[Tuple[str, str, str], float]], - paragraph_search_result: List[Tuple[str, float]], + relation_search_result: list[tuple[tuple[str, str, str], float]], + paragraph_search_result: list[tuple[str, float]], embed_manager: EmbeddingManager, ): """RAG搜索与PageRank diff --git a/src/chat/knowledge/knowledge_lib.py b/src/chat/knowledge/knowledge_lib.py index ccc3cd090..a1f49f314 100644 --- a/src/chat/knowledge/knowledge_lib.py +++ b/src/chat/knowledge/knowledge_lib.py @@ -1,10 +1,11 @@ -from src.chat.knowledge.embedding_store import EmbeddingManager -from src.chat.knowledge.qa_manager import QAManager -from src.chat.knowledge.kg_manager import KGManager -from src.chat.knowledge.global_logger import logger -from src.config.config import global_config import os +from src.chat.knowledge.embedding_store import EmbeddingManager +from src.chat.knowledge.global_logger import logger +from src.chat.knowledge.kg_manager import KGManager +from src.chat.knowledge.qa_manager import QAManager +from src.config.config import global_config + INVALID_ENTITY = [ "", "你", diff --git a/src/chat/knowledge/open_ie.py b/src/chat/knowledge/open_ie.py index 23b3032d5..aa01c6c2f 100644 --- a/src/chat/knowledge/open_ie.py +++ b/src/chat/knowledge/open_ie.py @@ -1,14 +1,15 @@ -import orjson -import os import glob -from typing import Any, Dict, List +import os +from typing import Any +import orjson + +from .knowledge_lib import DATA_PATH, INVALID_ENTITY, ROOT_PATH -from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH # from src.manager.local_store_manager import local_storage -def _filter_invalid_entities(entities: List[str]) -> List[str]: +def _filter_invalid_entities(entities: list[str]) -> list[str]: """过滤无效的实体""" valid_entities = set() for entity in entities: @@ -20,7 +21,7 @@ def _filter_invalid_entities(entities: List[str]) -> List[str]: return list(valid_entities) -def _filter_invalid_triples(triples: List[List[str]]) -> List[List[str]]: +def _filter_invalid_triples(triples: list[list[str]]) -> list[list[str]]: """过滤无效的三元组""" unique_triples = set() valid_triples = [] @@ -62,7 +63,7 @@ class OpenIE: def __init__( self, - docs: List[Dict[str, Any]], + docs: list[dict[str, Any]], avg_ent_chars, avg_ent_words, ): @@ -112,7 +113,7 @@ class OpenIE: json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json"))) data_list = [] for file in json_files: - with open(file, "r", encoding="utf-8") as f: + with open(file, encoding="utf-8") as f: data = orjson.loads(f.read()) data_list.append(data) if not data_list: diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index c340fc30e..b08fb24e0 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -1,15 +1,16 @@ import time -from typing import Tuple, List, Dict, Optional, Any +from typing import Any + +from src.chat.utils.utils import get_embedding +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest -from .global_logger import logger from .embedding_store import EmbeddingManager +from .global_logger import logger from .kg_manager import KGManager # from .lpmmconfig import global_config from .utils.dyn_topk import dyn_select_top_k -from src.llm_models.utils_model import LLMRequest -from src.chat.utils.utils import get_embedding -from src.config.config import global_config, model_config MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度 @@ -26,7 +27,7 @@ class QAManager: async def process_query( self, question: str - ) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]: + ) -> tuple[list[tuple[str, float, float]], dict[str, float] | None] | None: """处理查询""" # 生成问题的Embedding @@ -98,7 +99,7 @@ class QAManager: return result, ppr_node_weights - async def get_knowledge(self, question: str) -> Optional[Dict[str, Any]]: + async def get_knowledge(self, question: str) -> dict[str, Any] | None: """ 获取知识,返回结构化字典 diff --git a/src/chat/knowledge/utils/dyn_topk.py b/src/chat/knowledge/utils/dyn_topk.py index df9e470dc..106a68da4 100644 --- a/src/chat/knowledge/utils/dyn_topk.py +++ b/src/chat/knowledge/utils/dyn_topk.py @@ -1,9 +1,9 @@ -from typing import List, Any, Tuple +from typing import Any def dyn_select_top_k( - score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float -) -> List[Tuple[Any, float, float]]: + score: list[tuple[Any, float]], jmp_factor: float, var_factor: float +) -> list[tuple[Any, float, float]]: """动态TopK选择""" # 检查输入列表是否为空 if not score: diff --git a/src/chat/memory_system/__init__.py b/src/chat/memory_system/__init__.py index a1c176a10..d3c5feea4 100644 --- a/src/chat/memory_system/__init__.py +++ b/src/chat/memory_system/__init__.py @@ -1,37 +1,35 @@ -# -*- coding: utf-8 -*- """ 简化记忆系统模块 移除即时记忆和长期记忆分类,实现统一记忆架构和智能遗忘机制 """ # 核心数据结构 +# 激活器 +from .enhanced_memory_activator import MemoryActivator, enhanced_memory_activator, memory_activator from .memory_chunk import ( + ConfidenceLevel, + ContentStructure, + ImportanceLevel, MemoryChunk, MemoryMetadata, - ContentStructure, MemoryType, - ImportanceLevel, - ConfidenceLevel, create_memory_chunk, ) +# 兼容性别名 +from .memory_chunk import MemoryChunk as Memory + # 遗忘引擎 -from .memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig, get_memory_forgetting_engine - -# Vector DB存储系统 -from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage - -# 记忆核心系统 -from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system +from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine # 记忆管理器 from .memory_manager import MemoryManager, MemoryResult, memory_manager -# 激活器 -from .enhanced_memory_activator import MemoryActivator, memory_activator, enhanced_memory_activator +# 记忆核心系统 +from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system -# 兼容性别名 -from .memory_chunk import MemoryChunk as Memory +# Vector DB存储系统 +from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage __all__ = [ # 核心数据结构 diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py index aae09c08b..cf93ceaf0 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py @@ -1,17 +1,17 @@ -# -*- coding: utf-8 -*- """ 增强记忆系统适配器 将增强记忆系统集成到现有MoFox Bot架构中 """ import time -from typing import Dict, List, Optional, Any from dataclasses import dataclass +from typing import Any -from src.common.logger import get_logger -from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +from src.chat.memory_system.integration_layer import IntegrationConfig, IntegrationMode, MemoryIntegrationLayer from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm + +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest logger = get_logger(__name__) @@ -47,10 +47,10 @@ class AdapterConfig: class EnhancedMemoryAdapter: """增强记忆系统适配器""" - def __init__(self, llm_model: LLMRequest, config: Optional[AdapterConfig] = None): + def __init__(self, llm_model: LLMRequest, config: AdapterConfig | None = None): self.llm_model = llm_model self.config = config or AdapterConfig() - self.integration_layer: Optional[MemoryIntegrationLayer] = None + self.integration_layer: MemoryIntegrationLayer | None = None self._initialized = False # 统计信息 @@ -96,7 +96,7 @@ class EnhancedMemoryAdapter: # 如果初始化失败,禁用增强记忆功能 self.config.enable_enhanced_memory = False - async def process_conversation_memory(self, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + async def process_conversation_memory(self, context: dict[str, Any] | None = None) -> dict[str, Any]: """处理对话记忆,以上下文为唯一输入""" if not self._initialized or not self.config.enable_enhanced_memory: return {"success": False, "error": "Enhanced memory not available"} @@ -105,7 +105,7 @@ class EnhancedMemoryAdapter: self.adapter_stats["total_processed"] += 1 try: - payload_context: Dict[str, Any] = dict(context or {}) + payload_context: dict[str, Any] = dict(context or {}) conversation_text = payload_context.get("conversation_text") if not conversation_text: @@ -146,8 +146,8 @@ class EnhancedMemoryAdapter: return {"success": False, "error": str(e)} async def retrieve_relevant_memories( - self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: Optional[int] = None - ) -> List[MemoryChunk]: + self, query: str, user_id: str, context: dict[str, Any] | None = None, limit: int | None = None + ) -> list[MemoryChunk]: """检索相关记忆""" if not self._initialized or not self.config.enable_enhanced_memory: return [] @@ -166,7 +166,7 @@ class EnhancedMemoryAdapter: return [] async def get_memory_context_for_prompt( - self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, max_memories: int = 5 + self, query: str, user_id: str, context: dict[str, Any] | None = None, max_memories: int = 5 ) -> str: """获取用于提示词的记忆上下文""" memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories) @@ -186,7 +186,7 @@ class EnhancedMemoryAdapter: return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config) - async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]: + async def get_enhanced_memory_summary(self, user_id: str) -> dict[str, Any]: """获取增强记忆系统摘要""" if not self._initialized or not self.config.enable_enhanced_memory: return {"available": False, "reason": "Not initialized or disabled"} @@ -222,7 +222,7 @@ class EnhancedMemoryAdapter: new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed self.adapter_stats["average_processing_time"] = new_avg - def get_adapter_stats(self) -> Dict[str, Any]: + def get_adapter_stats(self) -> dict[str, Any]: """获取适配器统计信息""" return self.adapter_stats.copy() @@ -253,7 +253,7 @@ class EnhancedMemoryAdapter: # 全局适配器实例 -_enhanced_memory_adapter: Optional[EnhancedMemoryAdapter] = None +_enhanced_memory_adapter: EnhancedMemoryAdapter | None = None async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter: @@ -292,8 +292,8 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest): async def process_conversation_with_enhanced_memory( - context: Dict[str, Any], llm_model: Optional[LLMRequest] = None -) -> Dict[str, Any]: + context: dict[str, Any], llm_model: LLMRequest | None = None +) -> dict[str, Any]: """使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息""" if not llm_model: # 获取默认的LLM模型 @@ -323,10 +323,10 @@ async def process_conversation_with_enhanced_memory( async def retrieve_memories_with_enhanced_system( query: str, user_id: str, - context: Optional[Dict[str, Any]] = None, + context: dict[str, Any] | None = None, limit: int = 10, - llm_model: Optional[LLMRequest] = None, -) -> List[MemoryChunk]: + llm_model: LLMRequest | None = None, +) -> list[MemoryChunk]: """使用增强记忆系统检索记忆""" if not llm_model: # 获取默认的LLM模型 @@ -345,9 +345,9 @@ async def retrieve_memories_with_enhanced_system( async def get_memory_context_for_prompt( query: str, user_id: str, - context: Optional[Dict[str, Any]] = None, + context: dict[str, Any] | None = None, max_memories: int = 5, - llm_model: Optional[LLMRequest] = None, + llm_model: LLMRequest | None = None, ) -> str: """获取用于提示词的记忆上下文""" if not llm_model: diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py index a1b374510..2794332cf 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py @@ -1,15 +1,15 @@ -# -*- coding: utf-8 -*- """ 增强记忆系统钩子 用于在消息处理过程中自动构建和检索记忆 """ -from typing import Dict, List, Any, Optional from datetime import datetime +from typing import Any + +from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager from src.common.logger import get_logger from src.config.config import global_config -from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager logger = get_logger(__name__) @@ -27,7 +27,7 @@ class EnhancedMemoryHooks: user_id: str, chat_id: str, message_id: str, - context: Optional[Dict[str, Any]] = None, + context: dict[str, Any] | None = None, ) -> bool: """ 处理消息并构建记忆 @@ -106,8 +106,8 @@ class EnhancedMemoryHooks: user_id: str, chat_id: str, limit: int = 5, - extra_context: Optional[Dict[str, Any]] = None, - ) -> List[Dict[str, Any]]: + extra_context: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: """ 为回复获取相关记忆 diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py index 913c2aed0..8583f7dd2 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py @@ -1,19 +1,19 @@ -# -*- coding: utf-8 -*- """ 增强记忆系统集成脚本 用于在现有系统中无缝集成增强记忆功能 """ -from typing import Dict, Any, Optional +from typing import Any + +from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks from src.common.logger import get_logger -from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks logger = get_logger(__name__) async def process_user_message_memory( - message_content: str, user_id: str, chat_id: str, message_id: str, context: Optional[Dict[str, Any]] = None + message_content: str, user_id: str, chat_id: str, message_id: str, context: dict[str, Any] | None = None ) -> bool: """ 处理用户消息并构建记忆 @@ -44,8 +44,8 @@ async def process_user_message_memory( async def get_relevant_memories_for_response( - query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: + query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: dict[str, Any] | None = None +) -> dict[str, Any]: """ 为回复获取相关记忆 @@ -74,7 +74,7 @@ async def get_relevant_memories_for_response( return {"has_memories": False, "memories": [], "memory_count": 0} -def format_memories_for_prompt(memories: Dict[str, Any]) -> str: +def format_memories_for_prompt(memories: dict[str, Any]) -> str: """ 格式化记忆信息用于Prompt @@ -114,7 +114,7 @@ async def cleanup_memory_system(): logger.error(f"记忆系统清理失败: {e}") -def get_memory_system_status() -> Dict[str, Any]: +def get_memory_system_status() -> dict[str, Any]: """ 获取记忆系统状态 @@ -133,7 +133,7 @@ def get_memory_system_status() -> Dict[str, Any]: # 便捷函数 async def remember_message( - message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: Optional[Dict[str, Any]] = None + message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: dict[str, Any] | None = None ) -> bool: """ 便捷的记忆构建函数 @@ -159,8 +159,8 @@ async def recall_memories( user_id: str = "default_user", chat_id: str = "default_chat", limit: int = 5, - context: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: + context: dict[str, Any] | None = None, +) -> dict[str, Any]: """ 便捷的记忆检索函数 diff --git a/src/chat/memory_system/deprecated_backup/enhanced_reranker.py b/src/chat/memory_system/deprecated_backup/enhanced_reranker.py index e5b368460..c35b9de53 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_reranker.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_reranker.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 增强重排序器 实现文档设计的多维度评分模型 @@ -6,12 +5,12 @@ import math import time -from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass from enum import Enum +from typing import Any -from src.common.logger import get_logger from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +from src.common.logger import get_logger logger = get_logger(__name__) @@ -44,7 +43,7 @@ class ReRankingConfig: freq_max_score: float = 5.0 # 最大频率得分 # 类型匹配权重映射 - type_match_weights: Dict[str, Dict[str, float]] = None + type_match_weights: dict[str, dict[str, float]] = None def __post_init__(self): """初始化类型匹配权重""" @@ -157,7 +156,7 @@ class IntentClassifier: ], } - def classify_intent(self, query: str, context: Dict[str, Any]) -> IntentType: + def classify_intent(self, query: str, context: dict[str, Any]) -> IntentType: """识别对话意图""" if not query: return IntentType.UNKNOWN @@ -165,7 +164,7 @@ class IntentClassifier: query_lower = query.lower() # 统计各意图的匹配分数 - intent_scores = {intent: 0 for intent in IntentType} + intent_scores = dict.fromkeys(IntentType, 0) for intent, patterns in self.patterns.items(): for pattern in patterns: @@ -187,7 +186,7 @@ class IntentClassifier: class EnhancedReRanker: """增强重排序器 - 实现文档设计的多维度评分模型""" - def __init__(self, config: Optional[ReRankingConfig] = None): + def __init__(self, config: ReRankingConfig | None = None): self.config = config or ReRankingConfig() self.intent_classifier = IntentClassifier() @@ -210,10 +209,10 @@ class EnhancedReRanker: def rerank_memories( self, query: str, - candidate_memories: List[Tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity) - context: Dict[str, Any], + candidate_memories: list[tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity) + context: dict[str, Any], limit: int = 10, - ) -> List[Tuple[str, MemoryChunk, float]]: + ) -> list[tuple[str, MemoryChunk, float]]: """ 对候选记忆进行重排序 @@ -341,11 +340,11 @@ default_reranker = EnhancedReRanker() def rerank_candidate_memories( query: str, - candidate_memories: List[Tuple[str, MemoryChunk, float]], - context: Dict[str, Any], + candidate_memories: list[tuple[str, MemoryChunk, float]], + context: dict[str, Any], limit: int = 10, - config: Optional[ReRankingConfig] = None, -) -> List[Tuple[str, MemoryChunk, float]]: + config: ReRankingConfig | None = None, +) -> list[tuple[str, MemoryChunk, float]]: """ 便捷函数:对候选记忆进行重排序 """ diff --git a/src/chat/memory_system/deprecated_backup/integration_layer.py b/src/chat/memory_system/deprecated_backup/integration_layer.py index 5b9282a84..c7a27b8cb 100644 --- a/src/chat/memory_system/deprecated_backup/integration_layer.py +++ b/src/chat/memory_system/deprecated_backup/integration_layer.py @@ -1,18 +1,18 @@ -# -*- coding: utf-8 -*- """ 增强记忆系统集成层 现在只管理新的增强记忆系统,旧系统已被完全移除 """ -import time import asyncio -from typing import Dict, List, Optional, Any +import time from dataclasses import dataclass from enum import Enum +from typing import Any -from src.common.logger import get_logger from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem + from src.chat.memory_system.memory_chunk import MemoryChunk +from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest logger = get_logger(__name__) @@ -40,12 +40,12 @@ class IntegrationConfig: class MemoryIntegrationLayer: """记忆系统集成层 - 现在只管理增强记忆系统""" - def __init__(self, llm_model: LLMRequest, config: Optional[IntegrationConfig] = None): + def __init__(self, llm_model: LLMRequest, config: IntegrationConfig | None = None): self.llm_model = llm_model self.config = config or IntegrationConfig() # 只初始化增强记忆系统 - self.enhanced_memory: Optional[EnhancedMemorySystem] = None + self.enhanced_memory: EnhancedMemorySystem | None = None # 集成统计 self.integration_stats = { @@ -113,7 +113,7 @@ class MemoryIntegrationLayer: logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True) raise - async def process_conversation(self, context: Dict[str, Any]) -> Dict[str, Any]: + async def process_conversation(self, context: dict[str, Any]) -> dict[str, Any]: """处理对话记忆,仅使用上下文信息""" if not self._initialized or not self.enhanced_memory: return {"success": False, "error": "Memory system not available"} @@ -150,10 +150,10 @@ class MemoryIntegrationLayer: async def retrieve_relevant_memories( self, query: str, - user_id: Optional[str] = None, - context: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None, - ) -> List[MemoryChunk]: + user_id: str | None = None, + context: dict[str, Any] | None = None, + limit: int | None = None, + ) -> list[MemoryChunk]: """检索相关记忆""" if not self._initialized or not self.enhanced_memory: return [] @@ -172,7 +172,7 @@ class MemoryIntegrationLayer: logger.error(f"检索相关记忆失败: {e}", exc_info=True) return [] - async def get_system_status(self) -> Dict[str, Any]: + async def get_system_status(self) -> dict[str, Any]: """获取系统状态""" if not self._initialized: return {"status": "not_initialized"} @@ -193,7 +193,7 @@ class MemoryIntegrationLayer: logger.error(f"获取系统状态失败: {e}", exc_info=True) return {"status": "error", "error": str(e)} - def get_integration_stats(self) -> Dict[str, Any]: + def get_integration_stats(self) -> dict[str, Any]: """获取集成统计信息""" return self.integration_stats.copy() diff --git a/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py b/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py index 4659389cb..a37e4c548 100644 --- a/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py +++ b/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py @@ -1,20 +1,20 @@ -# -*- coding: utf-8 -*- """ 记忆系统集成钩子 提供与现有MoFox Bot系统的无缝集成点 """ import time -from typing import Dict, Optional, Any from dataclasses import dataclass +from typing import Any -from src.common.logger import get_logger from src.chat.memory_system.enhanced_memory_adapter import ( + get_memory_context_for_prompt, process_conversation_with_enhanced_memory, retrieve_memories_with_enhanced_system, - get_memory_context_for_prompt, ) +from src.common.logger import get_logger + logger = get_logger(__name__) @@ -24,7 +24,7 @@ class HookResult: success: bool data: Any = None - error: Optional[str] = None + error: str | None = None processing_time: float = 0.0 @@ -125,8 +125,8 @@ class MemoryIntegrationHooks: # 尝试注册到事件系统 try: - from src.plugin_system.core.event_manager import event_manager from src.plugin_system.base.component_types import EventType + from src.plugin_system.core.event_manager import event_manager # 注册消息后处理事件 event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler) @@ -238,11 +238,11 @@ class MemoryIntegrationHooks: # 钩子处理器方法 - async def _on_message_processed_handler(self, event_data: Dict[str, Any]) -> HookResult: + async def _on_message_processed_handler(self, event_data: dict[str, Any]) -> HookResult: """事件系统的消息处理处理器""" return await self._on_message_processed_hook(event_data) - async def _on_message_processed_hook(self, message_data: Dict[str, Any]) -> HookResult: + async def _on_message_processed_hook(self, message_data: dict[str, Any]) -> HookResult: """消息后处理钩子""" start_time = time.time() @@ -289,7 +289,7 @@ class MemoryIntegrationHooks: logger.error(f"消息处理钩子执行异常: {e}", exc_info=True) return HookResult(success=False, error=str(e), processing_time=processing_time) - async def _on_chat_stream_save_hook(self, chat_stream_data: Dict[str, Any]) -> HookResult: + async def _on_chat_stream_save_hook(self, chat_stream_data: dict[str, Any]) -> HookResult: """聊天流保存钩子""" start_time = time.time() @@ -345,7 +345,7 @@ class MemoryIntegrationHooks: logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True) return HookResult(success=False, error=str(e), processing_time=processing_time) - async def _on_pre_response_hook(self, response_data: Dict[str, Any]) -> HookResult: + async def _on_pre_response_hook(self, response_data: dict[str, Any]) -> HookResult: """回复前钩子""" start_time = time.time() @@ -380,7 +380,7 @@ class MemoryIntegrationHooks: logger.error(f"回复前钩子执行异常: {e}", exc_info=True) return HookResult(success=False, error=str(e), processing_time=processing_time) - async def _on_knowledge_query_hook(self, query_data: Dict[str, Any]) -> HookResult: + async def _on_knowledge_query_hook(self, query_data: dict[str, Any]) -> HookResult: """知识库查询钩子""" start_time = time.time() @@ -411,7 +411,7 @@ class MemoryIntegrationHooks: logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True) return HookResult(success=False, error=str(e), processing_time=processing_time) - async def _on_prompt_building_hook(self, prompt_data: Dict[str, Any]) -> HookResult: + async def _on_prompt_building_hook(self, prompt_data: dict[str, Any]) -> HookResult: """提示词构建钩子""" start_time = time.time() @@ -459,7 +459,7 @@ class MemoryIntegrationHooks: new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions self.hook_stats["average_hook_time"] = new_avg - def get_hook_stats(self) -> Dict[str, Any]: + def get_hook_stats(self) -> dict[str, Any]: """获取钩子统计信息""" return self.hook_stats.copy() @@ -501,7 +501,7 @@ class MemoryMaintenanceTask: # 全局钩子实例 -_memory_hooks: Optional[MemoryIntegrationHooks] = None +_memory_hooks: MemoryIntegrationHooks | None = None async def get_memory_integration_hooks() -> MemoryIntegrationHooks: diff --git a/src/chat/memory_system/deprecated_backup/metadata_index.py b/src/chat/memory_system/deprecated_backup/metadata_index.py index f7ab8ecda..8c89e5c34 100644 --- a/src/chat/memory_system/deprecated_backup/metadata_index.py +++ b/src/chat/memory_system/deprecated_backup/metadata_index.py @@ -1,20 +1,20 @@ -# -*- coding: utf-8 -*- """ 元数据索引系统 为记忆系统提供多维度的精准过滤和查询能力 """ +import threading import time -import orjson -from typing import Dict, List, Optional, Tuple, Set, Any, Union +from collections import defaultdict from dataclasses import dataclass from enum import Enum -import threading -from collections import defaultdict from pathlib import Path +from typing import Any +import orjson + +from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk, MemoryType from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel logger = get_logger(__name__) @@ -40,21 +40,21 @@ class IndexType(Enum): class IndexQuery: """索引查询条件""" - user_ids: Optional[List[str]] = None - memory_types: Optional[List[MemoryType]] = None - subjects: Optional[List[str]] = None - keywords: Optional[List[str]] = None - tags: Optional[List[str]] = None - categories: Optional[List[str]] = None - time_range: Optional[Tuple[float, float]] = None - confidence_levels: Optional[List[ConfidenceLevel]] = None - importance_levels: Optional[List[ImportanceLevel]] = None - min_relationship_score: Optional[float] = None - max_relationship_score: Optional[float] = None - min_access_count: Optional[int] = None - semantic_hashes: Optional[List[str]] = None - limit: Optional[int] = None - sort_by: Optional[str] = None # "created_at", "access_count", "relevance_score" + user_ids: list[str] | None = None + memory_types: list[MemoryType] | None = None + subjects: list[str] | None = None + keywords: list[str] | None = None + tags: list[str] | None = None + categories: list[str] | None = None + time_range: tuple[float, float] | None = None + confidence_levels: list[ConfidenceLevel] | None = None + importance_levels: list[ImportanceLevel] | None = None + min_relationship_score: float | None = None + max_relationship_score: float | None = None + min_access_count: int | None = None + semantic_hashes: list[str] | None = None + limit: int | None = None + sort_by: str | None = None # "created_at", "access_count", "relevance_score" sort_order: str = "desc" # "asc", "desc" @@ -62,10 +62,10 @@ class IndexQuery: class IndexResult: """索引结果""" - memory_ids: List[str] + memory_ids: list[str] total_count: int query_time: float - filtered_by: List[str] + filtered_by: list[str] class MetadataIndexManager: @@ -94,7 +94,7 @@ class MetadataIndexManager: self.access_frequency_index = [] # [(access_count, memory_id), ...] # 内存缓存 - self.memory_metadata_cache: Dict[str, Dict[str, Any]] = {} + self.memory_metadata_cache: dict[str, dict[str, Any]] = {} # 统计信息 self.index_stats = { @@ -140,7 +140,7 @@ class MetadataIndexManager: return key @staticmethod - def _serialize_metadata_entry(metadata: Dict[str, Any]) -> Dict[str, Any]: + def _serialize_metadata_entry(metadata: dict[str, Any]) -> dict[str, Any]: serialized = {} for field_name, value in metadata.items(): if isinstance(value, Enum): @@ -149,7 +149,7 @@ class MetadataIndexManager: serialized[field_name] = value return serialized - async def index_memories(self, memories: List[MemoryChunk]): + async def index_memories(self, memories: list[MemoryChunk]): """为记忆建立索引""" if not memories: return @@ -375,7 +375,7 @@ class MetadataIndexManager: logger.error(f"❌ 元数据查询失败: {e}", exc_info=True) return IndexResult(memory_ids=[], total_count=0, query_time=0.0, filtered_by=[]) - def _get_candidate_memories(self, query: IndexQuery) -> Set[str]: + def _get_candidate_memories(self, query: IndexQuery) -> set[str]: """获取候选记忆ID集合""" candidate_ids = set() @@ -444,7 +444,7 @@ class MetadataIndexManager: return candidate_ids - def _collect_index_matches(self, index_type: IndexType, token: Optional[Union[str, Enum]]) -> Set[str]: + def _collect_index_matches(self, index_type: IndexType, token: str | Enum | None) -> set[str]: """根据给定token收集索引匹配,支持部分匹配""" mapping = self.indices.get(index_type) if mapping is None: @@ -461,7 +461,7 @@ class MetadataIndexManager: if not key: return set() - matches: Set[str] = set(mapping.get(key, set())) + matches: set[str] = set(mapping.get(key, set())) if matches: return set(matches) @@ -477,7 +477,7 @@ class MetadataIndexManager: return matches - def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]: + def _apply_filters(self, candidate_ids: set[str], query: IndexQuery) -> list[str]: """应用过滤条件""" filtered_ids = list(candidate_ids) @@ -545,7 +545,7 @@ class MetadataIndexManager: created_at = self.memory_metadata_cache[memory_id]["created_at"] return start_time <= created_at <= end_time - def _sort_memories(self, memory_ids: List[str], sort_by: str, sort_order: str) -> List[str]: + def _sort_memories(self, memory_ids: list[str], sort_by: str, sort_order: str) -> list[str]: """对记忆进行排序""" if sort_by == "created_at": # 使用时间索引(已经有序) @@ -582,7 +582,7 @@ class MetadataIndexManager: return memory_ids - def _get_applied_filters(self, query: IndexQuery) -> List[str]: + def _get_applied_filters(self, query: IndexQuery) -> list[str]: """获取应用的过滤器列表""" filters = [] if query.memory_types: @@ -686,11 +686,11 @@ class MetadataIndexManager: except Exception as e: logger.error(f"❌ 移除记忆索引失败: {e}") - async def get_memory_metadata(self, memory_id: str) -> Optional[Dict[str, Any]]: + async def get_memory_metadata(self, memory_id: str) -> dict[str, Any] | None: """获取记忆元数据""" return self.memory_metadata_cache.get(memory_id) - async def get_user_memory_ids(self, user_id: str, limit: Optional[int] = None) -> List[str]: + async def get_user_memory_ids(self, user_id: str, limit: int | None = None) -> list[str]: """获取用户的所有记忆ID""" user_memory_ids = list(self.indices[IndexType.USER_ID].get(user_id, set())) @@ -699,7 +699,7 @@ class MetadataIndexManager: return user_memory_ids - async def get_memory_statistics(self, user_id: Optional[str] = None) -> Dict[str, Any]: + async def get_memory_statistics(self, user_id: str | None = None) -> dict[str, Any]: """获取记忆统计信息""" stats = { "total_memories": self.index_stats["total_memories"], @@ -784,7 +784,7 @@ class MetadataIndexManager: logger.info("正在保存元数据索引...") # 保存各类索引 - indices_data: Dict[str, Dict[str, List[str]]] = {} + indices_data: dict[str, dict[str, list[str]]] = {} for index_type, index_data in self.indices.items(): serialized_index = {} for key, values in index_data.items(): @@ -839,7 +839,7 @@ class MetadataIndexManager: # 加载各类索引 indices_file = self.index_path / "indices.json" if indices_file.exists(): - with open(indices_file, "r", encoding="utf-8") as f: + with open(indices_file, encoding="utf-8") as f: indices_data = orjson.loads(f.read()) for index_type_value, index_data in indices_data.items(): @@ -853,25 +853,25 @@ class MetadataIndexManager: # 加载时间索引 time_index_file = self.index_path / "time_index.json" if time_index_file.exists(): - with open(time_index_file, "r", encoding="utf-8") as f: + with open(time_index_file, encoding="utf-8") as f: self.time_index = orjson.loads(f.read()) # 加载关系分索引 relationship_index_file = self.index_path / "relationship_index.json" if relationship_index_file.exists(): - with open(relationship_index_file, "r", encoding="utf-8") as f: + with open(relationship_index_file, encoding="utf-8") as f: self.relationship_index = orjson.loads(f.read()) # 加载访问频率索引 access_frequency_index_file = self.index_path / "access_frequency_index.json" if access_frequency_index_file.exists(): - with open(access_frequency_index_file, "r", encoding="utf-8") as f: + with open(access_frequency_index_file, encoding="utf-8") as f: self.access_frequency_index = orjson.loads(f.read()) # 加载元数据缓存 metadata_cache_file = self.index_path / "metadata_cache.json" if metadata_cache_file.exists(): - with open(metadata_cache_file, "r", encoding="utf-8") as f: + with open(metadata_cache_file, encoding="utf-8") as f: cache_data = orjson.loads(f.read()) # 转换置信度和重要性为枚举类型 @@ -914,7 +914,7 @@ class MetadataIndexManager: # 加载统计信息 stats_file = self.index_path / "index_stats.json" if stats_file.exists(): - with open(stats_file, "r", encoding="utf-8") as f: + with open(stats_file, encoding="utf-8") as f: self.index_stats = orjson.loads(f.read()) # 更新记忆计数 @@ -1004,7 +1004,7 @@ class MetadataIndexManager: if len(self.indices[IndexType.CATEGORY][category]) < min_frequency: del self.indices[IndexType.CATEGORY][category] - def get_index_stats(self) -> Dict[str, Any]: + def get_index_stats(self) -> dict[str, Any]: """获取索引统计信息""" stats = self.index_stats.copy() if stats["total_queries"] > 0: diff --git a/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py b/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py index bc0a1a0f4..f13792603 100644 --- a/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py +++ b/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py @@ -1,19 +1,19 @@ -# -*- coding: utf-8 -*- """ 多阶段召回机制 实现粗粒度到细粒度的记忆检索优化 """ import time -from typing import Dict, List, Optional, Set, Any from dataclasses import dataclass, field from enum import Enum -import orjson +from typing import Any -from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +import orjson from src.chat.memory_system.enhanced_reranker import EnhancedReRanker, ReRankingConfig +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +from src.common.logger import get_logger + logger = get_logger(__name__) @@ -73,11 +73,11 @@ class StageResult: """阶段结果""" stage: RetrievalStage - memory_ids: List[str] + memory_ids: list[str] processing_time: float filtered_count: int score_threshold: float - details: List[Dict[str, Any]] = field(default_factory=list) + details: list[dict[str, Any]] = field(default_factory=list) @dataclass @@ -86,17 +86,17 @@ class RetrievalResult: query: str user_id: str - final_memories: List[MemoryChunk] - stage_results: List[StageResult] + final_memories: list[MemoryChunk] + stage_results: list[StageResult] total_processing_time: float total_filtered: int - retrieval_stats: Dict[str, Any] + retrieval_stats: dict[str, Any] class MultiStageRetrieval: """多阶段召回系统""" - def __init__(self, config: Optional[RetrievalConfig] = None): + def __init__(self, config: RetrievalConfig | None = None): self.config = config or RetrievalConfig.from_global_config() # 初始化增强重排序器 @@ -124,11 +124,11 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], + context: dict[str, Any], metadata_index, vector_storage, - all_memories_cache: Dict[str, MemoryChunk], - limit: Optional[int] = None, + all_memories_cache: dict[str, MemoryChunk], + limit: int | None = None, ) -> RetrievalResult: """多阶段记忆检索""" start_time = time.time() @@ -136,7 +136,7 @@ class MultiStageRetrieval: stage_results = [] current_memory_ids = set() - memory_debug_info: Dict[str, Dict[str, Any]] = {} + memory_debug_info: dict[str, dict[str, Any]] = {} try: logger.debug(f"开始多阶段检索:query='{query}', user_id='{user_id}'") @@ -311,11 +311,11 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], + context: dict[str, Any], metadata_index, - all_memories_cache: Dict[str, MemoryChunk], + all_memories_cache: dict[str, MemoryChunk], *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None, + debug_log: dict[str, dict[str, Any]] | None = None, ) -> StageResult: """阶段1:元数据过滤""" start_time = time.time() @@ -345,7 +345,7 @@ class MultiStageRetrieval: result = await metadata_index.query_memories(index_query) result_ids = list(result.memory_ids) filtered_count = max(0, len(all_memories_cache) - len(result_ids)) - details: List[Dict[str, Any]] = [] + details: list[dict[str, Any]] = [] # 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆 if not result_ids: @@ -440,12 +440,12 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], + context: dict[str, Any], vector_storage, - candidate_ids: Set[str], - all_memories_cache: Dict[str, MemoryChunk], + candidate_ids: set[str], + all_memories_cache: dict[str, MemoryChunk], *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None, + debug_log: dict[str, dict[str, Any]] | None = None, ) -> StageResult: """阶段2:向量搜索""" start_time = time.time() @@ -479,8 +479,8 @@ class MultiStageRetrieval: # 过滤候选记忆 filtered_memories = [] - details: List[Dict[str, Any]] = [] - raw_details: List[Dict[str, Any]] = [] + details: list[dict[str, Any]] = [] + raw_details: list[dict[str, Any]] = [] threshold = self.config.vector_similarity_threshold for memory_id, similarity in search_result: @@ -561,7 +561,7 @@ class MultiStageRetrieval: ) def _create_text_search_fallback( - self, candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], query_text: str, start_time: float + self, candidate_ids: set[str], all_memories_cache: dict[str, MemoryChunk], query_text: str, start_time: float ) -> StageResult: """当向量搜索失败时,使用文本搜索作为回退策略""" try: @@ -618,18 +618,18 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], - candidate_ids: Set[str], - all_memories_cache: Dict[str, MemoryChunk], + context: dict[str, Any], + candidate_ids: set[str], + all_memories_cache: dict[str, MemoryChunk], *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None, + debug_log: dict[str, dict[str, Any]] | None = None, ) -> StageResult: """阶段3:语义重排序""" start_time = time.time() try: reranked_memories = [] - details: List[Dict[str, Any]] = [] + details: list[dict[str, Any]] = [] threshold = self.config.semantic_similarity_threshold for memory_id in candidate_ids: @@ -704,19 +704,19 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], - candidate_ids: List[str], - all_memories_cache: Dict[str, MemoryChunk], + context: dict[str, Any], + candidate_ids: list[str], + all_memories_cache: dict[str, MemoryChunk], limit: int, *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None, + debug_log: dict[str, dict[str, Any]] | None = None, ) -> StageResult: """阶段4:上下文过滤""" start_time = time.time() try: final_memories = [] - details: List[Dict[str, Any]] = [] + details: list[dict[str, Any]] = [] for memory_id in candidate_ids: if memory_id not in all_memories_cache: @@ -793,12 +793,12 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], - all_memories_cache: Dict[str, MemoryChunk], + context: dict[str, Any], + all_memories_cache: dict[str, MemoryChunk], limit: int, *, - excluded_ids: Optional[Set[str]] = None, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None, + excluded_ids: set[str] | None = None, + debug_log: dict[str, dict[str, Any]] | None = None, ) -> StageResult: """回退检索阶段 - 当主检索失败时使用更宽松的策略""" start_time = time.time() @@ -881,8 +881,8 @@ class MultiStageRetrieval: ) async def _generate_query_embedding( - self, query: str, context: Dict[str, Any], vector_storage - ) -> Optional[List[float]]: + self, query: str, context: dict[str, Any], vector_storage + ) -> list[float] | None: """生成查询向量""" try: query_plan = context.get("query_plan") @@ -916,7 +916,7 @@ class MultiStageRetrieval: logger.error(f"生成查询向量时发生异常: {e}", exc_info=True) return None - async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float: + async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float: """计算语义相似度 - 简化优化版本,提升召回率""" try: query_plan = context.get("query_plan") @@ -947,9 +947,10 @@ class MultiStageRetrieval: # 核心匹配策略2:词汇匹配 word_score = 0.0 try: - import jieba import re + import jieba + # 分词处理 query_words = list(jieba.cut(query_text)) + re.findall(r"[a-zA-Z]+", query_text) memory_words = list(jieba.cut(memory_text)) + re.findall(r"[a-zA-Z]+", memory_text) @@ -1059,7 +1060,7 @@ class MultiStageRetrieval: logger.warning(f"计算语义相似度失败: {e}") return 0.0 - async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float: + async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float: """计算上下文相关度""" try: score = 0.0 @@ -1132,7 +1133,7 @@ class MultiStageRetrieval: return 0.0 async def _calculate_final_score( - self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float + self, query: str, memory: MemoryChunk, context: dict[str, Any], context_score: float ) -> float: """计算最终评分""" try: @@ -1184,7 +1185,7 @@ class MultiStageRetrieval: logger.warning(f"计算最终评分失败: {e}") return 0.0 - def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: Optional[List[str]]) -> float: + def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: list[str] | None) -> float: if not required_subjects: return 0.0 @@ -1229,7 +1230,7 @@ class MultiStageRetrieval: except Exception: return 0.5 - def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]: + def _extract_memory_types_from_context(self, context: dict[str, Any]) -> list[MemoryType]: """从上下文中提取记忆类型""" try: query_plan = context.get("query_plan") @@ -1256,10 +1257,10 @@ class MultiStageRetrieval: except Exception: return [] - def _extract_keywords_from_query(self, query: str, query_plan: Optional[Any] = None) -> List[str]: + def _extract_keywords_from_query(self, query: str, query_plan: Any | None = None) -> list[str]: """从查询中提取关键词""" try: - extracted: List[str] = [] + extracted: list[str] = [] if query_plan and getattr(query_plan, "required_keywords", None): extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)]) @@ -1283,7 +1284,7 @@ class MultiStageRetrieval: except Exception: return [] - def _update_retrieval_stats(self, total_time: float, stage_results: List[StageResult]): + def _update_retrieval_stats(self, total_time: float, stage_results: list[StageResult]): """更新检索统计""" self.retrieval_stats["total_queries"] += 1 @@ -1306,7 +1307,7 @@ class MultiStageRetrieval: ] stage_stat["avg_time"] = new_stage_avg - def get_retrieval_stats(self) -> Dict[str, Any]: + def get_retrieval_stats(self) -> dict[str, Any]: """获取检索统计信息""" return self.retrieval_stats.copy() @@ -1328,12 +1329,12 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], - candidate_ids: List[str], - all_memories_cache: Dict[str, MemoryChunk], + context: dict[str, Any], + candidate_ids: list[str], + all_memories_cache: dict[str, MemoryChunk], limit: int, *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None, + debug_log: dict[str, dict[str, Any]] | None = None, ) -> StageResult: """阶段5:增强重排序 - 使用多维度评分模型""" start_time = time.time() diff --git a/src/chat/memory_system/deprecated_backup/vector_storage.py b/src/chat/memory_system/deprecated_backup/vector_storage.py index 5d2e4fb91..d5d974486 100644 --- a/src/chat/memory_system/deprecated_backup/vector_storage.py +++ b/src/chat/memory_system/deprecated_backup/vector_storage.py @@ -1,24 +1,23 @@ -# -*- coding: utf-8 -*- """ 向量数据库存储接口 为记忆系统提供高效的向量存储和语义搜索能力 """ -import time -import orjson import asyncio -from typing import Dict, List, Optional, Tuple, Any -from dataclasses import dataclass import threading +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any import numpy as np -from pathlib import Path +import orjson -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config -from src.common.config_helpers import resolve_embedding_dimension from src.chat.memory_system.memory_chunk import MemoryChunk +from src.common.config_helpers import resolve_embedding_dimension +from src.common.logger import get_logger +from src.config.config import model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger(__name__) @@ -48,7 +47,7 @@ class VectorStorageConfig: class VectorStorageManager: """向量存储管理器""" - def __init__(self, config: Optional[VectorStorageConfig] = None): + def __init__(self, config: VectorStorageConfig | None = None): self.config = config or VectorStorageConfig() resolved_dimension = resolve_embedding_dimension(self.config.dimension) @@ -68,8 +67,8 @@ class VectorStorageManager: self.index_to_memory_id = {} # vector index -> memory_id # 内存缓存 - self.memory_cache: Dict[str, MemoryChunk] = {} - self.vector_cache: Dict[str, List[float]] = {} + self.memory_cache: dict[str, MemoryChunk] = {} + self.vector_cache: dict[str, list[float]] = {} # 统计信息 self.storage_stats = { @@ -125,7 +124,7 @@ class VectorStorageManager: ) logger.info("✅ 嵌入模型初始化完成") - async def generate_query_embedding(self, query_text: str) -> Optional[List[float]]: + async def generate_query_embedding(self, query_text: str) -> list[float] | None: """生成查询向量,用于记忆召回""" if not query_text: logger.warning("查询文本为空,无法生成向量") @@ -155,7 +154,7 @@ class VectorStorageManager: logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True) return None - async def store_memories(self, memories: List[MemoryChunk]): + async def store_memories(self, memories: list[MemoryChunk]): """存储记忆向量""" if not memories: return @@ -231,7 +230,7 @@ class VectorStorageManager: logger.debug("记忆 %s 缺少可用展示文本,使用占位符生成嵌入输入", memory.memory_id) return memory.memory_id - async def _batch_generate_and_store_embeddings(self, memory_texts: List[Tuple[str, str]]): + async def _batch_generate_and_store_embeddings(self, memory_texts: list[tuple[str, str]]): """批量生成和存储嵌入向量""" if not memory_texts: return @@ -253,12 +252,12 @@ class VectorStorageManager: except Exception as e: logger.error(f"❌ 批量生成嵌入向量失败: {e}") - async def _batch_generate_embeddings(self, memory_ids: List[str], texts: List[str]) -> Dict[str, List[float]]: + async def _batch_generate_embeddings(self, memory_ids: list[str], texts: list[str]) -> dict[str, list[float]]: """批量生成嵌入向量""" if not texts: return {} - results: Dict[str, List[float]] = {} + results: dict[str, list[float]] = {} try: semaphore = asyncio.Semaphore(min(4, max(1, len(texts)))) @@ -281,7 +280,9 @@ class VectorStorageManager: logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc) results[memory_id] = [] - tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)] + tasks = [ + asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False) + ] await asyncio.gather(*tasks, return_exceptions=True) except Exception as e: @@ -291,7 +292,7 @@ class VectorStorageManager: return results - async def _add_single_memory(self, memory: MemoryChunk, embedding: List[float]): + async def _add_single_memory(self, memory: MemoryChunk, embedding: list[float]): """添加单个记忆到向量存储""" with self._lock: try: @@ -337,7 +338,7 @@ class VectorStorageManager: except Exception as e: logger.error(f"❌ 添加记忆到向量存储失败: {e}") - def _normalize_vector(self, vector: List[float]) -> List[float]: + def _normalize_vector(self, vector: list[float]) -> list[float]: """L2归一化向量""" if not vector: return vector @@ -357,12 +358,12 @@ class VectorStorageManager: async def search_similar_memories( self, - query_vector: Optional[List[float]] = None, + query_vector: list[float] | None = None, *, - query_text: Optional[str] = None, + query_text: str | None = None, limit: int = 10, - scope_id: Optional[str] = None, - ) -> List[Tuple[str, float]]: + scope_id: str | None = None, + ) -> list[tuple[str, float]]: """搜索相似记忆""" start_time = time.time() @@ -379,7 +380,7 @@ class VectorStorageManager: logger.warning("查询向量生成失败") return [] - scope_filter: Optional[str] = None + scope_filter: str | None = None if isinstance(scope_id, str): normalized_scope = scope_id.strip().lower() if normalized_scope and normalized_scope not in {"global", "global_memory"}: @@ -491,7 +492,7 @@ class VectorStorageManager: logger.error(f"❌ 向量搜索失败: {e}", exc_info=True) return [] - async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]: + async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None: """根据ID获取记忆""" # 先检查缓存 if memory_id in self.memory_cache: @@ -501,7 +502,7 @@ class VectorStorageManager: self.storage_stats["total_searches"] += 1 return None - async def update_memory_embedding(self, memory_id: str, new_embedding: List[float]): + async def update_memory_embedding(self, memory_id: str, new_embedding: list[float]): """更新记忆的嵌入向量""" with self._lock: try: @@ -636,7 +637,7 @@ class VectorStorageManager: # 加载记忆缓存 cache_file = self.storage_path / "memory_cache.json" if cache_file.exists(): - with open(cache_file, "r", encoding="utf-8") as f: + with open(cache_file, encoding="utf-8") as f: cache_data = orjson.loads(f.read()) self.memory_cache = { @@ -646,13 +647,13 @@ class VectorStorageManager: # 加载向量缓存 vector_cache_file = self.storage_path / "vector_cache.json" if vector_cache_file.exists(): - with open(vector_cache_file, "r", encoding="utf-8") as f: + with open(vector_cache_file, encoding="utf-8") as f: self.vector_cache = orjson.loads(f.read()) # 加载映射关系 mapping_file = self.storage_path / "id_mapping.json" if mapping_file.exists(): - with open(mapping_file, "r", encoding="utf-8") as f: + with open(mapping_file, encoding="utf-8") as f: mapping_data = orjson.loads(f.read()) raw_memory_to_index = mapping_data.get("memory_id_to_index", {}) self.memory_id_to_index = { @@ -689,7 +690,7 @@ class VectorStorageManager: # 加载统计信息 stats_file = self.storage_path / "storage_stats.json" if stats_file.exists(): - with open(stats_file, "r", encoding="utf-8") as f: + with open(stats_file, encoding="utf-8") as f: self.storage_stats = orjson.loads(f.read()) # 更新向量计数 @@ -806,7 +807,7 @@ class VectorStorageManager: if invalid_memory_ids: logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用") - def get_storage_stats(self) -> Dict[str, Any]: + def get_storage_stats(self) -> dict[str, Any]: """获取存储统计信息""" stats = self.storage_stats.copy() if stats["total_searches"] > 0: @@ -821,11 +822,11 @@ class SimpleVectorIndex: def __init__(self, dimension: int): self.dimension = dimension - self.vectors: List[List[float]] = [] - self.vector_ids: List[int] = [] + self.vectors: list[list[float]] = [] + self.vector_ids: list[int] = [] self.next_id = 0 - def add_vector(self, vector: List[float]) -> int: + def add_vector(self, vector: list[float]) -> int: """添加向量""" if len(vector) != self.dimension: raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}") @@ -837,7 +838,7 @@ class SimpleVectorIndex: return vector_id - def search(self, query_vector: List[float], limit: int) -> List[Tuple[int, float]]: + def search(self, query_vector: list[float], limit: int) -> list[tuple[int, float]]: """搜索相似向量""" if len(query_vector) != self.dimension: raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}") @@ -853,7 +854,7 @@ class SimpleVectorIndex: return results[:limit] - def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float: + def _calculate_cosine_similarity(self, v1: list[float], v2: list[float]) -> float: """计算余弦相似度""" try: dot_product = sum(x * y for x, y in zip(v1, v2, strict=False)) diff --git a/src/chat/memory_system/enhanced_memory_activator.py b/src/chat/memory_system/enhanced_memory_activator.py index 7570715ee..22b44c7a1 100644 --- a/src/chat/memory_system/enhanced_memory_activator.py +++ b/src/chat/memory_system/enhanced_memory_activator.py @@ -1,25 +1,24 @@ -# -*- coding: utf-8 -*- """ 记忆激活器 记忆系统的激活器组件 """ import difflib -import orjson -from typing import List, Dict, Optional from datetime import datetime +import orjson from json_repair import repair_json -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger -from src.chat.utils.prompt import Prompt, global_prompt_manager + from src.chat.memory_system.memory_manager import MemoryResult +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("memory_activator") -def get_keywords_from_json(json_str) -> List: +def get_keywords_from_json(json_str) -> list: """ 从JSON字符串中提取关键词列表 @@ -81,7 +80,7 @@ class MemoryActivator: self.cached_keywords = set() # 用于缓存历史关键词 self.last_memory_query_time = 0 # 上次查询记忆的时间 - async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]: + async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]: """ 激活记忆 """ @@ -155,7 +154,7 @@ class MemoryActivator: return self.running_memory - async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]: + async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]: """查询统一记忆系统""" try: # 使用记忆系统 @@ -198,7 +197,7 @@ class MemoryActivator: logger.error(f"查询统一记忆失败: {e}") return [] - async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]: + async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None: """ 获取即时记忆 - 兼容原有接口(使用统一存储) """ diff --git a/src/chat/memory_system/memory_activator_new.py b/src/chat/memory_system/memory_activator_new.py index 491034de4..0b4e9a938 100644 --- a/src/chat/memory_system/memory_activator_new.py +++ b/src/chat/memory_system/memory_activator_new.py @@ -1,25 +1,24 @@ -# -*- coding: utf-8 -*- """ 记忆激活器 记忆系统的激活器组件 """ import difflib -import orjson -from typing import List, Dict, Optional from datetime import datetime +import orjson from json_repair import repair_json -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger -from src.chat.utils.prompt import Prompt, global_prompt_manager + from src.chat.memory_system.memory_manager import MemoryResult +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("memory_activator") -def get_keywords_from_json(json_str) -> List: +def get_keywords_from_json(json_str) -> list: """ 从JSON字符串中提取关键词列表 @@ -81,7 +80,7 @@ class MemoryActivator: self.cached_keywords = set() # 用于缓存历史关键词 self.last_memory_query_time = 0 # 上次查询记忆的时间 - async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]: + async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]: """ 激活记忆 """ @@ -155,7 +154,7 @@ class MemoryActivator: return self.running_memory - async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]: + async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]: """查询统一记忆系统""" try: # 使用记忆系统 @@ -198,7 +197,7 @@ class MemoryActivator: logger.error(f"查询统一记忆失败: {e}") return [] - async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]: + async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None: """ 获取即时记忆 - 兼容原有接口(使用统一存储) """ diff --git a/src/chat/memory_system/memory_builder.py b/src/chat/memory_system/memory_builder.py index 0c3f47043..a2f936028 100644 --- a/src/chat/memory_system/memory_builder.py +++ b/src/chat/memory_system/memory_builder.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 记忆构建模块 从对话流中提取高质量、结构化记忆单元 @@ -33,19 +32,19 @@ import time from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Union, Type +from typing import Any import orjson -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest from src.chat.memory_system.memory_chunk import ( - MemoryChunk, - MemoryType, ConfidenceLevel, ImportanceLevel, + MemoryChunk, + MemoryType, create_memory_chunk, ) +from src.common.logger import get_logger +from src.llm_models.utils_model import LLMRequest logger = get_logger(__name__) @@ -62,8 +61,8 @@ class ExtractionStrategy(Enum): class ExtractionResult: """提取结果""" - memories: List[MemoryChunk] - confidence_scores: List[float] + memories: list[MemoryChunk] + confidence_scores: list[float] extraction_time: float strategy_used: ExtractionStrategy @@ -85,8 +84,8 @@ class MemoryBuilder: } async def build_memories( - self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: float - ) -> List[MemoryChunk]: + self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float + ) -> list[MemoryChunk]: """从对话中构建记忆""" start_time = time.time() @@ -116,8 +115,8 @@ class MemoryBuilder: raise async def _extract_with_llm( - self, text: str, context: Dict[str, Any], user_id: str, timestamp: float - ) -> List[MemoryChunk]: + self, text: str, context: dict[str, Any], user_id: str, timestamp: float + ) -> list[MemoryChunk]: """使用LLM提取记忆""" try: prompt = self._build_llm_extraction_prompt(text, context) @@ -135,7 +134,7 @@ class MemoryBuilder: logger.error(f"LLM提取失败: {e}") raise MemoryExtractionError(str(e)) from e - def _build_llm_extraction_prompt(self, text: str, context: Dict[str, Any]) -> str: + def _build_llm_extraction_prompt(self, text: str, context: dict[str, Any]) -> str: """构建LLM提取提示""" current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") message_type = context.get("message_type", "normal") @@ -315,7 +314,7 @@ class MemoryBuilder: return prompt - def _extract_json_payload(self, response: str) -> Optional[str]: + def _extract_json_payload(self, response: str) -> str | None: """从模型响应中提取JSON部分,兼容Markdown代码块等格式""" if not response: return None @@ -338,8 +337,8 @@ class MemoryBuilder: return stripped if stripped.startswith("{") and stripped.endswith("}") else None def _parse_llm_response( - self, response: str, user_id: str, timestamp: float, context: Dict[str, Any] - ) -> List[MemoryChunk]: + self, response: str, user_id: str, timestamp: float, context: dict[str, Any] + ) -> list[MemoryChunk]: """解析LLM响应""" if not response: raise MemoryExtractionError("LLM未返回任何响应") @@ -385,7 +384,7 @@ class MemoryBuilder: bot_display = self._clean_subject_text(bot_display) - memories: List[MemoryChunk] = [] + memories: list[MemoryChunk] = [] for mem_data in memory_list: try: @@ -460,7 +459,7 @@ class MemoryBuilder: return memories - def _parse_enum_value(self, enum_cls: Type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum: + def _parse_enum_value(self, enum_cls: type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum: """解析枚举值,兼容数字/字符串表示""" if isinstance(raw_value, enum_cls): return raw_value @@ -514,7 +513,7 @@ class MemoryBuilder: ) return default - def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]: + def _collect_bot_identifiers(self, context: dict[str, Any] | None) -> set[str]: identifiers: set[str] = {"bot", "机器人", "ai助手"} if not context: return identifiers @@ -540,7 +539,7 @@ class MemoryBuilder: return identifiers - def _collect_system_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]: + def _collect_system_identifiers(self, context: dict[str, Any] | None) -> set[str]: identifiers: set[str] = set() if not context: return identifiers @@ -568,8 +567,8 @@ class MemoryBuilder: return identifiers - def _resolve_conversation_participants(self, context: Optional[Dict[str, Any]], user_id: str) -> List[str]: - participants: List[str] = [] + def _resolve_conversation_participants(self, context: dict[str, Any] | None, user_id: str) -> list[str]: + participants: list[str] = [] if context: candidate_keys = [ @@ -609,7 +608,7 @@ class MemoryBuilder: if not participants: participants = ["对话参与者"] - deduplicated: List[str] = [] + deduplicated: list[str] = [] seen = set() for name in participants: key = name.lower() @@ -620,7 +619,7 @@ class MemoryBuilder: return deduplicated - def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str: + def _resolve_user_display(self, context: dict[str, Any] | None, user_id: str) -> str: candidate_keys = [ "user_display_name", "user_name", @@ -683,7 +682,7 @@ class MemoryBuilder: return False - def _split_subject_string(self, value: str) -> List[str]: + def _split_subject_string(self, value: str) -> list[str]: if not value: return [] @@ -699,12 +698,12 @@ class MemoryBuilder: subject: Any, bot_identifiers: set[str], system_identifiers: set[str], - default_subjects: List[str], - bot_display: Optional[str] = None, - ) -> List[str]: + default_subjects: list[str], + bot_display: str | None = None, + ) -> list[str]: defaults = default_subjects or ["对话参与者"] - raw_candidates: List[str] = [] + raw_candidates: list[str] = [] if isinstance(subject, list): for item in subject: if isinstance(item, str): @@ -716,7 +715,7 @@ class MemoryBuilder: elif subject is not None: raw_candidates.extend(self._split_subject_string(str(subject))) - normalized: List[str] = [] + normalized: list[str] = [] bot_primary = self._clean_subject_text(bot_display or "") for candidate in raw_candidates: @@ -741,7 +740,7 @@ class MemoryBuilder: if not normalized: normalized = list(defaults) - deduplicated: List[str] = [] + deduplicated: list[str] = [] seen = set() for name in normalized: key = name.lower() @@ -752,7 +751,7 @@ class MemoryBuilder: return deduplicated - def _extract_value_from_object(self, obj: Union[str, Dict[str, Any], List[Any]], keys: List[str]) -> Optional[str]: + def _extract_value_from_object(self, obj: str | dict[str, Any] | list[Any], keys: list[str]) -> str | None: if isinstance(obj, dict): for key in keys: value = obj.get(key) @@ -773,9 +772,7 @@ class MemoryBuilder: return obj.strip() or None return None - def _compose_display_text( - self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]] - ) -> str: + def _compose_display_text(self, subjects: list[str], predicate: str, obj: str | dict[str, Any] | list[Any]) -> str: subject_phrase = "、".join(subjects) if subjects else "对话参与者" predicate = (predicate or "").strip() @@ -841,7 +838,7 @@ class MemoryBuilder: return f"{subject_phrase}{predicate}".strip() return subject_phrase - def _validate_and_enhance_memories(self, memories: List[MemoryChunk], context: Dict[str, Any]) -> List[MemoryChunk]: + def _validate_and_enhance_memories(self, memories: list[MemoryChunk], context: dict[str, Any]) -> list[MemoryChunk]: """验证和增强记忆""" validated_memories = [] @@ -876,7 +873,7 @@ class MemoryBuilder: return True - def _enhance_memory(self, memory: MemoryChunk, context: Dict[str, Any]) -> MemoryChunk: + def _enhance_memory(self, memory: MemoryChunk, context: dict[str, Any]) -> MemoryChunk: """增强记忆块""" # 时间规范化处理 self._normalize_time_in_memory(memory) @@ -985,7 +982,7 @@ class MemoryBuilder: total_confidence / self.extraction_stats["successful_extractions"] ) - def get_extraction_stats(self) -> Dict[str, Any]: + def get_extraction_stats(self) -> dict[str, Any]: """获取提取统计信息""" return self.extraction_stats.copy() diff --git a/src/chat/memory_system/memory_chunk.py b/src/chat/memory_system/memory_chunk.py index b5b609af6..dcce6eb64 100644 --- a/src/chat/memory_system/memory_chunk.py +++ b/src/chat/memory_system/memory_chunk.py @@ -1,18 +1,19 @@ -# -*- coding: utf-8 -*- """ 结构化记忆单元设计 实现高质量、结构化的记忆单元,符合文档设计规范 """ +import hashlib import time import uuid -import orjson -from typing import Dict, List, Optional, Any, Union, Iterable +from collections.abc import Iterable from dataclasses import dataclass, field from enum import Enum -import hashlib +from typing import Any import numpy as np +import orjson + from src.common.logger import get_logger logger = get_logger(__name__) @@ -56,17 +57,17 @@ class ImportanceLevel(Enum): class ContentStructure: """主谓宾结构,包含自然语言描述""" - subject: Union[str, List[str]] + subject: str | list[str] predicate: str - object: Union[str, Dict] + object: str | dict display: str = "" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为字典格式""" return {"subject": self.subject, "predicate": self.predicate, "object": self.object, "display": self.display} @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure": + def from_dict(cls, data: dict[str, Any]) -> "ContentStructure": """从字典创建实例""" return cls( subject=data.get("subject", ""), @@ -75,7 +76,7 @@ class ContentStructure: display=data.get("display", ""), ) - def to_subject_list(self) -> List[str]: + def to_subject_list(self) -> list[str]: """将主语转换为列表形式""" if isinstance(self.subject, list): return [s for s in self.subject if isinstance(s, str) and s.strip()] @@ -99,7 +100,7 @@ class MemoryMetadata: # 基础信息 memory_id: str # 唯一标识符 user_id: str # 用户ID - chat_id: Optional[str] = None # 聊天ID(群聊或私聊) + chat_id: str | None = None # 聊天ID(群聊或私聊) # 时间信息 created_at: float = 0.0 # 创建时间戳 @@ -124,9 +125,9 @@ class MemoryMetadata: last_forgetting_check: float = 0.0 # 上次遗忘检查时间 # 来源信息 - source_context: Optional[str] = None # 来源上下文片段 + source_context: str | None = None # 来源上下文片段 # 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source - source: Optional[str] = None + source: str | None = None def __post_init__(self): """后初始化处理""" @@ -209,7 +210,7 @@ class MemoryMetadata: # 设置最小和最大阈值 return max(7.0, min(threshold, 365.0)) # 7天到1年之间 - def should_forget(self, current_time: Optional[float] = None) -> bool: + def should_forget(self, current_time: float | None = None) -> bool: """判断是否应该遗忘""" if current_time is None: current_time = time.time() @@ -222,7 +223,7 @@ class MemoryMetadata: return days_since_activation > self.forgetting_threshold - def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool: + def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool: """判断是否处于休眠状态(长期未激活)""" if current_time is None: current_time = time.time() @@ -230,7 +231,7 @@ class MemoryMetadata: days_since_last_access = (current_time - self.last_accessed) / 86400 return days_since_last_access > inactive_days - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为字典格式""" return { "memory_id": self.memory_id, @@ -252,7 +253,7 @@ class MemoryMetadata: } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "MemoryMetadata": + def from_dict(cls, data: dict[str, Any]) -> "MemoryMetadata": """从字典创建实例""" return cls( memory_id=data.get("memory_id", ""), @@ -286,17 +287,17 @@ class MemoryChunk: memory_type: MemoryType # 记忆类型 # 扩展信息 - keywords: List[str] = field(default_factory=list) # 关键词列表 - tags: List[str] = field(default_factory=list) # 标签列表 - categories: List[str] = field(default_factory=list) # 分类列表 + keywords: list[str] = field(default_factory=list) # 关键词列表 + tags: list[str] = field(default_factory=list) # 标签列表 + categories: list[str] = field(default_factory=list) # 分类列表 # 语义信息 - embedding: Optional[List[float]] = None # 语义向量 - semantic_hash: Optional[str] = None # 语义哈希值 + embedding: list[float] | None = None # 语义向量 + semantic_hash: str | None = None # 语义哈希值 # 关联信息 - related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表 - temporal_context: Optional[Dict[str, Any]] = None # 时间上下文 + related_memories: list[str] = field(default_factory=list) # 关联记忆ID列表 + temporal_context: dict[str, Any] | None = None # 时间上下文 def __post_init__(self): """后初始化处理""" @@ -310,7 +311,7 @@ class MemoryChunk: try: # 使用向量和内容生成稳定的哈希 - content_str = f"{self.content.subject}:{self.content.predicate}:{str(self.content.object)}" + content_str = f"{self.content.subject}:{self.content.predicate}:{self.content.object!s}" embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding])) hash_input = f"{content_str}|{embedding_str}" @@ -342,7 +343,7 @@ class MemoryChunk: return self.content.display or str(self.content) @property - def subjects(self) -> List[str]: + def subjects(self) -> list[str]: """获取主语列表""" return self.content.to_subject_list() @@ -354,11 +355,11 @@ class MemoryChunk: """更新相关度评分""" self.metadata.update_relevance(new_score) - def should_forget(self, current_time: Optional[float] = None) -> bool: + def should_forget(self, current_time: float | None = None) -> bool: """判断是否应该遗忘""" return self.metadata.should_forget(current_time) - def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool: + def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool: """判断是否处于休眠状态(长期未激活)""" return self.metadata.is_dormant(current_time, inactive_days) @@ -386,7 +387,7 @@ class MemoryChunk: if memory_id and memory_id not in self.related_memories: self.related_memories.append(memory_id) - def set_embedding(self, embedding: List[float]): + def set_embedding(self, embedding: list[float]): """设置语义向量""" self.embedding = embedding self._generate_semantic_hash() @@ -415,7 +416,7 @@ class MemoryChunk: logger.warning(f"计算记忆相似度失败: {e}") return 0.0 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为完整的字典格式""" return { "metadata": self.metadata.to_dict(), @@ -431,7 +432,7 @@ class MemoryChunk: } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "MemoryChunk": + def from_dict(cls, data: dict[str, Any]) -> "MemoryChunk": """从字典创建实例""" metadata = MemoryMetadata.from_dict(data.get("metadata", {})) content = ContentStructure.from_dict(data.get("content", {})) @@ -541,7 +542,7 @@ class MemoryChunk: return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})" -def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, Dict]) -> str: +def _build_display_text(subjects: Iterable[str], predicate: str, obj: str | dict) -> str: """根据主谓宾生成自然语言描述""" subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)] subject_part = "、".join(subjects_clean) if subjects_clean else "对话参与者" @@ -569,15 +570,15 @@ def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, def create_memory_chunk( user_id: str, - subject: Union[str, List[str]], + subject: str | list[str], predicate: str, - obj: Union[str, Dict], + obj: str | dict, memory_type: MemoryType, - chat_id: Optional[str] = None, - source_context: Optional[str] = None, + chat_id: str | None = None, + source_context: str | None = None, importance: ImportanceLevel = ImportanceLevel.NORMAL, confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM, - display: Optional[str] = None, + display: str | None = None, **kwargs, ) -> MemoryChunk: """便捷的内存块创建函数""" @@ -593,10 +594,10 @@ def create_memory_chunk( source_context=source_context, ) - subjects: List[str] + subjects: list[str] if isinstance(subject, list): subjects = [s for s in subject if isinstance(s, str) and s.strip()] - subject_payload: Union[str, List[str]] = subjects + subject_payload: str | list[str] = subjects else: cleaned = subject.strip() if isinstance(subject, str) else "" subjects = [cleaned] if cleaned else [] diff --git a/src/chat/memory_system/memory_forgetting_engine.py b/src/chat/memory_system/memory_forgetting_engine.py index 3e243e433..e41d1149c 100644 --- a/src/chat/memory_system/memory_forgetting_engine.py +++ b/src/chat/memory_system/memory_forgetting_engine.py @@ -1,17 +1,15 @@ -# -*- coding: utf-8 -*- """ 智能记忆遗忘引擎 基于重要程度、置信度和激活频率的智能遗忘机制 """ -import time import asyncio -from typing import List, Dict, Optional, Tuple -from datetime import datetime +import time from dataclasses import dataclass +from datetime import datetime +from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import MemoryChunk, ImportanceLevel, ConfidenceLevel logger = get_logger(__name__) @@ -65,7 +63,7 @@ class ForgettingConfig: class MemoryForgettingEngine: """智能记忆遗忘引擎""" - def __init__(self, config: Optional[ForgettingConfig] = None): + def __init__(self, config: ForgettingConfig | None = None): self.config = config or ForgettingConfig() self.stats = ForgettingStats() self._last_forgetting_check = 0.0 @@ -116,7 +114,7 @@ class MemoryForgettingEngine: # 确保在合理范围内 return max(self.config.min_forgetting_days, min(threshold, self.config.max_forgetting_days)) - def should_forget_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool: + def should_forget_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool: """ 判断记忆是否应该被遗忘 @@ -155,7 +153,7 @@ class MemoryForgettingEngine: return should_forget - def is_dormant_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool: + def is_dormant_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool: """ 判断记忆是否处于休眠状态 @@ -168,7 +166,7 @@ class MemoryForgettingEngine: """ return memory.is_dormant(current_time, self.config.dormant_threshold_days) - def should_force_forget_dormant(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool: + def should_force_forget_dormant(self, memory: MemoryChunk, current_time: float | None = None) -> bool: """ 判断是否应该强制遗忘休眠记忆 @@ -189,7 +187,7 @@ class MemoryForgettingEngine: days_since_last_access = (current_time - memory.metadata.last_accessed) / 86400 return days_since_last_access > self.config.force_forget_dormant_days - async def check_memories_for_forgetting(self, memories: List[MemoryChunk]) -> Tuple[List[str], List[str]]: + async def check_memories_for_forgetting(self, memories: list[MemoryChunk]) -> tuple[list[str], list[str]]: """ 检查记忆列表,识别需要遗忘的记忆 @@ -241,7 +239,7 @@ class MemoryForgettingEngine: return normal_forgetting_ids, force_forgetting_ids - async def perform_forgetting_check(self, memories: List[MemoryChunk]) -> Dict[str, any]: + async def perform_forgetting_check(self, memories: list[MemoryChunk]) -> dict[str, any]: """ 执行完整的遗忘检查流程 @@ -314,7 +312,7 @@ class MemoryForgettingEngine: except Exception as e: logger.error(f"定期遗忘检查失败: {e}", exc_info=True) - def get_forgetting_stats(self) -> Dict[str, any]: + def get_forgetting_stats(self) -> dict[str, any]: """获取遗忘统计信息""" return { "total_checked": self.stats.total_checked, diff --git a/src/chat/memory_system/memory_fusion.py b/src/chat/memory_system/memory_fusion.py index 3ecc4cd71..59f36ed93 100644 --- a/src/chat/memory_system/memory_fusion.py +++ b/src/chat/memory_system/memory_fusion.py @@ -1,16 +1,14 @@ -# -*- coding: utf-8 -*- """ 记忆融合与去重机制 避免记忆碎片化,确保长期记忆库的高质量 """ import time -from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass +from typing import Any - +from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel logger = get_logger(__name__) @@ -22,9 +20,9 @@ class FusionResult: original_count: int fused_count: int removed_duplicates: int - merged_memories: List[MemoryChunk] + merged_memories: list[MemoryChunk] fusion_time: float - details: List[str] + details: list[str] @dataclass @@ -32,9 +30,9 @@ class DuplicateGroup: """重复记忆组""" group_id: str - memories: List[MemoryChunk] - similarity_matrix: List[List[float]] - representative_memory: Optional[MemoryChunk] = None + memories: list[MemoryChunk] + similarity_matrix: list[list[float]] + representative_memory: MemoryChunk | None = None class MemoryFusionEngine: @@ -59,8 +57,8 @@ class MemoryFusionEngine: } async def fuse_memories( - self, new_memories: List[MemoryChunk], existing_memories: Optional[List[MemoryChunk]] = None - ) -> List[MemoryChunk]: + self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk] | None = None + ) -> list[MemoryChunk]: """融合记忆列表""" start_time = time.time() @@ -106,8 +104,8 @@ class MemoryFusionEngine: return new_memories # 失败时返回原始记忆 async def _detect_duplicate_groups( - self, new_memories: List[MemoryChunk], existing_memories: List[MemoryChunk] - ) -> List[DuplicateGroup]: + self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk] + ) -> list[DuplicateGroup]: """检测重复记忆组""" all_memories = new_memories + existing_memories new_memory_ids = {memory.memory_id for memory in new_memories} @@ -212,7 +210,7 @@ class MemoryFusionEngine: jaccard_similarity = len(intersection) / len(union) return jaccard_similarity - def _calculate_keyword_similarity(self, keywords1: List[str], keywords2: List[str]) -> float: + def _calculate_keyword_similarity(self, keywords1: list[str], keywords2: list[str]) -> float: """计算关键词相似度""" if not keywords1 or not keywords2: return 0.0 @@ -302,7 +300,7 @@ class MemoryFusionEngine: return best_memory - async def _fuse_memory_group(self, group: DuplicateGroup) -> Optional[MemoryChunk]: + async def _fuse_memory_group(self, group: DuplicateGroup) -> MemoryChunk | None: """融合记忆组""" if not group.memories: return None @@ -328,7 +326,7 @@ class MemoryFusionEngine: # 返回置信度最高的记忆 return max(group.memories, key=lambda m: m.metadata.confidence.value) - async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: List[MemoryChunk]) -> MemoryChunk: + async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: list[MemoryChunk]) -> MemoryChunk: """合并记忆属性""" # 创建基础记忆的深拷贝 fused_memory = MemoryChunk.from_dict(base_memory.to_dict()) @@ -395,7 +393,7 @@ class MemoryFusionEngine: source_ids = [m.memory_id[:8] for m in group.memories] fused_memory.metadata.source_context = f"Fused from {len(group.memories)} memories: {', '.join(source_ids)}" - def _merge_temporal_context(self, memories: List[MemoryChunk]) -> Dict[str, Any]: + def _merge_temporal_context(self, memories: list[MemoryChunk]) -> dict[str, Any]: """合并时间上下文""" contexts = [m.temporal_context for m in memories if m.temporal_context] @@ -426,8 +424,8 @@ class MemoryFusionEngine: return merged_context async def incremental_fusion( - self, new_memory: MemoryChunk, existing_memories: List[MemoryChunk] - ) -> Tuple[MemoryChunk, List[MemoryChunk]]: + self, new_memory: MemoryChunk, existing_memories: list[MemoryChunk] + ) -> tuple[MemoryChunk, list[MemoryChunk]]: """增量融合(单个新记忆与现有记忆融合)""" # 寻找相似记忆 similar_memories = [] @@ -493,7 +491,7 @@ class MemoryFusionEngine: except Exception as e: logger.error(f"❌ 记忆融合引擎维护失败: {e}", exc_info=True) - def get_fusion_stats(self) -> Dict[str, Any]: + def get_fusion_stats(self) -> dict[str, Any]: """获取融合统计信息""" return self.fusion_stats.copy() diff --git a/src/chat/memory_system/memory_manager.py b/src/chat/memory_system/memory_manager.py index 4c6b2696e..1ba79fe59 100644 --- a/src/chat/memory_system/memory_manager.py +++ b/src/chat/memory_system/memory_manager.py @@ -1,17 +1,15 @@ -# -*- coding: utf-8 -*- """ 记忆系统管理器 替代原有的 Hippocampus 和 instant_memory 系统 """ import re -from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass +from typing import Any -from src.common.logger import get_logger -from src.chat.memory_system.memory_system import MemorySystem from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType -from src.chat.memory_system.memory_system import initialize_memory_system +from src.chat.memory_system.memory_system import MemorySystem, initialize_memory_system +from src.common.logger import get_logger logger = get_logger(__name__) @@ -27,14 +25,14 @@ class MemoryResult: timestamp: float source: str = "memory" relevance_score: float = 0.0 - structure: Dict[str, Any] | None = None + structure: dict[str, Any] | None = None class MemoryManager: """记忆系统管理器 - 替代原有的 HippocampusManager""" def __init__(self): - self.memory_system: Optional[MemorySystem] = None + self.memory_system: MemorySystem | None = None self.is_initialized = False self.user_cache = {} # 用户记忆缓存 @@ -63,8 +61,8 @@ class MemoryManager: logger.info("正在初始化记忆系统...") # 获取LLM模型 - from src.llm_models.utils_model import LLMRequest from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory") @@ -121,7 +119,7 @@ class MemoryManager: max_memory_length: int = 2, time_weight: float = 1.0, keyword_weight: float = 1.0, - ) -> List[Tuple[str, str]]: + ) -> list[tuple[str, str]]: """从文本获取相关记忆 - 兼容原有接口""" if not self.is_initialized or not self.memory_system: return [] @@ -152,8 +150,8 @@ class MemoryManager: return [] async def get_memory_from_topic( - self, valid_keywords: List[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 - ) -> List[Tuple[str, str]]: + self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 + ) -> list[tuple[str, str]]: """从关键词获取记忆 - 兼容原有接口""" if not self.is_initialized or not self.memory_system: return [] @@ -208,8 +206,8 @@ class MemoryManager: return [] async def process_conversation( - self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: Optional[float] = None - ) -> List[MemoryChunk]: + self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float | None = None + ) -> list[MemoryChunk]: """处理对话并构建记忆 - 新增功能""" if not self.is_initialized or not self.memory_system: return [] @@ -235,8 +233,8 @@ class MemoryManager: return [] async def get_enhanced_memory_context( - self, query_text: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: int = 5 - ) -> List[MemoryResult]: + self, query_text: str, user_id: str, context: dict[str, Any] | None = None, limit: int = 5 + ) -> list[MemoryResult]: """获取增强记忆上下文 - 新增功能""" if not self.is_initialized or not self.memory_system: return [] @@ -267,7 +265,7 @@ class MemoryManager: logger.error(f"get_enhanced_memory_context 失败: {e}") return [] - def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]: + def _format_memory_chunk(self, memory: MemoryChunk) -> tuple[str, dict[str, Any]]: """将记忆块转换为更易读的文本描述""" structure = memory.content.to_dict() if memory.display: @@ -289,7 +287,7 @@ class MemoryManager: return formatted, structure - def _format_subject(self, subject: Optional[str], memory: MemoryChunk) -> str: + def _format_subject(self, subject: str | None, memory: MemoryChunk) -> str: if not subject: return "该用户" @@ -299,7 +297,7 @@ class MemoryManager: return "该聊天" return self._clean_text(subject) - def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> Optional[str]: + def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> str | None: predicate = (predicate or "").strip() obj_value = obj @@ -446,10 +444,10 @@ class MemoryManager: text = self._truncate(str(obj).strip()) return self._clean_text(text) - def _extract_from_object(self, obj: Any, keys: List[str]) -> Optional[str]: + def _extract_from_object(self, obj: Any, keys: list[str]) -> str | None: if isinstance(obj, dict): for key in keys: - if key in obj and obj[key]: + if obj.get(key): value = obj[key] if isinstance(value, (dict, list)): return self._clean_text(self._format_object(value)) diff --git a/src/chat/memory_system/memory_metadata_index.py b/src/chat/memory_system/memory_metadata_index.py index ad27971a6..4b405aad6 100644 --- a/src/chat/memory_system/memory_metadata_index.py +++ b/src/chat/memory_system/memory_metadata_index.py @@ -1,15 +1,15 @@ -# -*- coding: utf-8 -*- """ 记忆元数据索引管理器 使用JSON文件存储记忆元数据,支持快速模糊搜索和过滤 """ -import orjson import threading -from pathlib import Path -from typing import Dict, List, Optional, Set, Any -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass from datetime import datetime +from pathlib import Path +from typing import Any + +import orjson from src.common.logger import get_logger @@ -25,10 +25,10 @@ class MemoryMetadataIndexEntry: # 分类信息 memory_type: str # MemoryType.value - subjects: List[str] # 主语列表 - objects: List[str] # 宾语列表 - keywords: List[str] # 关键词列表 - tags: List[str] # 标签列表 + subjects: list[str] # 主语列表 + objects: list[str] # 宾语列表 + keywords: list[str] # 关键词列表 + tags: list[str] # 标签列表 # 数值字段(用于范围过滤) importance: int # ImportanceLevel.value (1-4) @@ -37,8 +37,8 @@ class MemoryMetadataIndexEntry: access_count: int # 访问次数 # 可选字段 - chat_id: Optional[str] = None - content_preview: Optional[str] = None # 内容预览(前100字符) + chat_id: str | None = None + content_preview: str | None = None # 内容预览(前100字符) class MemoryMetadataIndex: @@ -46,13 +46,13 @@ class MemoryMetadataIndex: def __init__(self, index_file: str = "data/memory_metadata_index.json"): self.index_file = Path(index_file) - self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry + self.index: dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry # 倒排索引(用于快速查找) - self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids} - self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids} - self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids} - self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids} + self.type_index: dict[str, set[str]] = {} # type -> {memory_ids} + self.subject_index: dict[str, set[str]] = {} # subject -> {memory_ids} + self.keyword_index: dict[str, set[str]] = {} # keyword -> {memory_ids} + self.tag_index: dict[str, set[str]] = {} # tag -> {memory_ids} self.lock = threading.RLock() @@ -178,7 +178,7 @@ class MemoryMetadataIndex: self._remove_from_inverted_indices(memory_id) del self.index[memory_id] - def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]): + def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]): """批量添加或更新""" with self.lock: for entry in entries: @@ -191,18 +191,18 @@ class MemoryMetadataIndex: def search( self, - memory_types: Optional[List[str]] = None, - subjects: Optional[List[str]] = None, - keywords: Optional[List[str]] = None, - tags: Optional[List[str]] = None, - importance_min: Optional[int] = None, - importance_max: Optional[int] = None, - created_after: Optional[float] = None, - created_before: Optional[float] = None, - user_id: Optional[str] = None, - limit: Optional[int] = None, + memory_types: list[str] | None = None, + subjects: list[str] | None = None, + keywords: list[str] | None = None, + tags: list[str] | None = None, + importance_min: int | None = None, + importance_max: int | None = None, + created_after: float | None = None, + created_before: float | None = None, + user_id: str | None = None, + limit: int | None = None, flexible_mode: bool = True, # 新增:灵活匹配模式 - ) -> List[str]: + ) -> list[str]: """ 搜索符合条件的记忆ID列表(支持模糊匹配) @@ -237,14 +237,14 @@ class MemoryMetadataIndex: def _search_flexible( self, - memory_types: Optional[List[str]] = None, - subjects: Optional[List[str]] = None, - created_after: Optional[float] = None, - created_before: Optional[float] = None, - user_id: Optional[str] = None, - limit: Optional[int] = None, + memory_types: list[str] | None = None, + subjects: list[str] | None = None, + created_after: float | None = None, + created_before: float | None = None, + user_id: str | None = None, + limit: int | None = None, **kwargs, # 接受但不使用的参数 - ) -> List[str]: + ) -> list[str]: """ 灵活搜索模式:2/4项匹配即可,支持部分匹配 @@ -374,20 +374,20 @@ class MemoryMetadataIndex: def _search_strict( self, - memory_types: Optional[List[str]] = None, - subjects: Optional[List[str]] = None, - keywords: Optional[List[str]] = None, - tags: Optional[List[str]] = None, - importance_min: Optional[int] = None, - importance_max: Optional[int] = None, - created_after: Optional[float] = None, - created_before: Optional[float] = None, - user_id: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[str]: + memory_types: list[str] | None = None, + subjects: list[str] | None = None, + keywords: list[str] | None = None, + tags: list[str] | None = None, + importance_min: int | None = None, + importance_max: int | None = None, + created_after: float | None = None, + created_before: float | None = None, + user_id: str | None = None, + limit: int | None = None, + ) -> list[str]: """严格搜索模式(原有逻辑)""" # 初始候选集(所有记忆) - candidate_ids: Optional[Set[str]] = None + candidate_ids: set[str] | None = None # 用户过滤(必选) if user_id: @@ -471,11 +471,11 @@ class MemoryMetadataIndex: return result_ids - def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]: + def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None: """获取单个索引条目""" return self.index.get(memory_id) - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: """获取索引统计信息""" with self.lock: return { diff --git a/src/chat/memory_system/memory_query_planner.py b/src/chat/memory_system/memory_query_planner.py index a8be9d951..bbedf766c 100644 --- a/src/chat/memory_system/memory_query_planner.py +++ b/src/chat/memory_system/memory_query_planner.py @@ -1,11 +1,10 @@ -# -*- coding: utf-8 -*- """记忆检索查询规划器""" from __future__ import annotations import re from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any import orjson @@ -21,16 +20,16 @@ class MemoryQueryPlan: """查询规划结果""" semantic_query: str - memory_types: List[MemoryType] = field(default_factory=list) - subject_includes: List[str] = field(default_factory=list) - object_includes: List[str] = field(default_factory=list) - required_keywords: List[str] = field(default_factory=list) - optional_keywords: List[str] = field(default_factory=list) - owner_filters: List[str] = field(default_factory=list) + memory_types: list[MemoryType] = field(default_factory=list) + subject_includes: list[str] = field(default_factory=list) + object_includes: list[str] = field(default_factory=list) + required_keywords: list[str] = field(default_factory=list) + optional_keywords: list[str] = field(default_factory=list) + owner_filters: list[str] = field(default_factory=list) recency_preference: str = "any" limit: int = 10 - emphasis: Optional[str] = None - raw_plan: Dict[str, Any] = field(default_factory=dict) + emphasis: str | None = None + raw_plan: dict[str, Any] = field(default_factory=dict) def ensure_defaults(self, fallback_query: str, default_limit: int) -> None: if not self.semantic_query: @@ -46,11 +45,11 @@ class MemoryQueryPlan: class MemoryQueryPlanner: """基于小模型的记忆检索查询规划器""" - def __init__(self, planner_model: Optional[LLMRequest], default_limit: int = 10): + def __init__(self, planner_model: LLMRequest | None, default_limit: int = 10): self.model = planner_model self.default_limit = default_limit - async def plan_query(self, query_text: str, context: Dict[str, Any]) -> MemoryQueryPlan: + async def plan_query(self, query_text: str, context: dict[str, Any]) -> MemoryQueryPlan: if not self.model: logger.debug("未提供查询规划模型,使用默认规划") return self._default_plan(query_text) @@ -82,10 +81,10 @@ class MemoryQueryPlanner: def _default_plan(self, query_text: str) -> MemoryQueryPlan: return MemoryQueryPlan(semantic_query=query_text, limit=self.default_limit) - def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan: + def _parse_plan_dict(self, data: dict[str, Any], fallback_query: str) -> MemoryQueryPlan: semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query - def _collect_list(key: str) -> List[str]: + def _collect_list(key: str) -> list[str]: value = data.get(key) if isinstance(value, str): return [value] @@ -94,7 +93,7 @@ class MemoryQueryPlanner: return [] memory_type_values = _collect_list("memory_types") - memory_types: List[MemoryType] = [] + memory_types: list[MemoryType] = [] for item in memory_type_values: if not item: continue @@ -123,7 +122,7 @@ class MemoryQueryPlanner: ) return plan - def _build_prompt(self, query_text: str, context: Dict[str, Any]) -> str: + def _build_prompt(self, query_text: str, context: dict[str, Any]) -> str: participants = context.get("participants") or context.get("speaker_names") or [] if isinstance(participants, str): participants = [participants] @@ -206,7 +205,7 @@ class MemoryQueryPlanner: 请直接输出符合要求的 JSON 对象,禁止添加额外文本或 Markdown 代码块。 """ - def _extract_json_payload(self, response: str) -> Optional[str]: + def _extract_json_payload(self, response: str) -> str | None: if not response: return None diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index 4a275babd..5236da62a 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 精准记忆系统核心模块 1. 基于文档设计的高效记忆构建、存储与召回优化系统,覆盖构建、向量化与多阶段检索全流程。 @@ -6,26 +5,27 @@ """ import asyncio -import time -import orjson -import re import hashlib -from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING +import re +import time +from dataclasses import asdict, dataclass from datetime import datetime, timedelta -from dataclasses import dataclass, asdict from enum import Enum +from typing import TYPE_CHECKING, Any + +import orjson -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config, global_config -from src.chat.memory_system.memory_chunk import MemoryChunk from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractionError +from src.chat.memory_system.memory_chunk import MemoryChunk from src.chat.memory_system.memory_fusion import MemoryFusionEngine from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest if TYPE_CHECKING: - from src.common.data_models.database_data_model import DatabaseMessages from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine + from src.common.data_models.database_data_model import DatabaseMessages logger = get_logger(__name__) @@ -121,7 +121,7 @@ class MemorySystemConfig: class MemorySystem: """精准记忆系统核心类""" - def __init__(self, llm_model: Optional[LLMRequest] = None, config: Optional[MemorySystemConfig] = None): + def __init__(self, llm_model: LLMRequest | None = None, config: MemorySystemConfig | None = None): self.config = config or MemorySystemConfig.from_global_config() self.llm_model = llm_model self.status = MemorySystemStatus.INITIALIZING @@ -131,7 +131,7 @@ class MemorySystem: self.fusion_engine: MemoryFusionEngine = None self.unified_storage = None # 统一存储系统 self.query_planner: MemoryQueryPlanner = None - self.forgetting_engine: Optional[MemoryForgettingEngine] = None + self.forgetting_engine: MemoryForgettingEngine | None = None # LLM模型 self.value_assessment_model: LLMRequest = None @@ -143,10 +143,10 @@ class MemorySystem: self.last_retrieval_time = None # 构建节流记录 - self._last_memory_build_times: Dict[str, float] = {} + self._last_memory_build_times: dict[str, float] = {} # 记忆指纹缓存,用于快速检测重复记忆 - self._memory_fingerprints: Dict[str, str] = {} + self._memory_fingerprints: dict[str, str] = {} logger.info("MemorySystem 初始化开始") @@ -210,7 +210,7 @@ class MemorySystem: raise # 初始化遗忘引擎 - from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig + from src.chat.memory_system.memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine # 从全局配置创建遗忘引擎配置 forgetting_config = ForgettingConfig( @@ -241,7 +241,7 @@ class MemorySystem: self.forgetting_engine = MemoryForgettingEngine(forgetting_config) planner_task_config = getattr(model_config.model_task_config, "utils_small", None) - planner_model: Optional[LLMRequest] = None + planner_model: LLMRequest | None = None try: planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner") except Exception as planner_exc: @@ -261,8 +261,8 @@ class MemorySystem: raise async def retrieve_memories_for_building( - self, query_text: str, user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5 - ) -> List[MemoryChunk]: + self, query_text: str, user_id: str | None = None, context: dict[str, Any] | None = None, limit: int = 5 + ) -> list[MemoryChunk]: """在构建记忆时检索相关记忆,使用统一存储系统 Args: @@ -302,8 +302,8 @@ class MemorySystem: return [] async def build_memory_from_conversation( - self, conversation_text: str, context: Dict[str, Any], timestamp: Optional[float] = None - ) -> List[MemoryChunk]: + self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None + ) -> list[MemoryChunk]: """从对话中构建记忆 Args: @@ -318,8 +318,8 @@ class MemorySystem: self.status = MemorySystemStatus.BUILDING start_time = time.time() - build_scope_key: Optional[str] = None - build_marker_time: Optional[float] = None + build_scope_key: str | None = None + build_marker_time: float | None = None try: normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp) @@ -408,7 +408,7 @@ class MemorySystem: logger.error(f"❌ 记忆构建失败: {e}", exc_info=True) raise - def _log_memory_preview(self, memories: List[MemoryChunk]) -> None: + def _log_memory_preview(self, memories: list[MemoryChunk]) -> None: """在控制台输出记忆预览,便于人工检查""" if not memories: logger.info("📝 本次未生成新的记忆") @@ -425,12 +425,12 @@ class MemorySystem: f"置信度={memory.metadata.confidence.name} | 内容={text}" ) - async def _collect_fusion_candidates(self, new_memories: List[MemoryChunk]) -> List[MemoryChunk]: + async def _collect_fusion_candidates(self, new_memories: list[MemoryChunk]) -> list[MemoryChunk]: """收集与新记忆相似的现有记忆,便于融合去重""" if not new_memories: return [] - candidate_ids: Set[str] = set() + candidate_ids: set[str] = set() new_memory_ids = {memory.memory_id for memory in new_memories if memory and getattr(memory, "memory_id", None)} # 基于指纹的直接匹配 @@ -493,7 +493,7 @@ class MemorySystem: continue candidate_ids.add(memory_id) - existing_candidates: List[MemoryChunk] = [] + existing_candidates: list[MemoryChunk] = [] cache = self.unified_storage.memory_cache if self.unified_storage else {} for candidate_id in candidate_ids: if candidate_id in new_memory_ids: @@ -511,7 +511,7 @@ class MemorySystem: return existing_candidates - async def process_conversation_memory(self, context: Dict[str, Any]) -> Dict[str, Any]: + async def process_conversation_memory(self, context: dict[str, Any]) -> dict[str, Any]: """对外暴露的对话记忆处理接口,仅依赖上下文信息""" start_time = time.time() @@ -559,12 +559,12 @@ class MemorySystem: async def retrieve_relevant_memories( self, - query_text: Optional[str] = None, - user_id: Optional[str] = None, - context: Optional[Dict[str, Any]] = None, + query_text: str | None = None, + user_id: str | None = None, + context: dict[str, Any] | None = None, limit: int = 5, **kwargs, - ) -> List[MemoryChunk]: + ) -> list[MemoryChunk]: """检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)""" raw_query = query_text or kwargs.get("query") if not raw_query: @@ -750,7 +750,7 @@ class MemorySystem: raise @staticmethod - def _extract_json_payload(response: str) -> Optional[str]: + def _extract_json_payload(response: str) -> str | None: """从模型响应中提取JSON部分,兼容Markdown代码块等格式""" if not response: return None @@ -773,10 +773,10 @@ class MemorySystem: return stripped if stripped.startswith("{") and stripped.endswith("}") else None def _normalize_context( - self, raw_context: Optional[Dict[str, Any]], user_id: Optional[str], timestamp: Optional[float] - ) -> Dict[str, Any]: + self, raw_context: dict[str, Any] | None, user_id: str | None, timestamp: float | None + ) -> dict[str, Any]: """标准化上下文,确保必备字段存在且格式正确""" - context: Dict[str, Any] = {} + context: dict[str, Any] = {} if raw_context: try: context = dict(raw_context) @@ -822,7 +822,7 @@ class MemorySystem: return context - async def _build_enhanced_query_context(self, raw_query: str, normalized_context: Dict[str, Any]) -> Dict[str, Any]: + async def _build_enhanced_query_context(self, raw_query: str, normalized_context: dict[str, Any]) -> dict[str, Any]: """构建包含未读消息综合上下文的增强查询上下文 Args: @@ -861,7 +861,7 @@ class MemorySystem: return enhanced_context - async def _collect_unread_messages_context(self, stream_id: str) -> Optional[Dict[str, Any]]: + async def _collect_unread_messages_context(self, stream_id: str) -> dict[str, Any] | None: """收集未读消息的综合上下文信息 Args: @@ -953,7 +953,7 @@ class MemorySystem: logger.warning(f"收集未读消息上下文失败: {e}", exc_info=True) return None - def _build_unread_context_summary(self, messages_summary: List[Dict[str, Any]]) -> str: + def _build_unread_context_summary(self, messages_summary: list[dict[str, Any]]) -> str: """构建未读消息的文本摘要 Args: @@ -974,7 +974,7 @@ class MemorySystem: return " | ".join(summary_parts) - async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str: + async def _resolve_conversation_context(self, fallback_text: str, context: dict[str, Any] | None) -> str: """使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本""" if not context: return fallback_text @@ -1043,11 +1043,11 @@ class MemorySystem: # 回退到传入文本 return fallback_text - def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]: + def _get_build_scope_key(self, context: dict[str, Any], user_id: str | None) -> str | None: """确定用于节流控制的记忆构建作用域""" return "global_scope" - def _determine_history_limit(self, context: Dict[str, Any]) -> int: + def _determine_history_limit(self, context: dict[str, Any]) -> int: """确定历史消息获取数量,限制在30-50之间""" default_limit = 40 candidate = context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit") @@ -1065,12 +1065,12 @@ class MemorySystem: return history_limit - def _format_history_messages(self, messages: List["DatabaseMessages"]) -> Optional[str]: + def _format_history_messages(self, messages: list["DatabaseMessages"]) -> str | None: """将历史消息格式化为可供LLM处理的多轮对话文本""" if not messages: return None - lines: List[str] = [] + lines: list[str] = [] for msg in messages: try: content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None) @@ -1105,7 +1105,7 @@ class MemorySystem: return "\n".join(lines) if lines else None - async def _assess_information_value(self, text: str, context: Dict[str, Any]) -> float: + async def _assess_information_value(self, text: str, context: dict[str, Any]) -> float: """评估信息价值 Args: @@ -1201,7 +1201,7 @@ class MemorySystem: logger.error(f"信息价值评估失败: {e}", exc_info=True) return 0.5 # 默认中等价值 - async def _store_memories_unified(self, memory_chunks: List[MemoryChunk]) -> int: + async def _store_memories_unified(self, memory_chunks: list[MemoryChunk]) -> int: """使用统一存储系统存储记忆块""" if not memory_chunks or not self.unified_storage: return 0 @@ -1222,7 +1222,7 @@ class MemorySystem: return 0 # 保留原有方法以兼容旧代码 - async def _store_memories(self, memory_chunks: List[MemoryChunk]) -> int: + async def _store_memories(self, memory_chunks: list[MemoryChunk]) -> int: """兼容性方法:重定向到统一存储""" return await self._store_memories_unified(memory_chunks) @@ -1271,7 +1271,7 @@ class MemorySystem: key = self._fingerprint_key(memory.user_id, fingerprint) self._memory_fingerprints[key] = memory.memory_id - def _register_memory_fingerprints(self, memories: List[MemoryChunk]) -> None: + def _register_memory_fingerprints(self, memories: list[MemoryChunk]) -> None: for memory in memories: fingerprint = self._build_memory_fingerprint(memory) key = self._fingerprint_key(memory.user_id, fingerprint) @@ -1302,9 +1302,9 @@ class MemorySystem: @staticmethod def _fingerprint_key(user_id: str, fingerprint: str) -> str: - return f"{str(user_id)}:{fingerprint}" + return f"{user_id!s}:{fingerprint}" - def get_system_stats(self) -> Dict[str, Any]: + def get_system_stats(self) -> dict[str, Any]: """获取系统统计信息""" return { "status": self.status.value, @@ -1314,7 +1314,7 @@ class MemorySystem: "config": asdict(self.config), } - def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float: + def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: dict[str, Any]) -> float: """根据查询和上下文为记忆计算匹配分数""" tokens_query = self._tokenize_text(query_text) tokens_memory = self._tokenize_text(memory.text_content) @@ -1338,7 +1338,7 @@ class MemorySystem: final_score = base_score * 0.7 + keyword_overlap * 0.15 + importance_boost + confidence_boost return max(0.0, min(1.0, final_score)) - def _tokenize_text(self, text: str) -> Set[str]: + def _tokenize_text(self, text: str) -> set[str]: """简单分词,兼容中英文""" if not text: return set() @@ -1450,7 +1450,7 @@ def get_memory_system() -> MemorySystem: return memory_system -async def initialize_memory_system(llm_model: Optional[LLMRequest] = None): +async def initialize_memory_system(llm_model: LLMRequest | None = None): """初始化全局记忆系统""" global memory_system if memory_system is None: diff --git a/src/chat/memory_system/vector_memory_storage_v2.py b/src/chat/memory_system/vector_memory_storage_v2.py index 3c924ba30..7fcae93c8 100644 --- a/src/chat/memory_system/vector_memory_storage_v2.py +++ b/src/chat/memory_system/vector_memory_storage_v2.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 基于Vector DB的统一记忆存储系统 V2 使用ChromaDB作为底层存储,替代JSON存储方式 @@ -11,20 +10,21 @@ - 自动清理过期记忆 """ -import time -import orjson import asyncio import threading -from typing import Dict, List, Optional, Tuple, Any +import time from dataclasses import dataclass from datetime import datetime +from typing import Any -from src.common.logger import get_logger -from src.common.vector_db import vector_db_service -from src.chat.utils.utils import get_embedding -from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel +import orjson + +from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndex, MemoryMetadataIndexEntry +from src.chat.utils.utils import get_embedding +from src.common.logger import get_logger +from src.common.vector_db import vector_db_service logger = get_logger(__name__) @@ -32,7 +32,7 @@ logger = get_logger(__name__) _ENUM_MAPPINGS_CACHE = {} -def _build_enum_mapping(enum_class: type) -> Dict[str, Any]: +def _build_enum_mapping(enum_class: type) -> dict[str, Any]: """构建枚举类的完整映射表 Args: @@ -145,7 +145,7 @@ class VectorMemoryStorage: """基于Vector DB的记忆存储系统""" - def __init__(self, config: Optional[VectorStorageConfig] = None): + def __init__(self, config: VectorStorageConfig | None = None): # 默认从全局配置读取,如果没有传入config if config is None: try: @@ -163,15 +163,15 @@ class VectorMemoryStorage: self.vector_db_service = vector_db_service # 内存缓存 - self.memory_cache: Dict[str, MemoryChunk] = {} - self.cache_timestamps: Dict[str, float] = {} + self.memory_cache: dict[str, MemoryChunk] = {} + self.cache_timestamps: dict[str, float] = {} self._cache = self.memory_cache # 别名,兼容旧代码 # 元数据索引管理器(JSON文件索引) self.metadata_index = MemoryMetadataIndex() # 遗忘引擎 - self.forgetting_engine: Optional[MemoryForgettingEngine] = None + self.forgetting_engine: MemoryForgettingEngine | None = None if self.config.enable_forgetting: self.forgetting_engine = MemoryForgettingEngine() @@ -267,7 +267,7 @@ class VectorMemoryStorage: except Exception as e: logger.error(f"自动清理失败: {e}") - def _memory_to_vector_format(self, memory: MemoryChunk) -> Dict[str, Any]: + def _memory_to_vector_format(self, memory: MemoryChunk) -> dict[str, Any]: """将MemoryChunk转换为向量存储格式""" try: # 获取memory_id @@ -323,7 +323,7 @@ class VectorMemoryStorage: logger.error(f"转换记忆 {memory_id} 到向量格式失败: {e}", exc_info=True) raise - def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]: + def _vector_result_to_memory(self, document: str, metadata: dict[str, Any]) -> MemoryChunk | None: """将Vector DB结果转换为MemoryChunk""" try: # 从元数据中恢复完整记忆 @@ -440,7 +440,7 @@ class VectorMemoryStorage: logger.warning(f"不支持的{enum_class.__name__}值类型: {type(value)},使用默认值") return default - def _get_from_cache(self, memory_id: str) -> Optional[MemoryChunk]: + def _get_from_cache(self, memory_id: str) -> MemoryChunk | None: """从缓存获取记忆""" if not self.config.enable_caching: return None @@ -472,7 +472,7 @@ class VectorMemoryStorage: self.memory_cache[memory_id] = memory self.cache_timestamps[memory_id] = time.time() - async def store_memories(self, memories: List[MemoryChunk]) -> int: + async def store_memories(self, memories: list[MemoryChunk]) -> int: """批量存储记忆""" if not memories: return 0 @@ -603,11 +603,11 @@ class VectorMemoryStorage: self, query_text: str, limit: int = 10, - similarity_threshold: Optional[float] = None, - filters: Optional[Dict[str, Any]] = None, + similarity_threshold: float | None = None, + filters: dict[str, Any] | None = None, # 新增:元数据过滤参数(用于JSON索引粗筛) - metadata_filters: Optional[Dict[str, Any]] = None, - ) -> List[Tuple[MemoryChunk, float]]: + metadata_filters: dict[str, Any] | None = None, + ) -> list[tuple[MemoryChunk, float]]: """ 搜索相似记忆(混合索引模式) @@ -632,7 +632,7 @@ class VectorMemoryStorage: try: # === 阶段一:JSON元数据粗筛(可选) === - candidate_ids: Optional[List[str]] = None + candidate_ids: list[str] | None = None if metadata_filters: logger.debug(f"[JSON元数据粗筛] 开始,过滤条件: {metadata_filters}") candidate_ids = self.metadata_index.search( @@ -746,7 +746,7 @@ class VectorMemoryStorage: logger.error(f"搜索相似记忆失败: {e}") return [] - async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]: + async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None: """根据ID获取记忆""" # 首先尝试从缓存获取 memory = self._get_from_cache(memory_id) @@ -772,7 +772,7 @@ class VectorMemoryStorage: return None - async def get_memories_by_filters(self, filters: Dict[str, Any], limit: int = 100) -> List[MemoryChunk]: + async def get_memories_by_filters(self, filters: dict[str, Any], limit: int = 100) -> list[MemoryChunk]: """根据过滤条件获取记忆""" try: results = vector_db_service.get(collection_name=self.config.memory_collection, where=filters, limit=limit) @@ -848,7 +848,7 @@ class VectorMemoryStorage: logger.error(f"删除记忆 {memory_id} 失败: {e}") return False - async def delete_memories_by_filters(self, filters: Dict[str, Any]) -> int: + async def delete_memories_by_filters(self, filters: dict[str, Any]) -> int: """根据过滤条件批量删除记忆""" try: # 先获取要删除的记忆ID @@ -880,7 +880,7 @@ class VectorMemoryStorage: logger.error(f"批量删除记忆失败: {e}") return 0 - async def perform_forgetting_check(self) -> Dict[str, Any]: + async def perform_forgetting_check(self) -> dict[str, Any]: """执行遗忘检查""" if not self.forgetting_engine: return {"error": "遗忘引擎未启用"} @@ -925,7 +925,7 @@ class VectorMemoryStorage: logger.error(f"执行遗忘检查失败: {e}") return {"error": str(e)} - def get_storage_stats(self) -> Dict[str, Any]: + def get_storage_stats(self) -> dict[str, Any]: """获取存储统计信息""" try: current_total = vector_db_service.count(self.config.memory_collection) @@ -960,7 +960,7 @@ class VectorMemoryStorage: _global_vector_storage = None -def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> VectorMemoryStorage: +def get_vector_memory_storage(config: VectorStorageConfig | None = None) -> VectorMemoryStorage: """获取全局Vector记忆存储实例""" global _global_vector_storage @@ -974,15 +974,15 @@ def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> V class VectorMemoryStorageAdapter: """适配器类,提供与原UnifiedMemoryStorage兼容的接口""" - def __init__(self, config: Optional[VectorStorageConfig] = None): + def __init__(self, config: VectorStorageConfig | None = None): self.storage = VectorMemoryStorage(config) - async def store_memories(self, memories: List[MemoryChunk]) -> int: + async def store_memories(self, memories: list[MemoryChunk]) -> int: return await self.storage.store_memories(memories) async def search_similar_memories( - self, query_text: str, limit: int = 10, scope_id: Optional[str] = None, filters: Optional[Dict[str, Any]] = None - ) -> List[Tuple[str, float]]: + self, query_text: str, limit: int = 10, scope_id: str | None = None, filters: dict[str, Any] | None = None + ) -> list[tuple[str, float]]: results = await self.storage.search_similar_memories(query_text, limit, filters=filters) # 转换为原格式:(memory_id, similarity) return [ @@ -990,7 +990,7 @@ class VectorMemoryStorageAdapter: for memory, similarity in results ] - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: return self.storage.get_storage_stats() diff --git a/src/chat/message_manager/__init__.py b/src/chat/message_manager/__init__.py index fe5e90785..c8bd18a08 100644 --- a/src/chat/message_manager/__init__.py +++ b/src/chat/message_manager/__init__.py @@ -3,14 +3,14 @@ 提供统一的消息管理、上下文管理和流循环调度功能 """ -from .message_manager import MessageManager, message_manager from .context_manager import SingleStreamContextManager from .distribution_manager import StreamLoopManager, stream_loop_manager +from .message_manager import MessageManager, message_manager __all__ = [ "MessageManager", - "message_manager", "SingleStreamContextManager", "StreamLoopManager", + "message_manager", "stream_loop_manager", ] diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index 5f3212065..ceefa99b2 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -6,13 +6,14 @@ import asyncio import time -from typing import Dict, List, Optional, Any +from typing import Any +from src.chat.energy_system import energy_manager +from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.message_manager_data_model import StreamContext from src.common.logger import get_logger from src.config.config import global_config -from src.common.data_models.database_data_model import DatabaseMessages -from src.chat.energy_system import energy_manager + from .distribution_manager import stream_loop_manager logger = get_logger("context_manager") @@ -21,7 +22,7 @@ logger = get_logger("context_manager") class SingleStreamContextManager: """单流上下文管理器 - 每个实例只管理一个 stream 的上下文""" - def __init__(self, stream_id: str, context: StreamContext, max_context_size: Optional[int] = None): + def __init__(self, stream_id: str, context: StreamContext, max_context_size: int | None = None): self.stream_id = stream_id self.context = context @@ -66,7 +67,7 @@ class SingleStreamContextManager: logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True) return False - async def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool: + async def update_message(self, message_id: str, updates: dict[str, Any]) -> bool: """更新上下文中的消息 Args: @@ -84,7 +85,7 @@ class SingleStreamContextManager: logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True) return False - def get_messages(self, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]: + def get_messages(self, limit: int | None = None, include_unread: bool = True) -> list[DatabaseMessages]: """获取上下文消息 Args: @@ -117,7 +118,7 @@ class SingleStreamContextManager: logger.error(f"获取单流上下文消息失败 {self.stream_id}: {e}", exc_info=True) return [] - def get_unread_messages(self) -> List[DatabaseMessages]: + def get_unread_messages(self) -> list[DatabaseMessages]: """获取未读消息""" try: return self.context.get_unread_messages() @@ -125,7 +126,7 @@ class SingleStreamContextManager: logger.error(f"获取单流未读消息失败 {self.stream_id}: {e}", exc_info=True) return [] - def mark_messages_as_read(self, message_ids: List[str]) -> bool: + def mark_messages_as_read(self, message_ids: list[str]) -> bool: """标记消息为已读""" try: if not hasattr(self.context, "mark_message_as_read"): @@ -168,7 +169,7 @@ class SingleStreamContextManager: logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True) return False - def get_statistics(self) -> Dict[str, Any]: + def get_statistics(self) -> dict[str, Any]: """获取流统计信息""" try: current_time = time.time() @@ -285,7 +286,7 @@ class SingleStreamContextManager: logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True) return False - async def update_message_async(self, message_id: str, updates: Dict[str, Any]) -> bool: + async def update_message_async(self, message_id: str, updates: dict[str, Any]) -> bool: """异步实现的 update_message:更新消息并在需要时 await 能量更新。""" try: self.context.update_message_info(message_id, **updates) @@ -327,7 +328,7 @@ class SingleStreamContextManager: """更新流能量""" try: history_messages = self.context.get_history_messages(limit=self.max_context_size) - messages: List[DatabaseMessages] = list(history_messages) + messages: list[DatabaseMessages] = list(history_messages) if include_unread: messages.extend(self.get_unread_messages()) diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index 69f3e662d..152c40362 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -5,12 +5,12 @@ import asyncio import time -from typing import Dict, Optional, Any +from typing import Any +from src.chat.chatter_manager import ChatterManager +from src.chat.energy_system import energy_manager from src.common.logger import get_logger from src.config.config import global_config -from src.chat.energy_system import energy_manager -from src.chat.chatter_manager import ChatterManager from src.plugin_system.apis.chat_api import get_chat_manager logger = get_logger("stream_loop_manager") @@ -19,13 +19,13 @@ logger = get_logger("stream_loop_manager") class StreamLoopManager: """流循环管理器 - 每个流一个独立的无限循环任务""" - def __init__(self, max_concurrent_streams: Optional[int] = None): + def __init__(self, max_concurrent_streams: int | None = None): # 流循环任务管理 - self.stream_loops: Dict[str, asyncio.Task] = {} + self.stream_loops: dict[str, asyncio.Task] = {} self.loop_lock = asyncio.Lock() # 统计信息 - self.stats: Dict[str, Any] = { + self.stats: dict[str, Any] = { "active_streams": 0, "total_loops": 0, "total_process_cycles": 0, @@ -37,13 +37,13 @@ class StreamLoopManager: self.max_concurrent_streams = max_concurrent_streams or global_config.chat.max_concurrent_distributions # 强制分发策略 - self.force_dispatch_unread_threshold: Optional[int] = getattr( + self.force_dispatch_unread_threshold: int | None = getattr( global_config.chat, "force_dispatch_unread_threshold", 20 ) self.force_dispatch_min_interval: float = getattr(global_config.chat, "force_dispatch_min_interval", 0.1) # Chatter管理器 - self.chatter_manager: Optional[ChatterManager] = None + self.chatter_manager: ChatterManager | None = None # 状态控制 self.is_running = False @@ -212,7 +212,7 @@ class StreamLoopManager: logger.info(f"流循环结束: {stream_id}") - async def _get_stream_context(self, stream_id: str) -> Optional[Any]: + async def _get_stream_context(self, stream_id: str) -> Any | None: """获取流上下文 Args: @@ -320,7 +320,7 @@ class StreamLoopManager: logger.debug(f"流 {stream_id} 使用默认间隔: {base_interval:.2f}s ({e})") return base_interval - def get_queue_status(self) -> Dict[str, Any]: + def get_queue_status(self) -> dict[str, Any]: """获取队列状态 Returns: @@ -374,14 +374,14 @@ class StreamLoopManager: except Exception: return 0 - def _needs_force_dispatch_for_context(self, context: Any, unread_count: Optional[int] = None) -> bool: + def _needs_force_dispatch_for_context(self, context: Any, unread_count: int | None = None) -> bool: if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0: return False count = unread_count if unread_count is not None else self._get_unread_count(context) return count > self.force_dispatch_unread_threshold - def get_performance_summary(self) -> Dict[str, Any]: + def get_performance_summary(self) -> dict[str, Any]: """获取性能摘要 Returns: diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index bd55bd43f..78e3363ff 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -6,19 +6,20 @@ import asyncio import random import time -from typing import Dict, Optional, Any, TYPE_CHECKING, List +from typing import TYPE_CHECKING, Any +from src.chat.chatter_manager import ChatterManager from src.chat.message_receive.chat_stream import ChatStream -from src.common.logger import get_logger +from src.chat.planner_actions.action_manager import ChatterActionManager from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats -from src.chat.chatter_manager import ChatterManager -from src.chat.planner_actions.action_manager import ChatterActionManager -from .sleep_manager.sleep_manager import SleepManager -from .sleep_manager.wakeup_manager import WakeUpManager +from src.common.logger import get_logger from src.config.config import global_config from src.plugin_system.apis.chat_api import get_chat_manager + from .distribution_manager import stream_loop_manager +from .sleep_manager.sleep_manager import SleepManager +from .sleep_manager.wakeup_manager import WakeUpManager if TYPE_CHECKING: pass @@ -32,7 +33,7 @@ class MessageManager: def __init__(self, check_interval: float = 5.0): self.check_interval = check_interval # 检查间隔(秒) self.is_running = False - self.manager_task: Optional[asyncio.Task] = None + self.manager_task: asyncio.Task | None = None # 统计信息 self.stats = MessageManagerStats() @@ -125,7 +126,7 @@ class MessageManager: except Exception as e: logger.error(f"更新消息 {message_id} 时发生错误: {e}") - async def bulk_update_messages(self, stream_id: str, updates: List[Dict[str, Any]]) -> int: + async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int: """批量更新消息信息,降低更新频率""" if not updates: return 0 @@ -214,7 +215,7 @@ class MessageManager: except Exception as e: logger.error(f"激活聊天流 {stream_id} 时发生错误: {e}") - def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]: + def get_stream_stats(self, stream_id: str) -> StreamStats | None: """获取聊天流统计""" try: # 通过 ChatManager 获取 ChatStream @@ -243,7 +244,7 @@ class MessageManager: logger.error(f"获取聊天流 {stream_id} 统计时发生错误: {e}") return None - def get_manager_stats(self) -> Dict[str, Any]: + def get_manager_stats(self) -> dict[str, Any]: """获取管理器统计""" return { "total_streams": self.stats.total_streams, @@ -278,7 +279,7 @@ class MessageManager: except Exception as e: logger.error(f"清理不活跃聊天流时发生错误: {e}") - async def _check_and_handle_interruption(self, chat_stream: Optional[ChatStream] = None): + async def _check_and_handle_interruption(self, chat_stream: ChatStream | None = None): """检查并处理消息打断""" if not global_config.chat.interruption_enabled: return diff --git a/src/chat/message_manager/sleep_manager/sleep_manager.py b/src/chat/message_manager/sleep_manager/sleep_manager.py index b0cf79b1b..6aeab8037 100644 --- a/src/chat/message_manager/sleep_manager/sleep_manager.py +++ b/src/chat/message_manager/sleep_manager/sleep_manager.py @@ -1,12 +1,13 @@ import asyncio import random from datetime import datetime, timedelta -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from src.common.logger import get_logger from src.config.config import global_config + from .notification_sender import NotificationSender -from .sleep_state import SleepState, SleepContext +from .sleep_state import SleepContext, SleepState from .time_checker import TimeChecker if TYPE_CHECKING: @@ -92,7 +93,7 @@ class SleepManager: elif current_state == SleepState.WOKEN_UP: self._handle_woken_up(now, is_in_theoretical_sleep, wakeup_manager) - def _handle_awake_to_sleep(self, now: datetime, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]): + def _handle_awake_to_sleep(self, now: datetime, activity: str | None, wakeup_manager: Optional["WakeUpManager"]): """处理从“清醒”到“准备入睡”的状态转换。""" if activity: logger.info(f"进入理论休眠时间 '{activity}',开始进行睡眠决策...") @@ -181,7 +182,7 @@ class SleepManager: self, now: datetime, is_in_theoretical_sleep: bool, - activity: Optional[str], + activity: str | None, wakeup_manager: Optional["WakeUpManager"], ): """处理“正在睡觉”状态下的逻辑。""" diff --git a/src/chat/message_manager/sleep_manager/sleep_state.py b/src/chat/message_manager/sleep_manager/sleep_state.py index 105302169..21a9f11bb 100644 --- a/src/chat/message_manager/sleep_manager/sleep_state.py +++ b/src/chat/message_manager/sleep_manager/sleep_state.py @@ -1,6 +1,5 @@ +from datetime import date, datetime from enum import Enum, auto -from datetime import datetime, date -from typing import Optional from src.common.logger import get_logger from src.manager.local_store_manager import local_storage @@ -29,10 +28,10 @@ class SleepContext: def __init__(self): """初始化睡眠上下文,并从本地存储加载初始状态。""" self.current_state: SleepState = SleepState.AWAKE - self.sleep_buffer_end_time: Optional[datetime] = None + self.sleep_buffer_end_time: datetime | None = None self.total_delayed_minutes_today: float = 0.0 - self.last_sleep_check_date: Optional[date] = None - self.re_sleep_attempt_time: Optional[datetime] = None + self.last_sleep_check_date: date | None = None + self.re_sleep_attempt_time: datetime | None = None self.load() def save(self): diff --git a/src/chat/message_manager/sleep_manager/time_checker.py b/src/chat/message_manager/sleep_manager/time_checker.py index 773830c3a..0ea099039 100644 --- a/src/chat/message_manager/sleep_manager/time_checker.py +++ b/src/chat/message_manager/sleep_manager/time_checker.py @@ -1,6 +1,6 @@ -from datetime import datetime, time, timedelta -from typing import Optional, List, Dict, Any import random +from datetime import datetime, time, timedelta +from typing import Any from src.common.logger import get_logger from src.config.config import global_config @@ -37,11 +37,11 @@ class TimeChecker: return self._daily_sleep_offset, self._daily_wake_offset @staticmethod - def get_today_schedule() -> Optional[List[Dict[str, Any]]]: + def get_today_schedule() -> list[dict[str, Any]] | None: """从全局 ScheduleManager 获取今天的日程安排。""" return schedule_manager.today_schedule - def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]: + def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, str | None]: if global_config.sleep_system.sleep_by_schedule: if self.get_today_schedule(): return self._is_in_schedule_sleep_time(now_time) @@ -50,7 +50,7 @@ class TimeChecker: else: return self._is_in_sleep_time(now_time) - def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]: + def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, str | None]: """检查当前时间是否落在日程表的任何一个睡眠活动中""" sleep_keywords = ["休眠", "睡觉", "梦乡"] today_schedule = self.get_today_schedule() @@ -79,7 +79,7 @@ class TimeChecker: continue return False, None - def _is_in_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]: + def _is_in_sleep_time(self, now_time: time) -> tuple[bool, str | None]: """检查当前时间是否在固定的睡眠时间内(应用偏移量)""" try: start_time_str = global_config.sleep_system.fixed_sleep_time diff --git a/src/chat/message_manager/sleep_manager/wakeup_manager.py b/src/chat/message_manager/sleep_manager/wakeup_manager.py index 5fc68ff41..d390d9d3d 100644 --- a/src/chat/message_manager/sleep_manager/wakeup_manager.py +++ b/src/chat/message_manager/sleep_manager/wakeup_manager.py @@ -1,9 +1,10 @@ import asyncio import time -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING + +from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext from src.common.logger import get_logger from src.config.config import global_config -from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext if TYPE_CHECKING: from .sleep_manager import SleepManager @@ -27,9 +28,9 @@ class WakeUpManager: """ self.sleep_manager = sleep_manager self.context = WakeUpContext() # 使用新的上下文管理器 - self.angry_chat_id: Optional[str] = None + self.angry_chat_id: str | None = None self.last_decay_time = time.time() - self._decay_task: Optional[asyncio.Task] = None + self._decay_task: asyncio.Task | None = None self.is_running = False self.last_log_time = 0 self.log_interval = 30 @@ -104,9 +105,7 @@ class WakeUpManager: logger.debug(f"唤醒度衰减: {old_value:.1f} -> {self.context.wakeup_value:.1f}") self.context.save() - def add_wakeup_value( - self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None - ) -> bool: + def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False, chat_id: str | None = None) -> bool: """ 增加唤醒度值 diff --git a/src/chat/message_receive/__init__.py b/src/chat/message_receive/__init__.py index 44b9eee36..32a3fe9f5 100644 --- a/src/chat/message_receive/__init__.py +++ b/src/chat/message_receive/__init__.py @@ -2,9 +2,8 @@ from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.storage import MessageStorage - __all__ = [ - "get_emoji_manager", - "get_chat_manager", "MessageStorage", + "get_chat_manager", + "get_emoji_manager", ] diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 47d1f26e2..2007d01ec 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -1,25 +1,24 @@ -import traceback import os import re +import traceback +from typing import Any -from typing import Dict, Any, Optional from maim_message import UserInfo -from src.common.logger import get_logger -from src.config.config import global_config -from src.mood.mood_manager import mood_manager # 导入情绪管理器 -from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream -from src.chat.message_receive.message import MessageRecv, MessageRecvS4U -from src.chat.message_receive.storage import MessageStorage -from src.chat.message_manager import message_manager -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.plugin_system.core import component_registry, event_manager, global_announcement_manager -from src.plugin_system.base import BaseCommand, EventType -from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor -from src.chat.utils.utils import is_mentioned_bot_in_message - # 导入反注入系统 from src.chat.antipromptinjector import initialize_anti_injector +from src.chat.message_manager import message_manager +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.message import MessageRecv, MessageRecvS4U +from src.chat.message_receive.storage import MessageStorage +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.chat.utils.utils import is_mentioned_bot_in_message +from src.common.logger import get_logger +from src.config.config import global_config +from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor +from src.mood.mood_manager import mood_manager # 导入情绪管理器 +from src.plugin_system.base import BaseCommand, EventType +from src.plugin_system.core import component_registry, event_manager, global_announcement_manager # 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录) PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) @@ -219,7 +218,7 @@ class ChatBot: logger.error(traceback.format_exc()) try: - await plus_command_instance.send_text(f"命令执行出错: {str(e)}") + await plus_command_instance.send_text(f"命令执行出错: {e!s}") except Exception as send_error: logger.error(f"发送错误消息失败: {send_error}") @@ -286,7 +285,7 @@ class ChatBot: logger.error(traceback.format_exc()) try: - await command_instance.send_text(f"命令执行出错: {str(e)}") + await command_instance.send_text(f"命令执行出错: {e!s}") except Exception as send_error: logger.error(f"发送错误消息失败: {send_error}") @@ -338,7 +337,7 @@ class ChatBot: except Exception as e: logger.error(f"处理适配器响应时出错: {e}") - async def do_s4u(self, message_data: Dict[str, Any]): + async def do_s4u(self, message_data: dict[str, Any]): message = MessageRecvS4U(message_data) group_info = message.message_info.group_info user_info = message.message_info.user_info @@ -359,7 +358,7 @@ class ChatBot: return - async def message_process(self, message_data: Dict[str, Any]) -> None: + async def message_process(self, message_data: dict[str, Any]) -> None: """处理转化后的统一格式消息""" try: # 首先处理可能的切片消息重组 @@ -458,7 +457,7 @@ class ChatBot: # TODO:暂不可用 # 确认从接口发来的message是否有自定义的prompt模板信息 if message.message_info.template_info and not message.message_info.template_info.template_default: - template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore + template_group_name: str | None = message.message_info.template_info.template_name # type: ignore template_items = message.message_info.template_info.template_items async with global_prompt_manager.async_message_scope(template_group_name): if isinstance(template_items, dict): diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 559490694..40833b285 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -1,17 +1,18 @@ import asyncio +import copy import hashlib import time -import copy -from typing import Dict, Optional, TYPE_CHECKING -from rich.traceback import install -from maim_message import GroupInfo, UserInfo +from typing import TYPE_CHECKING -from src.common.logger import get_logger +from maim_message import GroupInfo, UserInfo +from rich.traceback import install from sqlalchemy import select -from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.dialects.mysql import insert as mysql_insert -from src.common.database.sqlalchemy_models import ChatStreams # 新增导入 +from sqlalchemy.dialects.sqlite import insert as sqlite_insert + from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import ChatStreams # 新增导入 +from src.common.logger import get_logger from src.config.config import global_config # 新增导入 # 避免循环导入,使用TYPE_CHECKING进行类型提示 @@ -33,8 +34,8 @@ class ChatStream: stream_id: str, platform: str, user_info: UserInfo, - group_info: Optional[GroupInfo] = None, - data: Optional[dict] = None, + group_info: GroupInfo | None = None, + data: dict | None = None, ): self.stream_id = stream_id self.platform = platform @@ -47,7 +48,7 @@ class ChatStream: # 使用StreamContext替代ChatMessageContext from src.common.data_models.message_manager_data_model import StreamContext - from src.plugin_system.base.component_types import ChatType, ChatMode + from src.plugin_system.base.component_types import ChatMode, ChatType # 创建StreamContext self.stream_context: StreamContext = StreamContext( @@ -133,11 +134,11 @@ class ChatStream: # 恢复stream_context信息 if "stream_context_chat_type" in data: - from src.plugin_system.base.component_types import ChatType, ChatMode + from src.plugin_system.base.component_types import ChatMode, ChatType instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"]) if "stream_context_chat_mode" in data: - from src.plugin_system.base.component_types import ChatType, ChatMode + from src.plugin_system.base.component_types import ChatMode, ChatType instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"]) @@ -163,9 +164,10 @@ class ChatStream: def set_context(self, message: "MessageRecv"): """设置聊天消息上下文""" # 将MessageRecv转换为DatabaseMessages并设置到stream_context - from src.common.data_models.database_data_model import DatabaseMessages import json + from src.common.data_models.database_data_model import DatabaseMessages + # 安全获取message_info中的数据 message_info = getattr(message, "message_info", {}) user_info = getattr(message_info, "user_info", {}) @@ -248,7 +250,7 @@ class ChatStream: f"interest_value: {db_message.interest_value}" ) - def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]: + def _safe_get_actions(self, message: "MessageRecv") -> list | None: """安全获取消息的actions字段""" try: actions = getattr(message, "actions", None) @@ -278,7 +280,7 @@ class ChatStream: logger.warning(f"获取actions字段失败: {e}") return None - def _extract_reply_from_segment(self, segment) -> Optional[str]: + def _extract_reply_from_segment(self, segment) -> str | None: """从消息段中提取reply_to信息""" try: if hasattr(segment, "type") and segment.type == "seglist": @@ -391,8 +393,8 @@ class ChatManager: def __init__(self): if not self._initialized: - self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream - self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message + self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream + self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message # try: # async with get_db_session() as session: # db.connect(reuse_if_open=True) @@ -414,7 +416,7 @@ class ChatManager: await self.load_all_streams() logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流") except Exception as e: - logger.error(f"聊天管理器启动失败: {str(e)}") + logger.error(f"聊天管理器启动失败: {e!s}") async def _auto_save_task(self): """定期自动保存所有聊天流""" @@ -424,7 +426,7 @@ class ChatManager: await self._save_all_streams() logger.info("聊天流自动保存完成") except Exception as e: - logger.error(f"聊天流自动保存失败: {str(e)}") + logger.error(f"聊天流自动保存失败: {e!s}") def register_message(self, message: "MessageRecv"): """注册消息到聊天流""" @@ -437,9 +439,7 @@ class ChatManager: # logger.debug(f"注册消息到聊天流: {stream_id}") @staticmethod - def _generate_stream_id( - platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None - ) -> str: + def _generate_stream_id(platform: str, user_info: UserInfo | None, group_info: GroupInfo | None = None) -> str: """生成聊天流唯一ID""" if not user_info and not group_info: raise ValueError("用户信息或群组信息必须提供") @@ -462,7 +462,7 @@ class ChatManager: return hashlib.md5(key.encode()).hexdigest() async def get_or_create_stream( - self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None + self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None ) -> ChatStream: """获取或创建聊天流 @@ -572,7 +572,7 @@ class ChatManager: await self._save_stream(stream) return stream - def get_stream(self, stream_id: str) -> Optional[ChatStream]: + def get_stream(self, stream_id: str) -> ChatStream | None: """通过stream_id获取聊天流""" stream = self.streams.get(stream_id) if not stream: @@ -582,13 +582,13 @@ class ChatManager: return stream def get_stream_by_info( - self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None - ) -> Optional[ChatStream]: + self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None + ) -> ChatStream | None: """通过信息获取聊天流""" stream_id = self._generate_stream_id(platform, user_info, group_info) return self.streams.get(stream_id) - def get_stream_name(self, stream_id: str) -> Optional[str]: + def get_stream_name(self, stream_id: str) -> str | None: """根据 stream_id 获取聊天流名称""" stream = self.get_stream(stream_id) if not stream: diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index fee932b62..7953ff862 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -1,20 +1,19 @@ import base64 import time -from abc import abstractmethod, ABCMeta +from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Optional, Any +from typing import Any, Optional import urllib3 -from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase +from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo from rich.traceback import install +from src.chat.message_receive.chat_stream import ChatStream from src.chat.utils.utils_image import get_image_manager from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available from src.chat.utils.utils_voice import get_voice_text from src.common.logger import get_logger from src.config.config import global_config -from src.chat.message_receive.chat_stream import ChatStream - install(extra_lines=3) @@ -41,8 +40,8 @@ class Message(MessageBase, metaclass=ABCMeta): message_id: str, chat_stream: "ChatStream", user_info: UserInfo, - message_segment: Optional[Seg] = None, - timestamp: Optional[float] = None, + message_segment: Seg | None = None, + timestamp: float | None = None, reply: Optional["MessageRecv"] = None, processed_plain_text: str = "", ): @@ -264,7 +263,7 @@ class MessageRecv(Message): logger.warning("视频消息中没有base64数据") return "[收到视频消息,但数据异常]" except Exception as e: - logger.error(f"视频处理失败: {str(e)}") + logger.error(f"视频处理失败: {e!s}") import traceback logger.error(f"错误详情: {traceback.format_exc()}") @@ -278,7 +277,7 @@ class MessageRecv(Message): logger.info("未启用视频识别") return "[视频]" except Exception as e: - logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") + logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") return f"[处理失败的{segment.type}消息]" @@ -291,7 +290,7 @@ class MessageRecvS4U(MessageRecv): self.is_superchat = False self.gift_info = None self.gift_name = None - self.gift_count: Optional[str] = None + self.gift_count: str | None = None self.superchat_info = None self.superchat_price = None self.superchat_message_text = None @@ -444,7 +443,7 @@ class MessageRecvS4U(MessageRecv): logger.warning("视频消息中没有base64数据") return "[收到视频消息,但数据异常]" except Exception as e: - logger.error(f"视频处理失败: {str(e)}") + logger.error(f"视频处理失败: {e!s}") import traceback logger.error(f"错误详情: {traceback.format_exc()}") @@ -458,7 +457,7 @@ class MessageRecvS4U(MessageRecv): logger.info("未启用视频识别") return "[视频]" except Exception as e: - logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") + logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") return f"[处理失败的{segment.type}消息]" @@ -471,10 +470,10 @@ class MessageProcessBase(Message): message_id: str, chat_stream: "ChatStream", bot_user_info: UserInfo, - message_segment: Optional[Seg] = None, + message_segment: Seg | None = None, reply: Optional["MessageRecv"] = None, thinking_start_time: float = 0, - timestamp: Optional[float] = None, + timestamp: float | None = None, ): # 调用父类初始化,传递时间戳 super().__init__( @@ -533,9 +532,9 @@ class MessageProcessBase(Message): return f"[回复<{self.reply.message_info.user_info.user_nickname}> 的消息:{self.reply.processed_plain_text}]" # type: ignore return None else: - return f"[{seg.type}:{str(seg.data)}]" + return f"[{seg.type}:{seg.data!s}]" except Exception as e: - logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}") + logger.error(f"处理消息段失败: {e!s}, 类型: {seg.type}, 数据: {seg.data}") return f"[处理失败的{seg.type}消息]" def _generate_detailed_text(self) -> str: @@ -565,7 +564,7 @@ class MessageSending(MessageProcessBase): is_emoji: bool = False, thinking_start_time: float = 0, apply_set_reply_logic: bool = False, - reply_to: Optional[str] = None, + reply_to: str | None = None, ): # 调用父类初始化 super().__init__( @@ -635,11 +634,11 @@ class MessageSet: self.messages.append(message) self.messages.sort(key=lambda x: x.message_info.time) # type: ignore - def get_message_by_index(self, index: int) -> Optional[MessageSending]: + def get_message_by_index(self, index: int) -> MessageSending | None: """通过索引获取消息""" return self.messages[index] if 0 <= index < len(self.messages) else None - def get_message_by_time(self, target_time: float) -> Optional[MessageSending]: + def get_message_by_time(self, target_time: float) -> MessageSending | None: """获取最接近指定时间的消息""" if not self.messages: return None diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 5a654e867..1382adfb8 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,14 +1,15 @@ import re import traceback -import orjson -from typing import Union -from src.common.database.sqlalchemy_models import Messages, Images -from src.common.logger import get_logger -from .chat_stream import ChatStream -from .message import MessageSending, MessageRecv +import orjson +from sqlalchemy import desc, select, update + from src.common.database.sqlalchemy_database_api import get_db_session -from sqlalchemy import select, update, desc +from src.common.database.sqlalchemy_models import Images, Messages +from src.common.logger import get_logger + +from .chat_stream import ChatStream +from .message import MessageRecv, MessageSending logger = get_logger("message_storage") @@ -32,7 +33,7 @@ class MessageStorage: return [] @staticmethod - async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None: + async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None: """存储消息到数据库""" try: # 过滤敏感信息的正则模式 @@ -299,6 +300,7 @@ class MessageStorage: try: async with get_db_session() as session: from sqlalchemy import select, update + from src.common.database.sqlalchemy_models import Messages # 查找需要修复的记录:interest_value为0、null或很小的值 diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index a881549f5..bd23402e2 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -3,12 +3,11 @@ import traceback from rich.traceback import install -from src.common.message.api import get_global_api -from src.common.logger import get_logger from src.chat.message_receive.message import MessageSending from src.chat.message_receive.storage import MessageStorage -from src.chat.utils.utils import truncate_message -from src.chat.utils.utils import calculate_typing_time +from src.chat.utils.utils import calculate_typing_time, truncate_message +from src.common.logger import get_logger +from src.common.message.api import get_global_api install(extra_lines=3) @@ -27,7 +26,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool: return True except Exception as e: - logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}") + logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {e!s}") traceback.print_exc() raise e # 重新抛出其他异常 diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 21a00ee52..9adde80cb 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -1,19 +1,17 @@ import asyncio -import traceback import time -from typing import Dict, Optional, Type, Any, Tuple +import traceback +from typing import Any - -from src.chat.utils.timer_calculator import Timer -from src.person_info.person_info import get_person_info_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.utils.timer_calculator import Timer from src.common.logger import get_logger from src.config.config import global_config -from src.plugin_system.core.component_registry import component_registry -from src.plugin_system.base.component_types import ComponentType, ActionInfo +from src.person_info.person_info import get_person_info_manager +from src.plugin_system.apis import database_api, generator_api, message_api, send_api from src.plugin_system.base.base_action import BaseAction -from src.plugin_system.apis import generator_api, database_api, send_api, message_api - +from src.plugin_system.base.component_types import ActionInfo, ComponentType +from src.plugin_system.core.component_registry import component_registry logger = get_logger("action_manager") @@ -29,7 +27,7 @@ class ChatterActionManager: """初始化动作管理器""" # 当前正在使用的动作集合,默认加载默认动作 - self._using_actions: Dict[str, ActionInfo] = {} + self._using_actions: dict[str, ActionInfo] = {} # 初始化时将默认动作加载到使用中的动作 self._using_actions = component_registry.get_default_actions() @@ -48,8 +46,8 @@ class ChatterActionManager: chat_stream: ChatStream, log_prefix: str, shutting_down: bool = False, - action_message: Optional[dict] = None, - ) -> Optional[BaseAction]: + action_message: dict | None = None, + ) -> BaseAction | None: """ 创建动作处理器实例 @@ -68,7 +66,7 @@ class ChatterActionManager: """ try: # 获取组件类 - 明确指定查询Action类型 - component_class: Type[BaseAction] = component_registry.get_component_class( + component_class: type[BaseAction] = component_registry.get_component_class( action_name, ComponentType.ACTION ) # type: ignore if not component_class: @@ -107,7 +105,7 @@ class ChatterActionManager: logger.error(traceback.format_exc()) return None - def get_using_actions(self) -> Dict[str, ActionInfo]: + def get_using_actions(self) -> dict[str, ActionInfo]: """获取当前正在使用的动作集合""" return self._using_actions.copy() @@ -140,10 +138,10 @@ class ChatterActionManager: self, action_name: str, chat_id: str, - target_message: Optional[dict] = None, + target_message: dict | None = None, reasoning: str = "", - action_data: Optional[dict] = None, - thinking_id: Optional[str] = None, + action_data: dict | None = None, + thinking_id: str | None = None, log_prefix: str = "", clear_unread_messages: bool = True, ) -> Any: @@ -437,10 +435,10 @@ class ChatterActionManager: response_set, loop_start_time, action_message, - cycle_timers: Dict[str, float], + cycle_timers: dict[str, float], thinking_id, actions, - ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: + ) -> tuple[dict[str, Any], str, dict[str, float]]: """ 发送并存储回复信息 @@ -488,7 +486,7 @@ class ChatterActionManager: ) # 构建循环信息 - loop_info: Dict[str, Any] = { + loop_info: dict[str, Any] = { "loop_plan_info": { "action_result": actions, }, diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 4e144d3f4..4f3e4b099 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -1,17 +1,17 @@ -import random import asyncio import hashlib +import random import time -from typing import List, Any, Dict, TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Any +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.planner_actions.action_manager import ChatterActionManager +from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat +from src.common.data_models.message_manager_data_model import StreamContext from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.chat.message_receive.chat_stream import get_chat_manager -from src.common.data_models.message_manager_data_model import StreamContext -from src.chat.planner_actions.action_manager import ChatterActionManager -from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages -from src.plugin_system.base.component_types import ActionInfo, ActionActivationType +from src.plugin_system.base.component_types import ActionActivationType, ActionInfo from src.plugin_system.core.global_announcement_manager import global_announcement_manager if TYPE_CHECKING: @@ -59,18 +59,17 @@ class ActionModifier: """ logger.debug(f"{self.log_prefix}开始完整动作修改流程") - removals_s1: List[Tuple[str, str]] = [] - removals_s2: List[Tuple[str, str]] = [] - removals_s3: List[Tuple[str, str]] = [] + removals_s1: list[tuple[str, str]] = [] + removals_s2: list[tuple[str, str]] = [] + removals_s3: list[tuple[str, str]] = [] self.action_manager.restore_actions() all_actions = self.action_manager.get_using_actions() # === 第0阶段:根据聊天类型过滤动作 === - from src.plugin_system.base.component_types import ChatType - from src.plugin_system.core.component_registry import component_registry - from src.plugin_system.base.component_types import ComponentType from src.chat.utils.utils import get_chat_type_and_target_info + from src.plugin_system.base.component_types import ChatType, ComponentType + from src.plugin_system.core.component_registry import component_registry # 获取聊天类型 is_group_chat, _ = get_chat_type_and_target_info(self.chat_id) @@ -167,8 +166,8 @@ class ActionModifier: logger.info(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}") - def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: StreamContext): - type_mismatched_actions: List[Tuple[str, str]] = [] + def _check_action_associated_types(self, all_actions: dict[str, ActionInfo], chat_context: StreamContext): + type_mismatched_actions: list[tuple[str, str]] = [] for action_name, action_info in all_actions.items(): if action_info.associated_types and not chat_context.check_types(action_info.associated_types): associated_types_str = ", ".join(action_info.associated_types) @@ -179,9 +178,9 @@ class ActionModifier: async def _get_deactivated_actions_by_type( self, - actions_with_info: Dict[str, ActionInfo], + actions_with_info: dict[str, ActionInfo], chat_content: str = "", - ) -> List[tuple[str, str]]: + ) -> list[tuple[str, str]]: """ 根据激活类型过滤,返回需要停用的动作列表及原因 @@ -254,9 +253,9 @@ class ActionModifier: async def _process_llm_judge_actions_parallel( self, - llm_judge_actions: Dict[str, Any], + llm_judge_actions: dict[str, Any], chat_content: str = "", - ) -> Dict[str, bool]: + ) -> dict[str, bool]: """ 并行处理LLM判定actions,支持智能缓存 diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 063fc1bf1..0d4b0b574 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -3,42 +3,41 @@ 使用重构后的统一Prompt系统替换原有的复杂提示词构建逻辑 """ -import traceback -import time import asyncio import random import re - -from typing import List, Optional, Dict, Any, Tuple +import time +import traceback from datetime import datetime -from src.mais4u.mai_think import mai_thinking_manager -from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.individuality.individuality import get_individuality -from src.llm_models.utils_model import LLMRequest -from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending +from typing import Any + +from src.chat.express.expression_selector import expression_selector from src.chat.message_receive.chat_stream import ChatStream -from src.chat.utils.memory_mappings import get_memory_type_chinese_label +from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo from src.chat.message_receive.uni_message_sender import HeartFCSender -from src.chat.utils.timer_calculator import Timer -from src.chat.utils.utils import get_chat_type_and_target_info -from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import ( build_readable_messages, get_raw_msg_before_timestamp_with_chat, replace_user_references_sync, ) -from src.chat.express.expression_selector import expression_selector +from src.chat.utils.memory_mappings import get_memory_type_chinese_label + +# 导入新的统一Prompt系统 +from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager +from src.chat.utils.timer_calculator import Timer +from src.chat.utils.utils import get_chat_type_and_target_info +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.individuality.individuality import get_individuality +from src.llm_models.utils_model import LLMRequest +from src.mais4u.mai_think import mai_thinking_manager # 旧记忆系统已被移除 # 旧记忆系统已被移除 from src.mood.mood_manager import mood_manager from src.person_info.person_info import get_person_info_manager -from src.plugin_system.base.component_types import ActionInfo, EventType from src.plugin_system.apis import llm_api - -# 导入新的统一Prompt系统 -from src.chat.utils.prompt import PromptParameters +from src.plugin_system.base.component_types import ActionInfo, EventType logger = get_logger("replyer") @@ -248,12 +247,12 @@ class DefaultReplyer: self, reply_to: str = "", extra_info: str = "", - available_actions: Optional[Dict[str, ActionInfo]] = None, + available_actions: dict[str, ActionInfo] | None = None, enable_tool: bool = True, from_plugin: bool = True, - stream_id: Optional[str] = None, - reply_message: Optional[Dict[str, Any]] = None, - ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: + stream_id: str | None = None, + reply_message: dict[str, Any] | None = None, + ) -> tuple[bool, dict[str, Any] | None, str | None]: # sourcery skip: merge-nested-ifs """ 回复器 (Replier): 负责生成回复文本的核心逻辑。 @@ -353,7 +352,7 @@ class DefaultReplyer: reason: str = "", reply_to: str = "", return_prompt: bool = False, - ) -> Tuple[bool, Optional[str], Optional[str]]: + ) -> tuple[bool, str | None, str | None]: """ 表达器 (Expressor): 负责重写和优化回复文本。 @@ -722,7 +721,7 @@ class DefaultReplyer: logger.error(f"工具信息获取失败: {e}") return "" - def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: + def _parse_reply_target(self, target_message: str) -> tuple[str, str]: """解析回复目标消息 - 使用共享工具""" from src.chat.utils.prompt import Prompt @@ -731,7 +730,7 @@ class DefaultReplyer: return "未知用户", "(无消息内容)" return Prompt.parse_reply_target(target_message) - async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str: + async def build_keywords_reaction_prompt(self, target: str | None) -> str: """构建关键词反应提示 Args: @@ -766,14 +765,14 @@ class DefaultReplyer: keywords_reaction_prompt += f"{reaction}," break except re.error as e: - logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}") + logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {e!s}") continue except Exception as e: - logger.error(f"关键词检测与反应时发生异常: {str(e)}", exc_info=True) + logger.error(f"关键词检测与反应时发生异常: {e!s}", exc_info=True) return keywords_reaction_prompt - async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]: + async def _time_and_run_task(self, coroutine, name: str) -> tuple[str, Any, float]: """计时并运行异步任务的辅助函数 Args: @@ -790,8 +789,8 @@ class DefaultReplyer: return name, result, duration async def build_s4u_chat_history_prompts( - self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str - ) -> Tuple[str, str]: + self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str + ) -> tuple[str, str]: """ 构建 s4u 风格的已读/未读历史消息 prompt @@ -907,8 +906,8 @@ class DefaultReplyer: return await self._fallback_build_chat_history_prompts(message_list_before_now, target_user_id, sender) async def _fallback_build_chat_history_prompts( - self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str - ) -> Tuple[str, str]: + self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str + ) -> tuple[str, str]: """ 回退的已读/未读历史消息构建方法 """ @@ -1000,15 +999,15 @@ class DefaultReplyer: return read_history_prompt, unread_history_prompt - async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]: + async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]: """为消息获取兴趣度评分""" interest_scores = {} try: + from src.common.data_models.database_data_model import DatabaseMessages from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ( chatter_interest_scoring_system as interest_scoring_system, ) - from src.common.data_models.database_data_model import DatabaseMessages # 转换消息格式 db_messages = [] @@ -1094,9 +1093,9 @@ class DefaultReplyer: self, reply_to: str, extra_info: str = "", - available_actions: Optional[Dict[str, ActionInfo]] = None, + available_actions: dict[str, ActionInfo] | None = None, enable_tool: bool = True, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: dict[str, Any] | None = None, ) -> str: """ 构建回复器上下文 @@ -1417,7 +1416,7 @@ class DefaultReplyer: raw_reply: str, reason: str, reply_to: str, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: dict[str, Any] | None = None, ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if chat_stream = self.chat_stream chat_id = chat_stream.stream_id @@ -1553,7 +1552,7 @@ class DefaultReplyer: is_emoji: bool, thinking_start_time: float, display_message: str, - anchor_message: Optional[MessageRecv] = None, + anchor_message: MessageRecv | None = None, ) -> MessageSending: """构建单个发送消息""" @@ -1644,7 +1643,7 @@ class DefaultReplyer: logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...") return "" except Exception as e: - logger.error(f"获取知识库内容时发生异常: {str(e)}") + logger.error(f"获取知识库内容时发生异常: {e!s}") return "" async def build_relation_info(self, sender: str, target: str): @@ -1660,10 +1659,9 @@ class DefaultReplyer: # 使用AFC关系追踪器获取关系信息 try: - from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker - # 创建关系追踪器实例 from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system + from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker relationship_tracker = ChatterRelationshipTracker(chatter_interest_scoring_system) if relationship_tracker: @@ -1704,7 +1702,7 @@ class DefaultReplyer: logger.error(f"获取AFC关系信息失败: {e}") return f"你与{sender}是普通朋友关系。" - async def _store_chat_memory_async(self, reply_to: str, reply_message: Optional[Dict[str, Any]] = None): + async def _store_chat_memory_async(self, reply_to: str, reply_message: dict[str, Any] | None = None): """ 异步存储聊天记忆(从build_memory_block迁移而来) diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index 2f64ab07f..55a422c1b 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -1,22 +1,20 @@ -from typing import Dict, Optional - -from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.replyer.default_generator import DefaultReplyer +from src.common.logger import get_logger logger = get_logger("ReplyerManager") class ReplyerManager: def __init__(self): - self._repliers: Dict[str, DefaultReplyer] = {} + self._repliers: dict[str, DefaultReplyer] = {} def get_replyer( self, - chat_stream: Optional[ChatStream] = None, - chat_id: Optional[str] = None, + chat_stream: ChatStream | None = None, + chat_id: str | None = None, request_type: str = "replyer", - ) -> Optional[DefaultReplyer]: + ) -> DefaultReplyer | None: """ 获取或创建回复器实例。 diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 8503e369a..65c123338 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -1,18 +1,19 @@ -import time # 导入 time 模块以获取当前时间 import random import re +import time # 导入 time 模块以获取当前时间 +from collections.abc import Callable +from typing import Any -from typing import List, Dict, Any, Tuple, Optional, Callable from rich.traceback import install +from sqlalchemy import and_, select -from src.config.config import global_config -from src.common.message_repository import find_messages, count_messages -from src.common.database.sqlalchemy_models import ActionRecords, Images -from src.person_info.person_info import PersonInfoManager, get_person_info_manager -from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids +from src.chat.utils.utils import assign_message_ids, translate_timestamp_to_human_readable from src.common.database.sqlalchemy_database_api import get_db_session -from sqlalchemy import select, and_ +from src.common.database.sqlalchemy_models import ActionRecords, Images from src.common.logger import get_logger +from src.common.message_repository import count_messages, find_messages +from src.config.config import global_config +from src.person_info.person_info import PersonInfoManager, get_person_info_manager logger = get_logger("chat_message_builder") @@ -22,7 +23,7 @@ install(extra_lines=3) def replace_user_references_sync( content: str, platform: str, - name_resolver: Optional[Callable[[str, str], str]] = None, + name_resolver: Callable[[str, str], str] | None = None, replace_bot_name: bool = True, ) -> str: """ @@ -100,7 +101,7 @@ def replace_user_references_sync( async def replace_user_references_async( content: str, platform: str, - name_resolver: Optional[Callable[[str, str], Any]] = None, + name_resolver: Callable[[str, str], Any] | None = None, replace_bot_name: bool = True, ) -> str: """ @@ -174,7 +175,7 @@ async def replace_user_references_async( async def get_raw_msg_by_timestamp( timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 @@ -194,7 +195,7 @@ async def get_raw_msg_by_timestamp_with_chat( limit_mode: str = "latest", filter_bot=False, filter_command=False, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。 @@ -220,7 +221,7 @@ async def get_raw_msg_by_timestamp_with_chat_inclusive( limit: int = 0, limit_mode: str = "latest", filter_bot=False, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。 @@ -239,10 +240,10 @@ async def get_raw_msg_by_timestamp_with_chat_users( chat_id: str, timestamp_start: float, timestamp_end: float, - person_ids: List[str], + person_ids: list[str], limit: int = 0, limit_mode: str = "latest", -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取某些特定用户在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。 @@ -263,7 +264,7 @@ async def get_actions_by_timestamp_with_chat( timestamp_end: float = time.time(), limit: int = 0, limit_mode: str = "latest", -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表""" from src.common.logger import get_logger @@ -372,7 +373,7 @@ async def get_actions_by_timestamp_with_chat( async def get_actions_by_timestamp_with_chat_inclusive( chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表""" async with get_db_session() as session: if limit > 0: @@ -423,7 +424,7 @@ async def get_actions_by_timestamp_with_chat_inclusive( async def get_raw_msg_by_timestamp_random( timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息 """ @@ -441,7 +442,7 @@ async def get_raw_msg_by_timestamp_random( async def get_raw_msg_by_timestamp_with_users( timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest" -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。 @@ -452,7 +453,7 @@ async def get_raw_msg_by_timestamp_with_users( return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> list[dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ @@ -463,7 +464,7 @@ async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List async def get_raw_msg_before_timestamp_with_chat( chat_id: str, timestamp: float, limit: int = 0 -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ @@ -474,7 +475,7 @@ async def get_raw_msg_before_timestamp_with_chat( async def get_raw_msg_before_timestamp_with_users( timestamp: float, person_ids: list, limit: int = 0 -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ @@ -483,9 +484,7 @@ async def get_raw_msg_before_timestamp_with_users( return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -async def num_new_messages_since( - chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None -) -> int: +async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float | None = None) -> int: """ 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。 如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。 @@ -517,16 +516,16 @@ async def num_new_messages_since_with_users( async def _build_readable_messages_internal( - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, timestamp_mode: str = "relative", truncate: bool = False, - pic_id_mapping: Optional[Dict[str, str]] = None, + pic_id_mapping: dict[str, str] | None = None, pic_counter: int = 1, show_pic: bool = True, - message_id_list: Optional[List[Dict[str, Any]]] = None, -) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]: + message_id_list: list[dict[str, Any]] | None = None, +) -> tuple[str, list[tuple[float, str, str]], dict[str, str], int]: """ 内部辅助函数,构建可读消息字符串和原始消息详情列表。 @@ -545,7 +544,7 @@ async def _build_readable_messages_internal( if not messages: return "", [], pic_id_mapping or {}, pic_counter - message_details_raw: List[Tuple[float, str, str, bool]] = [] + message_details_raw: list[tuple[float, str, str, bool]] = [] # 使用传入的映射字典,如果没有则创建新的 if pic_id_mapping is None: @@ -672,7 +671,7 @@ async def _build_readable_messages_internal( message_details_with_flags.append((timestamp, name, content, is_action)) # 应用截断逻辑 (如果 truncate 为 True) - message_details: List[Tuple[float, str, str, bool]] = [] + message_details: list[tuple[float, str, str, bool]] = [] n_messages = len(message_details_with_flags) if truncate and n_messages > 0: for i, (timestamp, name, content, is_action) in enumerate(message_details_with_flags): @@ -809,7 +808,7 @@ async def _build_readable_messages_internal( ) -async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: +async def build_pic_mapping_info(pic_id_mapping: dict[str, str]) -> str: # sourcery skip: use-contextlib-suppress """ 构建图片映射信息字符串,显示图片的具体描述内容 @@ -847,7 +846,7 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: return "\n".join(mapping_lines) -def build_readable_actions(actions: List[Dict[str, Any]]) -> str: +def build_readable_actions(actions: list[dict[str, Any]]) -> str: """ 将动作列表转换为可读的文本格式。 格式: 在()分钟前,你使用了(action_name),具体内容是:(action_prompt_display) @@ -922,12 +921,12 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str: async def build_readable_messages_with_list( - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, timestamp_mode: str = "relative", truncate: bool = False, -) -> Tuple[str, List[Tuple[float, str, str]]]: +) -> tuple[str, list[tuple[float, str, str]]]: """ 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 允许通过参数控制格式化行为。 @@ -943,7 +942,7 @@ async def build_readable_messages_with_list( async def build_readable_messages_with_id( - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, timestamp_mode: str = "relative", @@ -951,7 +950,7 @@ async def build_readable_messages_with_id( truncate: bool = False, show_actions: bool = False, show_pic: bool = True, -) -> Tuple[str, List[Dict[str, Any]]]: +) -> tuple[str, list[dict[str, Any]]]: """ 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 允许通过参数控制格式化行为。 @@ -980,7 +979,7 @@ async def build_readable_messages_with_id( async def build_readable_messages( - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, timestamp_mode: str = "relative", @@ -988,7 +987,7 @@ async def build_readable_messages( truncate: bool = False, show_actions: bool = True, show_pic: bool = True, - message_id_list: Optional[List[Dict[str, Any]]] = None, + message_id_list: list[dict[str, Any]] | None = None, ) -> str: # sourcery skip: extract-method """ 将消息列表转换为可读的文本格式。 @@ -1148,7 +1147,7 @@ async def build_readable_messages( return "".join(result_parts) -async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: +async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str: """ 构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。 处理 回复 和 @ 字段,将bbb映射为匿名占位符。 @@ -1261,7 +1260,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: return formatted_string -async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: +async def get_person_id_list(messages: list[dict[str, Any]]) -> list[str]: """ 从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。 diff --git a/src/chat/utils/memory_mappings.py b/src/chat/utils/memory_mappings.py index 4da20fdb5..b82771f8e 100644 --- a/src/chat/utils/memory_mappings.py +++ b/src/chat/utils/memory_mappings.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 记忆系统相关的映射表和工具函数 提供记忆类型、置信度、重要性等的中文标签映射 diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index baf77a143..d869ec7a2 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -3,19 +3,20 @@ 将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类 """ -import re import asyncio -import time import contextvars -from dataclasses import dataclass, field -from typing import Dict, Any, Optional, List, Literal, Tuple +import re +import time from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from typing import Any, Literal, Optional from rich.traceback import install + +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.chat_message_builder import build_readable_messages from src.common.logger import get_logger from src.config.config import global_config -from src.chat.utils.chat_message_builder import build_readable_messages -from src.chat.message_receive.chat_stream import get_chat_manager from src.person_info.person_info import get_person_info_manager install(extra_lines=3) @@ -50,11 +51,11 @@ class PromptParameters: debug_mode: bool = False # 聊天历史和上下文 - chat_target_info: Optional[Dict[str, Any]] = None - message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list) - message_list_before_short: List[Dict[str, Any]] = field(default_factory=list) + chat_target_info: dict[str, Any] | None = None + message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list) + message_list_before_short: list[dict[str, Any]] = field(default_factory=list) chat_talking_prompt_short: str = "" - target_user_info: Optional[Dict[str, Any]] = None + target_user_info: dict[str, Any] | None = None # 已构建的内容块 expression_habits_block: str = "" @@ -77,12 +78,12 @@ class PromptParameters: action_descriptions: str = "" # 可用动作信息 - available_actions: Optional[Dict[str, Any]] = None + available_actions: dict[str, Any] | None = None # 动态生成的聊天场景提示 chat_scene: str = "" - def validate(self) -> List[str]: + def validate(self) -> list[str]: """参数验证""" errors = [] if not self.chat_id: @@ -98,22 +99,22 @@ class PromptContext: """提示词上下文管理器""" def __init__(self): - self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {} + self._context_prompts: dict[str, dict[str, "Prompt"]] = {} self._current_context_var = contextvars.ContextVar("current_context", default=None) self._context_lock = asyncio.Lock() @property - def _current_context(self) -> Optional[str]: + def _current_context(self) -> str | None: """获取当前协程的上下文ID""" return self._current_context_var.get() @_current_context.setter - def _current_context(self, value: Optional[str]): + def _current_context(self, value: str | None): """设置当前协程的上下文ID""" self._current_context_var.set(value) # type: ignore @asynccontextmanager - async def async_scope(self, context_id: Optional[str] = None): + async def async_scope(self, context_id: str | None = None): """创建一个异步的临时提示模板作用域""" if context_id is not None: try: @@ -159,7 +160,7 @@ class PromptContext: return self._context_prompts[current_context][name] return None - async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None: + async def register_async(self, prompt: "Prompt", context_id: str | None = None) -> None: """异步注册提示模板到指定作用域""" async with self._context_lock: if target_context := context_id or self._current_context: @@ -177,7 +178,7 @@ class PromptManager: self._lock = asyncio.Lock() @asynccontextmanager - async def async_message_scope(self, message_id: Optional[str] = None): + async def async_message_scope(self, message_id: str | None = None): """为消息处理创建异步临时作用域""" async with self._context.async_scope(message_id): yield self @@ -236,8 +237,8 @@ class Prompt: def __init__( self, template: str, - name: Optional[str] = None, - parameters: Optional[PromptParameters] = None, + name: str | None = None, + parameters: PromptParameters | None = None, should_register: bool = True, ): """ @@ -277,7 +278,7 @@ class Prompt: """将临时标记还原为实际的花括号字符""" return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}") - def _parse_template_args(self, template: str) -> List[str]: + def _parse_template_args(self, template: str) -> list[str]: """解析模板参数""" template_args = [] processed_template = self._process_escaped_braces(template) @@ -321,7 +322,7 @@ class Prompt: logger.error(f"构建Prompt失败: {e}") raise RuntimeError(f"构建Prompt失败: {e}") from e - async def _build_context_data(self) -> Dict[str, Any]: + async def _build_context_data(self) -> dict[str, Any]: """构建智能上下文数据""" # 并行执行所有构建任务 start_time = time.time() @@ -401,7 +402,7 @@ class Prompt: default_result = self._get_default_result_for_task(task_name) results.append(default_result) except Exception as e: - logger.error(f"构建任务{task_name}失败: {str(e)}") + logger.error(f"构建任务{task_name}失败: {e!s}") default_result = self._get_default_result_for_task(task_name) results.append(default_result) @@ -411,7 +412,7 @@ class Prompt: task_name = task_names[i] if i < len(task_names) else f"task_{i}" if isinstance(result, Exception): - logger.error(f"构建任务{task_name}失败: {str(result)}") + logger.error(f"构建任务{task_name}失败: {result!s}") elif isinstance(result, dict): context_data.update(result) @@ -453,7 +454,7 @@ class Prompt: return context_data - async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None: + async def _build_s4u_chat_context(self, context_data: dict[str, Any]) -> None: """构建S4U模式的聊天上下文""" if not self.parameters.message_list_before_now_long: return @@ -468,7 +469,7 @@ class Prompt: context_data["read_history_prompt"] = read_history_prompt context_data["unread_history_prompt"] = unread_history_prompt - async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None: + async def _build_normal_chat_context(self, context_data: dict[str, Any]) -> None: """构建normal模式的聊天上下文""" if not self.parameters.chat_talking_prompt_short: return @@ -477,8 +478,8 @@ class Prompt: {self.parameters.chat_talking_prompt_short}""" async def _build_s4u_chat_history_prompts( - self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str - ) -> Tuple[str, str]: + self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str + ) -> tuple[str, str]: """构建S4U风格的已读/未读历史消息prompt""" try: # 动态导入default_generator以避免循环导入 @@ -492,7 +493,7 @@ class Prompt: except Exception as e: logger.error(f"构建S4U历史消息prompt失败: {e}") - async def _build_expression_habits(self) -> Dict[str, Any]: + async def _build_expression_habits(self) -> dict[str, Any]: """构建表达习惯""" use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.parameters.chat_id) if not use_expression: @@ -533,7 +534,7 @@ class Prompt: logger.error(f"构建表达习惯失败: {e}") return {"expression_habits_block": ""} - async def _build_memory_block(self) -> Dict[str, Any]: + async def _build_memory_block(self) -> dict[str, Any]: """构建记忆块""" if not global_config.memory.enable_memory: return {"memory_block": ""} @@ -653,7 +654,7 @@ class Prompt: logger.error(f"构建记忆块失败: {e}") return {"memory_block": ""} - async def _build_memory_block_fast(self) -> Dict[str, Any]: + async def _build_memory_block_fast(self) -> dict[str, Any]: """快速构建记忆块(简化版本,用于未预构建时的后备方案)""" if not global_config.memory.enable_memory: return {"memory_block": ""} @@ -677,7 +678,7 @@ class Prompt: logger.warning(f"快速构建记忆块失败: {e}") return {"memory_block": ""} - async def _build_relation_info(self) -> Dict[str, Any]: + async def _build_relation_info(self) -> dict[str, Any]: """构建关系信息""" try: relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to) @@ -686,7 +687,7 @@ class Prompt: logger.error(f"构建关系信息失败: {e}") return {"relation_info_block": ""} - async def _build_tool_info(self) -> Dict[str, Any]: + async def _build_tool_info(self) -> dict[str, Any]: """构建工具信息""" if not global_config.tool.enable_tool: return {"tool_info_block": ""} @@ -734,7 +735,7 @@ class Prompt: logger.error(f"构建工具信息失败: {e}") return {"tool_info_block": ""} - async def _build_knowledge_info(self) -> Dict[str, Any]: + async def _build_knowledge_info(self) -> dict[str, Any]: """构建知识信息""" if not global_config.lpmm_knowledge.enable: return {"knowledge_prompt": ""} @@ -783,7 +784,7 @@ class Prompt: logger.error(f"构建知识信息失败: {e}") return {"knowledge_prompt": ""} - async def _build_cross_context(self) -> Dict[str, Any]: + async def _build_cross_context(self) -> dict[str, Any]: """构建跨群上下文""" try: cross_context = await Prompt.build_cross_context( @@ -794,7 +795,7 @@ class Prompt: logger.error(f"构建跨群上下文失败: {e}") return {"cross_context_block": ""} - async def _format_with_context(self, context_data: Dict[str, Any]) -> str: + async def _format_with_context(self, context_data: dict[str, Any]) -> str: """使用上下文数据格式化模板""" if self.parameters.prompt_mode == "s4u": params = self._prepare_s4u_params(context_data) @@ -805,7 +806,7 @@ class Prompt: return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params) - def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: + def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]: """准备S4U模式的参数""" return { **context_data, @@ -834,7 +835,7 @@ class Prompt: or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", } - def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: + def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]: """准备Normal模式的参数""" return { **context_data, @@ -862,7 +863,7 @@ class Prompt: or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", } - def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: + def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]: """准备默认模式的参数""" return { "expression_habits_block": context_data.get("expression_habits_block", ""), @@ -905,7 +906,7 @@ class Prompt: result = self._restore_escaped_braces(processed_template) return result except (IndexError, KeyError) as e: - raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e + raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {e!s}") from e def __str__(self) -> str: """返回格式化后的结果或原始模板""" @@ -922,7 +923,7 @@ class Prompt: # ============================================================================= @staticmethod - def parse_reply_target(target_message: str) -> Tuple[str, str]: + def parse_reply_target(target_message: str) -> tuple[str, str]: """ 解析回复目标消息 - 统一实现 @@ -981,7 +982,7 @@ class Prompt: return await relationship_fetcher.build_relation_info(person_id, points_num=5) - def _get_default_result_for_task(self, task_name: str) -> Dict[str, Any]: + def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]: """ 为超时的任务提供默认结果 @@ -1008,7 +1009,7 @@ class Prompt: return {} @staticmethod - async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str: + async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: dict[str, Any] | None) -> str: """ 构建跨群聊上下文 - 统一实现 @@ -1071,7 +1072,7 @@ class Prompt: # 工厂函数 def create_prompt( - template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs + template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs ) -> Prompt: """快速创建Prompt实例的工厂函数""" if parameters is None: @@ -1080,7 +1081,7 @@ def create_prompt( async def create_prompt_async( - template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs + template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs ) -> Prompt: """异步创建Prompt实例""" prompt = create_prompt(template, name, parameters, **kwargs) diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 1c879a01b..96433d21a 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -1,11 +1,11 @@ import asyncio from collections import defaultdict from datetime import datetime, timedelta -from typing import Any, Dict, Tuple, List +from typing import Any +from src.common.database.sqlalchemy_database_api import db_get, db_query, db_save +from src.common.database.sqlalchemy_models import LLMUsage, Messages, OnlineTime from src.common.logger import get_logger -from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages -from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get from src.manager.async_task_manager import AsyncTask from src.manager.local_store_manager import local_storage @@ -150,7 +150,7 @@ class StatisticOutputTask(AsyncTask): # 延迟300秒启动,运行间隔300秒 super().__init__(task_name="Statistics Data Output Task", wait_before_start=0, run_interval=300) - self.name_mapping: Dict[str, Tuple[str, float]] = {} + self.name_mapping: dict[str, tuple[str, float]] = {} """ 联系人/群聊名称映射 {聊天ID: (联系人/群聊名称, 记录时间(timestamp))} 注:设计记录时间的目的是方便更新名称,使联系人/群聊名称保持最新 @@ -170,7 +170,7 @@ class StatisticOutputTask(AsyncTask): deploy_time = datetime(2000, 1, 1) local_storage["deploy_time"] = now.timestamp() - self.stat_period: List[Tuple[str, timedelta, str]] = [ + self.stat_period: list[tuple[str, timedelta, str]] = [ ("all_time", now - deploy_time, "自部署以来"), # 必须保留"all_time" ("last_7_days", timedelta(days=7), "最近7天"), ("last_24_hours", timedelta(days=1), "最近24小时"), @@ -181,7 +181,7 @@ class StatisticOutputTask(AsyncTask): 统计时间段 [(统计名称, 统计时间段, 统计描述), ...] """ - def _statistic_console_output(self, stats: Dict[str, Any], now: datetime): + def _statistic_console_output(self, stats: dict[str, Any], now: datetime): """ 输出统计数据到控制台 :param stats: 统计数据 @@ -239,7 +239,7 @@ class StatisticOutputTask(AsyncTask): # -- 以下为统计数据收集方法 -- @staticmethod - async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: + async def _collect_model_request_for_period(collect_period: list[tuple[str, datetime]]) -> dict[str, Any]: """ 收集指定时间段的LLM请求统计数据 @@ -393,8 +393,8 @@ class StatisticOutputTask(AsyncTask): @staticmethod async def _collect_online_time_for_period( - collect_period: List[Tuple[str, datetime]], now: datetime - ) -> Dict[str, Any]: + collect_period: list[tuple[str, datetime]], now: datetime + ) -> dict[str, Any]: """ 收集指定时间段的在线时间统计数据 @@ -452,7 +452,7 @@ class StatisticOutputTask(AsyncTask): break return stats - async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: + async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]: """ 收集指定时间段的消息统计数据 @@ -523,7 +523,7 @@ class StatisticOutputTask(AsyncTask): break return stats - async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: + async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]: """ 收集各时间段的统计数据 :param now: 基准当前时间 @@ -533,7 +533,7 @@ class StatisticOutputTask(AsyncTask): if "last_full_statistics" in local_storage: # 如果存在上次完整统计数据,则使用该数据进行增量统计 - last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore + last_stat: dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射 last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据 @@ -620,7 +620,7 @@ class StatisticOutputTask(AsyncTask): # -- 以下为统计数据格式化方法 -- @staticmethod - def _format_total_stat(stats: Dict[str, Any]) -> str: + def _format_total_stat(stats: dict[str, Any]) -> str: """ 格式化总统计数据 """ @@ -636,7 +636,7 @@ class StatisticOutputTask(AsyncTask): return "\n".join(output) @staticmethod - def _format_model_classified_stat(stats: Dict[str, Any]) -> str: + def _format_model_classified_stat(stats: dict[str, Any]) -> str: """ 格式化按模型分类的统计数据 """ @@ -662,7 +662,7 @@ class StatisticOutputTask(AsyncTask): output.append("") return "\n".join(output) - def _format_chat_stat(self, stats: Dict[str, Any]) -> str: + def _format_chat_stat(self, stats: dict[str, Any]) -> str: """ 格式化聊天统计数据 """ @@ -1007,7 +1007,7 @@ class StatisticOutputTask(AsyncTask): async def _generate_chart_data(self, stat: dict[str, Any]) -> dict: """生成图表数据 (异步)""" now = datetime.now() - chart_data: Dict[str, Any] = {} + chart_data: dict[str, Any] = {} time_ranges = [ ("6h", 6, 10), @@ -1023,16 +1023,16 @@ class StatisticOutputTask(AsyncTask): async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict: start_time = now - timedelta(hours=hours) - time_points: List[datetime] = [] + time_points: list[datetime] = [] current_time = start_time while current_time <= now: time_points.append(current_time) current_time += timedelta(minutes=interval_minutes) total_cost_data = [0.0] * len(time_points) - cost_by_model: Dict[str, List[float]] = {} - cost_by_module: Dict[str, List[float]] = {} - message_by_chat: Dict[str, List[int]] = {} + cost_by_model: dict[str, list[float]] = {} + cost_by_module: dict[str, list[float]] = {} + message_by_chat: dict[str, list[int]] = {} time_labels = [t.strftime("%H:%M") for t in time_points] interval_seconds = interval_minutes * 60 diff --git a/src/chat/utils/timer_calculator.py b/src/chat/utils/timer_calculator.py index d9479af16..acdadc956 100644 --- a/src/chat/utils/timer_calculator.py +++ b/src/chat/utils/timer_calculator.py @@ -1,8 +1,8 @@ import asyncio - -from time import perf_counter +from collections.abc import Callable from functools import wraps -from typing import Optional, Dict, Callable +from time import perf_counter + from rich.traceback import install install(extra_lines=3) @@ -75,12 +75,12 @@ class Timer: 3. 直接实例化:如果不调用 __enter__,打印对象时将显示当前 perf_counter 的值 """ - __slots__ = ("name", "storage", "elapsed", "auto_unit", "start") + __slots__ = ("auto_unit", "elapsed", "name", "start", "storage") def __init__( self, - name: Optional[str] = None, - storage: Optional[Dict[str, float]] = None, + name: str | None = None, + storage: dict[str, float] | None = None, auto_unit: bool = True, do_type_check: bool = False, ): @@ -103,7 +103,7 @@ class Timer: if storage is not None and not isinstance(storage, dict): raise TimerTypeError("storage", "Optional[dict]", type(storage)) - def __call__(self, func: Optional[Callable] = None) -> Callable: + def __call__(self, func: Callable | None = None) -> Callable: """装饰器模式""" if func is None: return lambda f: Timer(name=self.name or f.__name__, storage=self.storage, auto_unit=self.auto_unit)(f) diff --git a/src/chat/utils/typo_generator.py b/src/chat/utils/typo_generator.py index 9c3718b2b..1852679a3 100644 --- a/src/chat/utils/typo_generator.py +++ b/src/chat/utils/typo_generator.py @@ -2,15 +2,15 @@ 错别字生成器 - 基于拼音和字频的中文错别字生成工具 """ -import orjson import math import os import random import time -import jieba - from collections import defaultdict from pathlib import Path + +import jieba +import orjson from pypinyin import Style, pinyin from src.common.logger import get_logger @@ -51,7 +51,7 @@ class ChineseTypoGenerator: # 如果缓存文件存在,直接加载 if cache_file.exists(): - with open(cache_file, "r", encoding="utf-8") as f: + with open(cache_file, encoding="utf-8") as f: return orjson.loads(f.read()) # 使用内置的词频文件 @@ -59,7 +59,7 @@ class ChineseTypoGenerator: dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt") # 读取jieba的词典文件 - with open(dict_path, "r", encoding="utf-8") as f: + with open(dict_path, encoding="utf-8") as f: for line in f: word, freq = line.strip().split()[:2] # 对词中的每个字进行频率累加 @@ -254,7 +254,7 @@ class ChineseTypoGenerator: # 获取jieba词典和词频信息 dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt") valid_words = {} # 改用字典存储词语及其频率 - with open(dict_path, "r", encoding="utf-8") as f: + with open(dict_path, encoding="utf-8") as f: for line in f: parts = line.strip().split() if len(parts) >= 2: diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index ea3bdc89f..8659b3539 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -3,20 +3,21 @@ import random import re import string import time +from collections import Counter +from typing import Any + import jieba import numpy as np - -from collections import Counter from maim_message import UserInfo -from typing import Optional, Tuple, Dict, List, Any -from src.common.logger import get_logger -from src.common.message_repository import find_messages, count_messages -from src.config.config import global_config, model_config -from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.message import MessageRecv +from src.common.logger import get_logger +from src.common.message_repository import count_messages, find_messages +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import PersonInfoManager, get_person_info_manager + from .typo_generator import ChineseTypoGenerator logger = get_logger("chat_utils") @@ -86,9 +87,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: if not is_mentioned: # 判断是否被回复 if re.match( - rf"\[回复 (.+?)\({str(global_config.bot.qq_account)}\):(.+?)\],说:", message.processed_plain_text + rf"\[回复 (.+?)\({global_config.bot.qq_account!s}\):(.+?)\],说:", message.processed_plain_text ) or re.match( - rf"\[回复<(.+?)(?=:{str(global_config.bot.qq_account)}>)\:{str(global_config.bot.qq_account)}>:(.+?)\],说:", + rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>:(.+?)\],说:", message.processed_plain_text, ): is_mentioned = True @@ -110,14 +111,14 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: return is_mentioned, reply_probability -async def get_embedding(text, request_type="embedding") -> Optional[List[float]]: +async def get_embedding(text, request_type="embedding") -> list[float] | None: """获取文本的embedding向量""" # 每次都创建新的LLMRequest实例以避免事件循环冲突 llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type) try: embedding, _ = await llm.get_embedding(text) except Exception as e: - logger.error(f"获取embedding失败: {str(e)}") + logger.error(f"获取embedding失败: {e!s}") embedding = None return embedding @@ -621,7 +622,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal" return time.strftime("%H:%M:%S", time.localtime(timestamp)) -def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: +def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None]: """ 获取聊天类型(是否群聊)和私聊对象信息。 @@ -670,7 +671,6 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: if loop.is_running(): # 如果事件循环在运行,从其他线程提交并等待结果 try: - fut = asyncio.run_coroutine_threadsafe( person_info_manager.get_value(person_id, "person_name"), loop ) @@ -706,7 +706,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: return is_group_chat, chat_target_info -def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]: +def assign_message_ids(messages: list[Any]) -> list[dict[str, Any]]: """ 为消息列表中的每个消息分配唯一的简短随机ID diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index ab0915842..29a918d87 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -1,29 +1,27 @@ import base64 +import hashlib +import io import os import time -import hashlib import uuid -import io -import numpy as np +from typing import Any -from typing import Optional, Tuple, Dict, Any +import numpy as np from PIL import Image from rich.traceback import install +from sqlalchemy import and_, select +from src.common.database.sqlalchemy_models import ImageDescriptions, Images, get_db_session from src.common.logger import get_logger -from src.common.database.sqlalchemy_models import Images, ImageDescriptions from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.common.database.sqlalchemy_models import get_db_session - -from sqlalchemy import select, and_ install(extra_lines=3) logger = get_logger("chat_image") -def is_image_message(message: Dict[str, Any]) -> bool: +def is_image_message(message: dict[str, Any]) -> bool: """ 判断消息是否为图片消息 @@ -69,7 +67,7 @@ class ImageManager: os.makedirs(self.IMAGE_DIR, exist_ok=True) @staticmethod - async def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: + async def _get_description_from_db(image_hash: str, description_type: str) -> str | None: """从数据库获取图片描述 Args: @@ -93,7 +91,7 @@ class ImageManager: ).scalar() return record.description if record else None except Exception as e: - logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}") + logger.error(f"从数据库获取描述失败 (SQLAlchemy): {e!s}") return None @staticmethod @@ -136,7 +134,7 @@ class ImageManager: await session.commit() # 会在上下文管理器中自动调用 except Exception as e: - logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}") + logger.error(f"保存描述到数据库失败 (SQLAlchemy): {e!s}") @staticmethod async def get_emoji_tag(image_base64: str) -> str: @@ -287,10 +285,10 @@ class ImageManager: session.add(new_img) await session.commit() except Exception as e: - logger.error(f"保存到Images表失败: {str(e)}") + logger.error(f"保存到Images表失败: {e!s}") except Exception as e: - logger.error(f"保存表情包文件或元数据失败: {str(e)}") + logger.error(f"保存表情包文件或元数据失败: {e!s}") else: logger.debug("偷取表情包功能已关闭,跳过保存。") @@ -300,7 +298,7 @@ class ImageManager: return f"[表情包:{final_emotion}]" except Exception as e: - logger.error(f"获取表情包描述失败: {str(e)}") + logger.error(f"获取表情包描述失败: {e!s}") return "[表情包(处理失败)]" async def get_image_description(self, image_base64: str) -> str: @@ -391,11 +389,11 @@ class ImageManager: logger.info(f"[VLM完成] 图片描述生成: {description}...") return f"[图片:{description}]" except Exception as e: - logger.error(f"获取图片描述失败: {str(e)}") + logger.error(f"获取图片描述失败: {e!s}") return "[图片(处理失败)]" @staticmethod - def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]: + def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> str | None: # sourcery skip: use-contextlib-suppress """将GIF转换为水平拼接的静态图像, 跳过相似的帧 @@ -512,10 +510,10 @@ class ImageManager: logger.error("GIF转换失败: 内存不足,可能是GIF太大或帧数太多") return None # 内存不够啦 except Exception as e: - logger.error(f"GIF转换失败: {str(e)}", exc_info=True) # 记录详细错误信息 + logger.error(f"GIF转换失败: {e!s}", exc_info=True) # 记录详细错误信息 return None # 其他错误也返回None - async def process_image(self, image_base64: str) -> Tuple[str, str]: + async def process_image(self, image_base64: str) -> tuple[str, str]: # sourcery skip: hoist-if-from-if """处理图片并返回图片ID和描述 @@ -604,7 +602,7 @@ class ImageManager: return image_id, f"[picid:{image_id}]" except Exception as e: - logger.error(f"处理图片失败: {str(e)}") + logger.error(f"处理图片失败: {e!s}") return "", "[图片]" @@ -637,4 +635,4 @@ def image_path_to_base64(image_path: str) -> str: if image_data := f.read(): return base64.b64encode(image_data).decode("utf-8") else: - raise IOError(f"读取图片文件失败: {image_path}") + raise OSError(f"读取图片文件失败: {image_path}") diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 19ec72cb6..6a6fc6245 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -1,29 +1,28 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ 视频分析器模块 - Rust优化版本 集成了Rust视频关键帧提取模块,提供高性能的视频分析功能 支持SIMD优化、多线程处理和智能关键帧检测 """ -import os -import tempfile import asyncio import base64 import hashlib +import io +import os +import tempfile import time +from pathlib import Path + import numpy as np from PIL import Image -from pathlib import Path -from typing import List, Tuple, Optional, Dict -import io - -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger -from src.common.database.sqlalchemy_models import get_db_session, Videos from sqlalchemy import select +from src.common.database.sqlalchemy_models import Videos, get_db_session +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest + logger = get_logger("utils_video") # Rust模块可用性检测 @@ -203,7 +202,7 @@ class VideoAnalyzer: hash_obj.update(video_data) return hash_obj.hexdigest() - async def _check_video_exists(self, video_hash: str) -> Optional[Videos]: + async def _check_video_exists(self, video_hash: str) -> Videos | None: """检查视频是否已经分析过""" try: async with get_db_session() as session: @@ -220,8 +219,8 @@ class VideoAnalyzer: return None async def _store_video_result( - self, video_hash: str, description: str, metadata: Optional[Dict] = None - ) -> Optional[Videos]: + self, video_hash: str, description: str, metadata: dict | None = None + ) -> Videos | None: """存储视频分析结果到数据库""" # 检查描述是否为错误信息,如果是则不保存 if description.startswith("❌"): @@ -281,7 +280,7 @@ class VideoAnalyzer: else: logger.warning(f"无效的分析模式: {mode}") - async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]: + async def extract_frames(self, video_path: str) -> list[tuple[str, float]]: """提取视频帧 - 智能选择最佳实现""" # 检查是否应该使用Rust实现 if RUST_VIDEO_AVAILABLE and self.frame_extraction_mode == "keyframe": @@ -303,7 +302,7 @@ class VideoAnalyzer: logger.info(f"🔄 抽帧模式为 {self.frame_extraction_mode},使用Python抽帧实现") return await self._extract_frames_python_fallback(video_path) - async def _extract_frames_rust_advanced(self, video_path: str) -> List[Tuple[str, float]]: + async def _extract_frames_rust_advanced(self, video_path: str) -> list[tuple[str, float]]: """使用 Rust 高级接口的帧提取""" try: logger.info("🔄 使用 Rust 高级接口提取关键帧...") @@ -387,7 +386,7 @@ class VideoAnalyzer: logger.info("回退到基础 Rust 方法") return await self._extract_frames_rust(video_path) - async def _extract_frames_rust(self, video_path: str) -> List[Tuple[str, float]]: + async def _extract_frames_rust(self, video_path: str) -> list[tuple[str, float]]: """使用 Rust 实现的帧提取""" try: logger.info("🔄 使用 Rust 模块提取关键帧...") @@ -463,7 +462,7 @@ class VideoAnalyzer: logger.error(f"❌ Rust 帧提取失败: {e}") raise e - async def _extract_frames_python_fallback(self, video_path: str) -> List[Tuple[str, float]]: + async def _extract_frames_python_fallback(self, video_path: str) -> list[tuple[str, float]]: """Python降级抽帧实现 - 支持多种抽帧模式""" try: # 导入旧版本分析器 @@ -490,7 +489,7 @@ class VideoAnalyzer: logger.error(f"❌ Python降级抽帧失败: {e}") return [] - async def analyze_frames_batch(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: + async def analyze_frames_batch(self, frames: list[tuple[str, float]], user_question: str = None) -> str: """批量分析所有帧""" logger.info(f"开始批量分析{len(frames)}帧") @@ -526,7 +525,7 @@ class VideoAnalyzer: logger.error(f"❌ 视频识别失败: {e}") raise e - async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str: + async def _analyze_multiple_frames(self, frames: list[tuple[str, float]], prompt: str) -> str: """使用多图片分析方法""" logger.info(f"开始构建包含{len(frames)}帧的分析请求") @@ -566,7 +565,7 @@ class VideoAnalyzer: logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ") return api_response.content or "❌ 未获得响应内容" - async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: + async def analyze_frames_sequential(self, frames: list[tuple[str, float]], user_question: str = None) -> str: """逐帧分析并汇总""" logger.info(f"开始逐帧分析{len(frames)}帧") @@ -624,7 +623,7 @@ class VideoAnalyzer: # 如果汇总失败,返回各帧分析结果 return f"视频逐帧分析结果:\n\n{chr(10).join(frame_analyses)}" - async def analyze_video(self, video_path: str, user_question: str = None) -> Tuple[bool, str]: + async def analyze_video(self, video_path: str, user_question: str = None) -> tuple[bool, str]: """分析视频的主要方法 Returns: @@ -662,13 +661,13 @@ class VideoAnalyzer: return (True, result) except Exception as e: - error_msg = f"❌ 视频分析失败: {str(e)}" + error_msg = f"❌ 视频分析失败: {e!s}" logger.error(error_msg) return (False, error_msg) async def analyze_video_from_bytes( self, video_bytes: bytes, filename: str = None, user_question: str = None, prompt: str = None - ) -> Dict[str, str]: + ) -> dict[str, str]: """从字节数据分析视频 Args: @@ -778,7 +777,7 @@ class VideoAnalyzer: return {"summary": result} except Exception as e: - error_msg = f"❌ 从字节数据分析视频失败: {str(e)}" + error_msg = f"❌ 从字节数据分析视频失败: {e!s}" logger.error(error_msg) # 不保存错误信息到数据库,允许后续重试 @@ -802,7 +801,7 @@ class VideoAnalyzer: supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"} return Path(file_path).suffix.lower() in supported_formats - def get_processing_capabilities(self) -> Dict[str, any]: + def get_processing_capabilities(self) -> dict[str, any]: """获取处理能力信息""" if not RUST_VIDEO_AVAILABLE: return {"error": "Rust视频处理模块不可用", "available": False, "reason": "rust_video模块未安装或加载失败"} @@ -832,7 +831,7 @@ class VideoAnalyzer: logger.error(f"获取处理能力信息失败: {e}") return {"error": str(e), "available": False} - def _get_recommended_settings(self, cpu_features: Dict[str, bool]) -> Dict[str, any]: + def _get_recommended_settings(self, cpu_features: dict[str, bool]) -> dict[str, any]: """根据CPU特性推荐最佳设置""" settings = { "use_simd": any(cpu_features.values()), @@ -882,7 +881,7 @@ def is_video_analysis_available() -> bool: return False -def get_video_analysis_status() -> Dict[str, any]: +def get_video_analysis_status() -> dict[str, any]: """获取视频分析功能的详细状态信息 Returns: diff --git a/src/chat/utils/utils_video_legacy.py b/src/chat/utils/utils_video_legacy.py index 77ca88142..46eb13857 100644 --- a/src/chat/utils/utils_video_legacy.py +++ b/src/chat/utils/utils_video_legacy.py @@ -1,25 +1,25 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ 视频分析器模块 - 旧版本兼容模块 支持多种分析模式:批处理、逐帧、自动选择 包含Python原生的抽帧功能,作为Rust模块的降级方案 """ -import os -import cv2 import asyncio import base64 +import io +import os +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any + +import cv2 import numpy as np from PIL import Image -from pathlib import Path -from typing import List, Tuple, Optional, Any -import io -from concurrent.futures import ThreadPoolExecutor -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("utils_video_legacy") @@ -30,7 +30,7 @@ def _extract_frames_worker( frame_quality: int, max_image_size: int, frame_extraction_mode: str, - frame_interval_seconds: Optional[float], + frame_interval_seconds: float | None, ) -> list[Any] | list[tuple[str, str]]: """线程池中提取视频帧的工作函数""" frames = [] @@ -221,7 +221,7 @@ class LegacyVideoAnalyzer: f"✅ 旧版本视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}" ) - async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]: + async def extract_frames(self, video_path: str) -> list[tuple[str, float]]: """提取视频帧 - 支持多进程和单线程模式""" # 先获取视频信息 cap = cv2.VideoCapture(video_path) @@ -247,7 +247,7 @@ class LegacyVideoAnalyzer: else: return await self._extract_frames_fallback(video_path) - async def _extract_frames_multiprocess(self, video_path: str) -> List[Tuple[str, float]]: + async def _extract_frames_multiprocess(self, video_path: str) -> list[tuple[str, float]]: """线程池版本的帧提取""" loop = asyncio.get_event_loop() @@ -282,7 +282,7 @@ class LegacyVideoAnalyzer: logger.info("🔄 降级到单线程模式...") return await self._extract_frames_fallback(video_path) - async def _extract_frames_fallback(self, video_path: str) -> List[Tuple[str, float]]: + async def _extract_frames_fallback(self, video_path: str) -> list[tuple[str, float]]: """帧提取的降级方法 - 原始异步版本""" frames = [] extracted_count = 0 @@ -389,7 +389,7 @@ class LegacyVideoAnalyzer: logger.info(f"✅ 成功提取{len(frames)}帧") return frames - async def analyze_frames_batch(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: + async def analyze_frames_batch(self, frames: list[tuple[str, float]], user_question: str = None) -> str: """批量分析所有帧""" logger.info(f"开始批量分析{len(frames)}帧") @@ -441,7 +441,7 @@ class LegacyVideoAnalyzer: logger.error(f"❌ 降级分析也失败: {fallback_e}") raise - async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str: + async def _analyze_multiple_frames(self, frames: list[tuple[str, float]], prompt: str) -> str: """使用多图片分析方法""" logger.info(f"开始构建包含{len(frames)}帧的分析请求") @@ -481,7 +481,7 @@ class LegacyVideoAnalyzer: logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ") return api_response.content or "❌ 未获得响应内容" - async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: + async def analyze_frames_sequential(self, frames: list[tuple[str, float]], user_question: str = None) -> str: """逐帧分析并汇总""" logger.info(f"开始逐帧分析{len(frames)}帧") @@ -567,7 +567,7 @@ class LegacyVideoAnalyzer: return result except Exception as e: - error_msg = f"❌ 视频分析失败: {str(e)}" + error_msg = f"❌ 视频分析失败: {e!s}" logger.error(error_msg) return error_msg diff --git a/src/chat/utils/utils_voice.py b/src/chat/utils/utils_voice.py index 49ec10794..eae96e5f3 100644 --- a/src/chat/utils/utils_voice.py +++ b/src/chat/utils/utils_voice.py @@ -1,8 +1,8 @@ -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from rich.traceback import install from src.common.logger import get_logger -from rich.traceback import install +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest install(extra_lines=3) @@ -25,5 +25,5 @@ async def get_voice_text(voice_base64: str) -> str: return f"[语音:{text}]" except Exception as e: - logger.error(f"语音转文字失败: {str(e)}") + logger.error(f"语音转文字失败: {e!s}") return "[语音]" diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index c77d9e8bd..9afb70dcc 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -1,17 +1,19 @@ -import time -import orjson import hashlib +import time from pathlib import Path -import numpy as np +from typing import Any + import faiss -from typing import Any, Dict, Optional, Union -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config +import numpy as np +import orjson + from src.common.config_helpers import resolve_embedding_dimension -from src.common.database.sqlalchemy_models import CacheEntries from src.common.database.sqlalchemy_database_api import db_query, db_save +from src.common.database.sqlalchemy_models import CacheEntries +from src.common.logger import get_logger from src.common.vector_db import vector_db_service +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("cache_manager") @@ -40,14 +42,14 @@ class CacheManager: self.semantic_cache_collection_name = "semantic_cache" # L1 缓存 (内存) - self.l1_kv_cache: Dict[str, Dict[str, Any]] = {} + self.l1_kv_cache: dict[str, dict[str, Any]] = {} embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension) if not embedding_dim: embedding_dim = global_config.lpmm_knowledge.embedding_dimension self.embedding_dimension = embedding_dim self.l1_vector_index = faiss.IndexFlatIP(embedding_dim) - self.l1_vector_id_to_key: Dict[int, str] = {} + self.l1_vector_id_to_key: dict[int, str] = {} # L2 向量缓存 (使用新的服务) vector_db_service.get_or_create_collection(self.semantic_cache_collection_name) @@ -59,7 +61,7 @@ class CacheManager: logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)") @staticmethod - def _validate_embedding(embedding_result: Any) -> Optional[np.ndarray]: + def _validate_embedding(embedding_result: Any) -> np.ndarray | None: """ 验证和标准化嵌入向量格式 """ @@ -100,7 +102,7 @@ class CacheManager: return None @staticmethod - def _generate_key(tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path]) -> str: + def _generate_key(tool_name: str, function_args: dict[str, Any], tool_file_path: str | Path) -> str: """生成确定性的缓存键,包含文件修改时间以实现自动失效。""" try: tool_file_path = Path(tool_file_path) @@ -124,10 +126,10 @@ class CacheManager: async def get( self, tool_name: str, - function_args: Dict[str, Any], - tool_file_path: Union[str, Path], - semantic_query: Optional[str] = None, - ) -> Optional[Any]: + function_args: dict[str, Any], + tool_file_path: str | Path, + semantic_query: str | None = None, + ) -> Any | None: """ 从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。 """ @@ -251,11 +253,11 @@ class CacheManager: async def set( self, tool_name: str, - function_args: Dict[str, Any], - tool_file_path: Union[str, Path], + function_args: dict[str, Any], + tool_file_path: str | Path, data: Any, - ttl: Optional[int] = None, - semantic_query: Optional[str] = None, + ttl: int | None = None, + semantic_query: str | None = None, ): """将结果存入所有缓存层。""" if ttl is None: diff --git a/src/common/config_helpers.py b/src/common/config_helpers.py index 5a2134fe1..f5460fece 100644 --- a/src/common/config_helpers.py +++ b/src/common/config_helpers.py @@ -1,11 +1,9 @@ from __future__ import annotations -from typing import Optional - from src.config.config import global_config, model_config -def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: bool = True) -> Optional[int]: +def resolve_embedding_dimension(fallback: int | None = None, *, sync_global: bool = True) -> int | None: """获取当前配置的嵌入向量维度。 优先顺序: @@ -14,7 +12,7 @@ def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: 3. 调用方提供的 fallback """ - candidates: list[Optional[int]] = [] + candidates: list[int | None] = [] try: embedding_task = getattr(model_config.model_task_config, "embedding", None) @@ -30,7 +28,7 @@ def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: candidates.append(fallback) - resolved: Optional[int] = next((int(dim) for dim in candidates if dim and int(dim) > 0), None) + resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None) if resolved and sync_global: try: diff --git a/src/common/data_models/bot_interest_data_model.py b/src/common/data_models/bot_interest_data_model.py index 819b50a8f..fe152ca2e 100644 --- a/src/common/data_models/bot_interest_data_model.py +++ b/src/common/data_models/bot_interest_data_model.py @@ -4,8 +4,8 @@ """ from dataclasses import dataclass, field -from typing import List, Dict, Optional, Any from datetime import datetime +from typing import Any from . import BaseDataModel @@ -16,12 +16,12 @@ class BotInterestTag(BaseDataModel): tag_name: str weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0) - embedding: Optional[List[float]] = None # 标签的embedding向量 + embedding: list[float] | None = None # 标签的embedding向量 created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) is_active: bool = True - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为字典格式""" return { "tag_name": self.tag_name, @@ -33,7 +33,7 @@ class BotInterestTag(BaseDataModel): } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "BotInterestTag": + def from_dict(cls, data: dict[str, Any]) -> "BotInterestTag": """从字典创建对象""" return cls( tag_name=data["tag_name"], @@ -51,16 +51,16 @@ class BotPersonalityInterests(BaseDataModel): personality_id: str personality_description: str # 人设描述文本 - interest_tags: List[BotInterestTag] = field(default_factory=list) + interest_tags: list[BotInterestTag] = field(default_factory=list) embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型 last_updated: datetime = field(default_factory=datetime.now) version: int = 1 # 版本号,用于追踪更新 - def get_active_tags(self) -> List[BotInterestTag]: + def get_active_tags(self) -> list[BotInterestTag]: """获取活跃的兴趣标签""" return [tag for tag in self.interest_tags if tag.is_active] - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为字典格式""" return { "personality_id": self.personality_id, @@ -72,7 +72,7 @@ class BotPersonalityInterests(BaseDataModel): } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "BotPersonalityInterests": + def from_dict(cls, data: dict[str, Any]) -> "BotPersonalityInterests": """从字典创建对象""" return cls( personality_id=data["personality_id"], @@ -89,14 +89,14 @@ class InterestMatchResult(BaseDataModel): """兴趣匹配结果""" message_id: str - matched_tags: List[str] = field(default_factory=list) - match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score + matched_tags: list[str] = field(default_factory=list) + match_scores: dict[str, float] = field(default_factory=dict) # tag_name -> score overall_score: float = 0.0 - top_tag: Optional[str] = None + top_tag: str | None = None confidence: float = 0.0 # 匹配置信度 (0.0-1.0) - matched_keywords: List[str] = field(default_factory=list) + matched_keywords: list[str] = field(default_factory=list) - def add_match(self, tag_name: str, score: float, keywords: List[str] = None): + def add_match(self, tag_name: str, score: float, keywords: list[str] = None): """添加匹配结果""" self.matched_tags.append(tag_name) self.match_scores[tag_name] = score @@ -131,7 +131,7 @@ class InterestMatchResult(BaseDataModel): else: self.confidence = 0.0 - def get_top_matches(self, top_n: int = 3) -> List[tuple]: + def get_top_matches(self, top_n: int = 3) -> list[tuple]: """获取前N个最佳匹配""" sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True) return sorted_matches[:top_n] diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 4578d1481..f1bc0ef67 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -1,6 +1,6 @@ import json -from typing import Optional, Any, Dict from dataclasses import dataclass, field +from typing import Any from . import BaseDataModel @@ -10,7 +10,7 @@ class DatabaseUserInfo(BaseDataModel): platform: str = field(default_factory=str) user_id: str = field(default_factory=str) user_nickname: str = field(default_factory=str) - user_cardname: Optional[str] = None + user_cardname: str | None = None # def __post_init__(self): # assert isinstance(self.platform, str), "platform must be a string" @@ -25,7 +25,7 @@ class DatabaseUserInfo(BaseDataModel): class DatabaseGroupInfo(BaseDataModel): group_id: str = field(default_factory=str) group_name: str = field(default_factory=str) - group_platform: Optional[str] = None + group_platform: str | None = None # def __post_init__(self): # assert isinstance(self.group_id, str), "group_id must be a string" @@ -42,7 +42,7 @@ class DatabaseChatInfo(BaseDataModel): create_time: float = field(default_factory=float) last_active_time: float = field(default_factory=float) user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo) - group_info: Optional[DatabaseGroupInfo] = None + group_info: DatabaseGroupInfo | None = None # def __post_init__(self): # assert isinstance(self.stream_id, str), "stream_id must be a string" @@ -62,41 +62,41 @@ class DatabaseMessages(BaseDataModel): message_id: str = "", time: float = 0.0, chat_id: str = "", - reply_to: Optional[str] = None, - interest_value: Optional[float] = None, - key_words: Optional[str] = None, - key_words_lite: Optional[str] = None, - is_mentioned: Optional[bool] = None, - is_at: Optional[bool] = None, - reply_probability_boost: Optional[float] = None, - processed_plain_text: Optional[str] = None, - display_message: Optional[str] = None, - priority_mode: Optional[str] = None, - priority_info: Optional[str] = None, - additional_config: Optional[str] = None, + reply_to: str | None = None, + interest_value: float | None = None, + key_words: str | None = None, + key_words_lite: str | None = None, + is_mentioned: bool | None = None, + is_at: bool | None = None, + reply_probability_boost: float | None = None, + processed_plain_text: str | None = None, + display_message: str | None = None, + priority_mode: str | None = None, + priority_info: str | None = None, + additional_config: str | None = None, is_emoji: bool = False, is_picid: bool = False, is_command: bool = False, is_notify: bool = False, - selected_expressions: Optional[str] = None, + selected_expressions: str | None = None, is_read: bool = False, user_id: str = "", user_nickname: str = "", - user_cardname: Optional[str] = None, + user_cardname: str | None = None, user_platform: str = "", - chat_info_group_id: Optional[str] = None, - chat_info_group_name: Optional[str] = None, - chat_info_group_platform: Optional[str] = None, + chat_info_group_id: str | None = None, + chat_info_group_name: str | None = None, + chat_info_group_platform: str | None = None, chat_info_user_id: str = "", chat_info_user_nickname: str = "", - chat_info_user_cardname: Optional[str] = None, + chat_info_user_cardname: str | None = None, chat_info_user_platform: str = "", chat_info_stream_id: str = "", chat_info_platform: str = "", chat_info_create_time: float = 0.0, chat_info_last_active_time: float = 0.0, # 新增字段 - actions: Optional[list] = None, + actions: list | None = None, should_reply: bool = False, **kwargs: Any, ): @@ -132,7 +132,7 @@ class DatabaseMessages(BaseDataModel): self.selected_expressions = selected_expressions self.is_read = is_read - self.group_info: Optional[DatabaseGroupInfo] = None + self.group_info: DatabaseGroupInfo | None = None self.user_info = DatabaseUserInfo( user_id=user_id, user_nickname=user_nickname, @@ -172,7 +172,7 @@ class DatabaseMessages(BaseDataModel): # assert isinstance(self.interest_value, float) or self.interest_value is None, ( # "interest_value must be a float or None" # ) - def flatten(self) -> Dict[str, Any]: + def flatten(self) -> dict[str, Any]: """ 将消息数据模型转换为字典格式,便于存储或传输 """ @@ -255,7 +255,7 @@ class DatabaseMessages(BaseDataModel): """ return self.actions or [] - def get_message_summary(self) -> Dict[str, Any]: + def get_message_summary(self) -> dict[str, Any]: """ 获取消息摘要信息 diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py index ba45ab3c4..e9ed04162 100644 --- a/src/common/data_models/info_data_model.py +++ b/src/common/data_models/info_data_model.py @@ -1,30 +1,32 @@ from dataclasses import dataclass, field -from typing import Optional, Dict, List, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from src.plugin_system.base.component_types import ChatType + from . import BaseDataModel if TYPE_CHECKING: - from .database_data_model import DatabaseMessages from src.plugin_system.base.component_types import ActionInfo, ChatMode + from .database_data_model import DatabaseMessages + @dataclass class TargetPersonInfo(BaseDataModel): platform: str = field(default_factory=str) user_id: str = field(default_factory=str) user_nickname: str = field(default_factory=str) - person_id: Optional[str] = None - person_name: Optional[str] = None + person_id: str | None = None + person_name: str | None = None @dataclass class ActionPlannerInfo(BaseDataModel): action_type: str = field(default_factory=str) - reasoning: Optional[str] = None - action_data: Optional[Dict] = None + reasoning: str | None = None + action_data: dict | None = None action_message: Optional["DatabaseMessages"] = None - available_actions: Optional[Dict[str, "ActionInfo"]] = None + available_actions: dict[str, "ActionInfo"] | None = None @dataclass @@ -36,7 +38,7 @@ class InterestScore(BaseDataModel): interest_match_score: float relationship_score: float mentioned_score: float - details: Dict[str, str] + details: dict[str, str] @dataclass @@ -50,10 +52,10 @@ class Plan(BaseDataModel): chat_type: "ChatType" # Generator 填充 - available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict) - chat_history: List["DatabaseMessages"] = field(default_factory=list) - target_info: Optional[TargetPersonInfo] = None + available_actions: dict[str, "ActionInfo"] = field(default_factory=dict) + chat_history: list["DatabaseMessages"] = field(default_factory=list) + target_info: TargetPersonInfo | None = None # Filter 填充 - llm_prompt: Optional[str] = None - decided_actions: Optional[List[ActionPlannerInfo]] = None + llm_prompt: str | None = None + decided_actions: list[ActionPlannerInfo] | None = None diff --git a/src/common/data_models/llm_data_model.py b/src/common/data_models/llm_data_model.py index a59b65391..147c2b22b 100644 --- a/src/common/data_models/llm_data_model.py +++ b/src/common/data_models/llm_data_model.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional, List, Tuple, TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any from . import BaseDataModel @@ -9,10 +9,10 @@ if TYPE_CHECKING: @dataclass class LLMGenerationDataModel(BaseDataModel): - content: Optional[str] = None - reasoning: Optional[str] = None - model: Optional[str] = None - tool_calls: Optional[List["ToolCall"]] = None - prompt: Optional[str] = None - selected_expressions: Optional[List[int]] = None - reply_set: Optional[List[Tuple[str, Any]]] = None + content: str | None = None + reasoning: str | None = None + model: str | None = None + tool_calls: list["ToolCall"] | None = None + prompt: str | None = None + selected_expressions: list[int] | None = None + reply_set: list[tuple[str, Any]] | None = None diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index a72b7564c..b836101cc 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -7,11 +7,12 @@ import asyncio import time from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional + +from src.common.logger import get_logger +from src.plugin_system.base.component_types import ChatMode, ChatType from . import BaseDataModel -from src.plugin_system.base.component_types import ChatMode, ChatType -from src.common.logger import get_logger if TYPE_CHECKING: from .database_data_model import DatabaseMessages @@ -34,11 +35,11 @@ class StreamContext(BaseDataModel): stream_id: str chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊 chat_mode: ChatMode = ChatMode.NORMAL # 聊天模式,默认为普通模式 - unread_messages: List["DatabaseMessages"] = field(default_factory=list) - history_messages: List["DatabaseMessages"] = field(default_factory=list) + unread_messages: list["DatabaseMessages"] = field(default_factory=list) + history_messages: list["DatabaseMessages"] = field(default_factory=list) last_check_time: float = field(default_factory=time.time) is_active: bool = True - processing_task: Optional[asyncio.Task] = None + processing_task: asyncio.Task | None = None interruption_count: int = 0 # 打断计数器 last_interruption_time: float = 0.0 # 上次打断时间 afc_threshold_adjustment: float = 0.0 # afc阈值调整量 @@ -49,8 +50,8 @@ class StreamContext(BaseDataModel): # 新增字段以替代ChatMessageContext功能 current_message: Optional["DatabaseMessages"] = None - priority_mode: Optional[str] = None - priority_info: Optional[dict] = None + priority_mode: str | None = None + priority_info: dict | None = None def add_message(self, message: "DatabaseMessages"): """添加消息到上下文""" @@ -150,11 +151,11 @@ class StreamContext(BaseDataModel): self.unread_messages.remove(msg) break - def get_unread_messages(self) -> List["DatabaseMessages"]: + def get_unread_messages(self) -> list["DatabaseMessages"]: """获取未读消息""" return [msg for msg in self.unread_messages if not msg.is_read] - def get_history_messages(self, limit: int = 20) -> List["DatabaseMessages"]: + def get_history_messages(self, limit: int = 20) -> list["DatabaseMessages"]: """获取历史消息""" # 优先返回最近的历史消息和所有未读消息 recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages @@ -230,7 +231,7 @@ class StreamContext(BaseDataModel): """设置当前消息""" self.current_message = message - def get_template_name(self) -> Optional[str]: + def get_template_name(self) -> str | None: """获取模板名称""" if ( self.current_message @@ -336,11 +337,11 @@ class StreamContext(BaseDataModel): return False return True - def get_priority_mode(self) -> Optional[str]: + def get_priority_mode(self) -> str | None: """获取优先级模式""" return self.priority_mode - def get_priority_info(self) -> Optional[dict]: + def get_priority_info(self) -> dict | None: """获取优先级信息""" return self.priority_info diff --git a/src/common/database/database.py b/src/common/database/database.py index 92c851edb..63f632aa5 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -1,10 +1,11 @@ import os + from rich.traceback import install -from src.common.logger import get_logger # SQLAlchemy相关导入 from src.common.database.sqlalchemy_init import initialize_database_compat -from src.common.database.sqlalchemy_models import get_engine, get_db_session +from src.common.database.sqlalchemy_models import get_db_session, get_engine +from src.common.logger import get_logger install(extra_lines=3) diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 330846983..38c972236 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -6,31 +6,31 @@ import time import traceback -from typing import Dict, List, Any, Union, Optional +from typing import Any -from sqlalchemy import desc, asc, func, and_, select +from sqlalchemy import and_, asc, desc, func, select from sqlalchemy.exc import SQLAlchemyError from src.common.database.sqlalchemy_models import ( - get_db_session, - Messages, ActionRecords, - PersonInfo, - ChatStreams, - LLMUsage, - Emoji, - Images, - ImageDescriptions, - OnlineTime, - Memory, - Expression, - ThinkingLog, - GraphNodes, - GraphEdges, - Schedule, - MaiZoneScheduleStatus, CacheEntries, + ChatStreams, + Emoji, + Expression, + GraphEdges, + GraphNodes, + ImageDescriptions, + Images, + LLMUsage, + MaiZoneScheduleStatus, + Memory, + Messages, + OnlineTime, + PersonInfo, + Schedule, + ThinkingLog, UserRelationships, + get_db_session, ) from src.common.logger import get_logger @@ -59,7 +59,7 @@ MODEL_MAPPING = { } -async def build_filters(model_class, filters: Dict[str, Any]): +async def build_filters(model_class, filters: dict[str, Any]): """构建查询过滤条件""" conditions = [] @@ -98,13 +98,13 @@ async def build_filters(model_class, filters: Dict[str, Any]): async def db_query( model_class, - data: Optional[Dict[str, Any]] = None, - query_type: Optional[str] = "get", - filters: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None, - order_by: Optional[List[str]] = None, - single_result: Optional[bool] = False, -) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: + data: dict[str, Any] | None = None, + query_type: str | None = "get", + filters: dict[str, Any] | None = None, + limit: int | None = None, + order_by: list[str] | None = None, + single_result: bool | None = False, +) -> list[dict[str, Any]] | dict[str, Any] | None: """执行异步数据库查询操作 Args: @@ -263,8 +263,8 @@ async def db_query( async def db_save( - model_class, data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None -) -> Optional[Dict[str, Any]]: + model_class, data: dict[str, Any], key_field: str | None = None, key_value: Any | None = None +) -> dict[str, Any] | None: """异步保存数据到数据库(创建或更新) Args: @@ -325,11 +325,11 @@ async def db_save( async def db_get( model_class, - filters: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None, - order_by: Optional[str] = None, - single_result: Optional[bool] = False, -) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: + filters: dict[str, Any] | None = None, + limit: int | None = None, + order_by: str | None = None, + single_result: bool | None = False, +) -> list[dict[str, Any]] | dict[str, Any] | None: """异步从数据库获取记录 Args: @@ -359,9 +359,9 @@ async def store_action_info( action_prompt_display: str = "", action_done: bool = True, thinking_id: str = "", - action_data: Optional[dict] = None, + action_data: dict | None = None, action_name: str = "", -) -> Optional[Dict[str, Any]]: +) -> dict[str, Any] | None: """异步存储动作信息到数据库 Args: diff --git a/src/common/database/sqlalchemy_init.py b/src/common/database/sqlalchemy_init.py index 7d3f97136..daf61f3a5 100644 --- a/src/common/database/sqlalchemy_init.py +++ b/src/common/database/sqlalchemy_init.py @@ -4,10 +4,10 @@ 提供统一的异步数据库初始化接口 """ -from typing import Optional from sqlalchemy.exc import SQLAlchemyError -from src.common.logger import get_logger + from src.common.database.sqlalchemy_models import Base, get_engine, initialize_database +from src.common.logger import get_logger logger = get_logger("sqlalchemy_init") @@ -71,7 +71,7 @@ async def create_all_tables() -> bool: return False -async def get_database_info() -> Optional[dict]: +async def get_database_info() -> dict | None: """ 异步获取数据库信息 diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 2f78e56d0..c89848ee3 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -6,11 +6,12 @@ import datetime import os import time +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import Optional, Any, Dict, AsyncGenerator +from typing import Any -from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime, text -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker +from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Mapped, mapped_column @@ -423,7 +424,7 @@ class Expression(Base): last_active_time: Mapped[float] = mapped_column(Float, nullable=False) chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) type: Mapped[str] = mapped_column(Text, nullable=False) - create_date: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + create_date: Mapped[float | None] = mapped_column(Float, nullable=True) __table_args__ = (Index("idx_expression_chat_id", "chat_id"),) @@ -710,7 +711,7 @@ async def initialize_database(): config = global_config.database # 配置引擎参数 - engine_kwargs: Dict[str, Any] = { + engine_kwargs: dict[str, Any] = { "echo": False, # 生产环境关闭SQL日志 "future": True, } @@ -759,12 +760,12 @@ async def initialize_database(): @asynccontextmanager -async def get_db_session() -> AsyncGenerator[Optional[AsyncSession], None]: +async def get_db_session() -> AsyncGenerator[AsyncSession | None, None]: """ 异步数据库会话上下文管理器。 在初始化失败时会yield None,调用方需要检查会话是否为None。 """ - session: Optional[AsyncSession] = None + session: AsyncSession | None = None SessionLocal = None try: _, SessionLocal = await initialize_database() diff --git a/src/common/logger.py b/src/common/logger.py index 2830c127d..a28628a46 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -1,16 +1,16 @@ # 使用基于时间戳的文件处理器,简单的轮转份数限制 import logging -import orjson import threading import time +from collections.abc import Callable +from datetime import datetime, timedelta +from pathlib import Path + +import orjson import structlog import tomlkit -from pathlib import Path -from typing import Callable, Optional -from datetime import datetime, timedelta - # 创建logs目录 LOG_DIR = Path("logs") LOG_DIR.mkdir(exist_ok=True) @@ -212,7 +212,7 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress try: if config_path.exists(): - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: config = tomlkit.load(f) return config.get("log", default_config) except Exception as e: @@ -942,7 +942,7 @@ raw_logger: structlog.stdlib.BoundLogger = structlog.get_logger() binds: dict[str, Callable] = {} -def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger: +def get_logger(name: str | None) -> structlog.stdlib.BoundLogger: """获取logger实例,支持按名称绑定""" if name is None: return raw_logger diff --git a/src/common/message/__init__.py b/src/common/message/__init__.py index 160456b0f..79f346c04 100644 --- a/src/common/message/__init__.py +++ b/src/common/message/__init__.py @@ -4,7 +4,6 @@ __version__ = "0.1.0" from .api import get_global_api - __all__ = [ "get_global_api", ] diff --git a/src/common/message/api.py b/src/common/message/api.py index 37b7a7ddc..2d797a5a8 100644 --- a/src/common/message/api.py +++ b/src/common/message/api.py @@ -1,10 +1,12 @@ -from src.common.server import get_global_server import importlib.metadata -from maim_message import MessageServer -from src.common.logger import get_logger -from src.config.config import global_config import os +from maim_message import MessageServer + +from src.common.logger import get_logger +from src.common.server import get_global_server +from src.config.config import global_config + global_api = None diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 57f179c36..f9a874859 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,15 +1,15 @@ import traceback +from typing import Any -from typing import List, Optional, Any, Dict -from sqlalchemy import not_, select, func - +from sqlalchemy import func, not_, select from sqlalchemy.orm import DeclarativeBase -from src.config.config import global_config + +from src.common.database.sqlalchemy_database_api import get_db_session # from src.common.database.database_model import Messages from src.common.database.sqlalchemy_models import Messages -from src.common.database.sqlalchemy_database_api import get_db_session from src.common.logger import get_logger +from src.config.config import global_config logger = get_logger(__name__) @@ -18,7 +18,7 @@ class Base(DeclarativeBase): pass -def _model_to_dict(instance: Base) -> Dict[str, Any]: +def _model_to_dict(instance: Base) -> dict[str, Any]: """ 将 SQLAlchemy 模型实例转换为字典。 """ @@ -32,12 +32,12 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]: async def find_messages( message_filter: dict[str, Any], - sort: Optional[List[tuple[str, int]]] = None, + sort: list[tuple[str, int]] | None = None, limit: int = 0, limit_mode: str = "latest", filter_bot=False, filter_command=False, -) -> List[dict[str, Any]]: +) -> list[dict[str, Any]]: """ 根据提供的过滤器、排序和限制条件查找消息。 diff --git a/src/common/remote.py b/src/common/remote.py index 95202f810..f6396a037 100644 --- a/src/common/remote.py +++ b/src/common/remote.py @@ -1,13 +1,13 @@ import asyncio import base64 import json +import platform +from datetime import datetime, timezone import aiohttp -import platform - -from datetime import datetime, timezone from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding, rsa + from src.common.logger import get_logger from src.common.tcp_connector import get_tcp_connector from src.config.config import global_config diff --git a/src/common/server.py b/src/common/server.py index 64299274b..ec6ff932a 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -1,20 +1,20 @@ import os -from typing import Optional -from fastapi import FastAPI, APIRouter +from fastapi import APIRouter, FastAPI from fastapi.middleware.cors import CORSMiddleware # 新增导入 from rich.traceback import install -from uvicorn import Config, Server as UvicornServer +from uvicorn import Config +from uvicorn import Server as UvicornServer install(extra_lines=3) class Server: - def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"): + def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MaiMCore"): self.app = FastAPI(title=app_name) self._host: str = "127.0.0.1" self._port: int = 8080 - self._server: Optional[UvicornServer] = None + self._server: UvicornServer | None = None self.set_address(host, port) # 配置 CORS @@ -57,7 +57,7 @@ class Server: """ self.app.include_router(router, prefix=prefix) - def set_address(self, host: Optional[str] = None, port: Optional[int] = None): + def set_address(self, host: str | None = None, port: int | None = None): """设置服务器地址和端口""" if host: self._host = host @@ -76,7 +76,7 @@ class Server: raise except Exception as e: await self.shutdown() - raise RuntimeError(f"服务器运行错误: {str(e)}") from e + raise RuntimeError(f"服务器运行错误: {e!s}") from e finally: await self.shutdown() diff --git a/src/common/tcp_connector.py b/src/common/tcp_connector.py index dd966e648..868b0c3f2 100644 --- a/src/common/tcp_connector.py +++ b/src/common/tcp_connector.py @@ -1,6 +1,7 @@ import ssl -import certifi + import aiohttp +import certifi ssl_context = ssl.create_default_context(cafile=certifi.where()) diff --git a/src/common/vector_db/__init__.py b/src/common/vector_db/__init__.py index a913c2232..65e0a8025 100644 --- a/src/common/vector_db/__init__.py +++ b/src/common/vector_db/__init__.py @@ -18,4 +18,4 @@ def get_vector_db_service() -> VectorDBBase: # 全局向量数据库服务实例 vector_db_service: VectorDBBase = get_vector_db_service() -__all__ = ["vector_db_service", "VectorDBBase"] +__all__ = ["VectorDBBase", "vector_db_service"] diff --git a/src/common/vector_db/base.py b/src/common/vector_db/base.py index 132ea15cb..04449e24a 100644 --- a/src/common/vector_db/base.py +++ b/src/common/vector_db/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any class VectorDBBase(ABC): @@ -36,10 +36,10 @@ class VectorDBBase(ABC): def add( self, collection_name: str, - embeddings: List[List[float]], - documents: Optional[List[str]] = None, - metadatas: Optional[List[Dict[str, Any]]] = None, - ids: Optional[List[str]] = None, + embeddings: list[list[float]], + documents: list[str] | None = None, + metadatas: list[dict[str, Any]] | None = None, + ids: list[str] | None = None, ) -> None: """ 向指定集合中添加数据。 @@ -57,11 +57,11 @@ class VectorDBBase(ABC): def query( self, collection_name: str, - query_embeddings: List[List[float]], + query_embeddings: list[list[float]], n_results: int = 1, - where: Optional[Dict[str, Any]] = None, + where: dict[str, Any] | None = None, **kwargs: Any, - ) -> Dict[str, List[Any]]: + ) -> dict[str, list[Any]]: """ 在指定集合中查询相似向量。 @@ -81,8 +81,8 @@ class VectorDBBase(ABC): def delete( self, collection_name: str, - ids: Optional[List[str]] = None, - where: Optional[Dict[str, Any]] = None, + ids: list[str] | None = None, + where: dict[str, Any] | None = None, ) -> None: """ 从指定集合中删除数据。 @@ -98,13 +98,13 @@ class VectorDBBase(ABC): def get( self, collection_name: str, - ids: Optional[List[str]] = None, - where: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - where_document: Optional[Dict[str, Any]] = None, - include: Optional[List[str]] = None, - ) -> Dict[str, Any]: + ids: list[str] | None = None, + where: dict[str, Any] | None = None, + limit: int | None = None, + offset: int | None = None, + where_document: dict[str, Any] | None = None, + include: list[str] | None = None, + ) -> dict[str, Any]: """ 根据条件从集合中获取数据。 diff --git a/src/common/vector_db/chromadb_impl.py b/src/common/vector_db/chromadb_impl.py index a0267dfed..1934c812e 100644 --- a/src/common/vector_db/chromadb_impl.py +++ b/src/common/vector_db/chromadb_impl.py @@ -1,12 +1,13 @@ import threading -from typing import Any, Dict, List, Optional +from typing import Any import chromadb from chromadb.config import Settings -from .base import VectorDBBase from src.common.logger import get_logger +from .base import VectorDBBase + logger = get_logger("chromadb_impl") @@ -38,7 +39,7 @@ class ChromaDBImpl(VectorDBBase): self.client = chromadb.PersistentClient( path=path, settings=Settings(anonymized_telemetry=False) ) - self._collections: Dict[str, Any] = {} + self._collections: dict[str, Any] = {} self._initialized = True logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}") except Exception as e: @@ -65,10 +66,10 @@ class ChromaDBImpl(VectorDBBase): def add( self, collection_name: str, - embeddings: List[List[float]], - documents: Optional[List[str]] = None, - metadatas: Optional[List[Dict[str, Any]]] = None, - ids: Optional[List[str]] = None, + embeddings: list[list[float]], + documents: list[str] | None = None, + metadatas: list[dict[str, Any]] | None = None, + ids: list[str] | None = None, ) -> None: collection = self.get_or_create_collection(collection_name) if collection: @@ -85,11 +86,11 @@ class ChromaDBImpl(VectorDBBase): def query( self, collection_name: str, - query_embeddings: List[List[float]], + query_embeddings: list[list[float]], n_results: int = 1, - where: Optional[Dict[str, Any]] = None, + where: dict[str, Any] | None = None, **kwargs: Any, - ) -> Dict[str, List[Any]]: + ) -> dict[str, list[Any]]: collection = self.get_or_create_collection(collection_name) if collection: try: @@ -120,7 +121,7 @@ class ChromaDBImpl(VectorDBBase): logger.error(f"回退查询也失败: {fallback_e}") return {} - def _process_where_condition(self, where: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def _process_where_condition(self, where: dict[str, Any]) -> dict[str, Any] | None: """ 处理where条件,转换为ChromaDB支持的格式 ChromaDB支持的格式: @@ -174,13 +175,13 @@ class ChromaDBImpl(VectorDBBase): def get( self, collection_name: str, - ids: Optional[List[str]] = None, - where: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - where_document: Optional[Dict[str, Any]] = None, - include: Optional[List[str]] = None, - ) -> Dict[str, Any]: + ids: list[str] | None = None, + where: dict[str, Any] | None = None, + limit: int | None = None, + offset: int | None = None, + where_document: dict[str, Any] | None = None, + include: list[str] | None = None, + ) -> dict[str, Any]: """根据条件从集合中获取数据""" collection = self.get_or_create_collection(collection_name) if collection: @@ -217,8 +218,8 @@ class ChromaDBImpl(VectorDBBase): def delete( self, collection_name: str, - ids: Optional[List[str]] = None, - where: Optional[Dict[str, Any]] = None, + ids: list[str] | None = None, + where: dict[str, Any] | None = None, ) -> None: collection = self.get_or_create_collection(collection_name) if collection: diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index eb5d1a1f1..de7479efb 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -1,6 +1,7 @@ -from typing import List, Dict, Any, Literal, Union, Optional -from pydantic import Field from threading import Lock +from typing import Any, Literal + +from pydantic import Field from src.config.config_base import ValidatedConfigBase @@ -10,7 +11,7 @@ class APIProvider(ValidatedConfigBase): name: str = Field(..., min_length=1, description="API提供商名称") base_url: str = Field(..., description="API基础URL") - api_key: Union[str, List[str]] = Field(..., min_length=1, description="API密钥,支持单个密钥或密钥列表轮询") + api_key: str | list[str] = Field(..., min_length=1, description="API密钥,支持单个密钥或密钥列表轮询") client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field( default="openai", description="客户端类型(如openai/google等,默认为openai)" ) @@ -70,7 +71,7 @@ class ModelInfo(ValidatedConfigBase): price_in: float = Field(default=0.0, ge=0, description="每M token输入价格") price_out: float = Field(default=0.0, ge=0, description="每M token输出价格") force_stream_mode: bool = Field(default=False, description="是否强制使用流式输出模式") - extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数(用于API调用时的额外配置)") + extra_params: dict[str, Any] = Field(default_factory=dict, description="额外参数(用于API调用时的额外配置)") anti_truncation: bool = Field(default=False, description="是否启用反截断功能,防止模型输出被截断") @classmethod @@ -101,11 +102,11 @@ class ModelInfo(ValidatedConfigBase): class TaskConfig(ValidatedConfigBase): """任务配置类""" - model_list: List[str] = Field(..., description="任务使用的模型列表") + model_list: list[str] = Field(..., description="任务使用的模型列表") max_tokens: int = Field(default=800, description="任务最大输出token数") temperature: float = Field(default=0.7, description="模型温度") concurrency_count: int = Field(default=1, description="并发请求数量") - embedding_dimension: Optional[int] = Field( + embedding_dimension: int | None = Field( default=None, description="嵌入模型输出向量维度,仅在嵌入任务中使用", ge=1, @@ -168,9 +169,9 @@ class ModelTaskConfig(ValidatedConfigBase): class APIAdapterConfig(ValidatedConfigBase): """API Adapter配置类""" - models: List[ModelInfo] = Field(..., min_length=1, description="模型列表") + models: list[ModelInfo] = Field(..., min_length=1, description="模型列表") model_task_config: ModelTaskConfig = Field(..., description="模型任务配置") - api_providers: List[APIProvider] = Field(..., min_length=1, description="API提供商列表") + api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表") def __init__(self, **data): super().__init__(**data) diff --git a/src/config/config.py b/src/config/config.py index 375d513df..846643477 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -1,60 +1,58 @@ import os -import tomlkit import shutil import sys - from datetime import datetime -from tomlkit import TOMLDocument -from tomlkit.items import Table, KeyType -from rich.traceback import install -from typing import List, Optional + +import tomlkit from pydantic import Field +from rich.traceback import install +from tomlkit import TOMLDocument +from tomlkit.items import KeyType, Table from src.common.logger import get_logger from src.config.config_base import ValidatedConfigBase from src.config.official_configs import ( - DatabaseConfig, + AffinityFlowConfig, + AntiPromptInjectionConfig, BotConfig, - PersonalityConfig, - ExpressionConfig, ChatConfig, - NormalChatConfig, - EmojiConfig, - MemoryConfig, - MoodConfig, - KeywordReactionConfig, ChineseTypoConfig, + CommandConfig, + CrossContextConfig, + CustomPromptConfig, + DatabaseConfig, + DebugConfig, + DependencyManagementConfig, + EmojiConfig, + ExperimentalConfig, + ExpressionConfig, + KeywordReactionConfig, + LPMMKnowledgeConfig, + MaimMessageConfig, + MemoryConfig, + MessageReceiveConfig, + MoodConfig, + NormalChatConfig, + PermissionConfig, + PersonalityConfig, + PlanningSystemConfig, + ProactiveThinkingConfig, + RelationshipConfig, ResponsePostProcessConfig, ResponseSplitterConfig, - ExperimentalConfig, - MessageReceiveConfig, - MaimMessageConfig, - LPMMKnowledgeConfig, - RelationshipConfig, - ToolConfig, - VoiceConfig, - DebugConfig, - CustomPromptConfig, - VideoAnalysisConfig, - DependencyManagementConfig, - WebSearchConfig, - AntiPromptInjectionConfig, SleepSystemConfig, - CrossContextConfig, - PermissionConfig, - CommandConfig, - PlanningSystemConfig, - AffinityFlowConfig, - ProactiveThinkingConfig, + ToolConfig, + VideoAnalysisConfig, + VoiceConfig, + WebSearchConfig, ) from .api_ada_configs import ( - ModelTaskConfig, - ModelInfo, APIProvider, + ModelInfo, + ModelTaskConfig, ) - install(extra_lines=3) @@ -148,11 +146,11 @@ def compare_default_values(new, old, path=None, logs=None, changes=None): return logs, changes -def _get_version_from_toml(toml_path) -> Optional[str]: +def _get_version_from_toml(toml_path) -> str | None: """从TOML文件中获取版本号""" if not os.path.exists(toml_path): return None - with open(toml_path, "r", encoding="utf-8") as f: + with open(toml_path, encoding="utf-8") as f: doc = tomlkit.load(f) if "inner" in doc and "version" in doc["inner"]: # type: ignore return doc["inner"]["version"] # type: ignore @@ -264,17 +262,17 @@ def _update_config_generic(config_name: str, template_name: str): # 先读取 compare 下的模板(如果有),用于默认值变动检测 if os.path.exists(compare_path): - with open(compare_path, "r", encoding="utf-8") as f: + with open(compare_path, encoding="utf-8") as f: compare_config = tomlkit.load(f) # 读取当前模板 - with open(template_path, "r", encoding="utf-8") as f: + with open(template_path, encoding="utf-8") as f: new_config = tomlkit.load(f) # 检查默认值变化并处理(只有 compare_config 存在时才做) if compare_config: # 读取旧配置 - with open(old_config_path, "r", encoding="utf-8") as f: + with open(old_config_path, encoding="utf-8") as f: old_config = tomlkit.load(f) logs, changes = compare_default_values(new_config, compare_config) if logs: @@ -304,7 +302,7 @@ def _update_config_generic(config_name: str, template_name: str): # 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次) if old_config is None: - with open(old_config_path, "r", encoding="utf-8") as f: + with open(old_config_path, encoding="utf-8") as f: old_config = tomlkit.load(f) # new_config 已经读取 @@ -350,7 +348,7 @@ def _update_config_generic(config_name: str, template_name: str): # 移除在新模板中已不存在的旧配置项 logger.info(f"开始移除{config_name}中已废弃的配置项...") - with open(template_path, "r", encoding="utf-8") as f: + with open(template_path, encoding="utf-8") as f: template_doc = tomlkit.load(f) _remove_obsolete_keys(new_config, template_doc) logger.info(f"已移除{config_name}中已废弃的配置项") @@ -428,9 +426,9 @@ class Config(ValidatedConfigBase): class APIAdapterConfig(ValidatedConfigBase): """API Adapter配置类""" - models: List[ModelInfo] = Field(..., min_items=1, description="模型列表") + models: list[ModelInfo] = Field(..., min_items=1, description="模型列表") model_task_config: ModelTaskConfig = Field(..., description="模型任务配置") - api_providers: List[APIProvider] = Field(..., min_items=1, description="API提供商列表") + api_providers: list[APIProvider] = Field(..., min_items=1, description="API提供商列表") def __init__(self, **data): super().__init__(**data) @@ -494,7 +492,7 @@ def load_config(config_path: str) -> Config: Config对象 """ # 读取配置文件 - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: config_data = tomlkit.load(f) # 创建Config对象(各个配置类会自动进行 Pydantic 验证) @@ -517,7 +515,7 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig: APIAdapterConfig对象 """ # 读取配置文件 - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: config_data = tomlkit.load(f) config_dict = dict(config_data) diff --git a/src/config/config_base.py b/src/config/config_base.py index 764ec5b71..a80740a46 100644 --- a/src/config/config_base.py +++ b/src/config/config_base.py @@ -1,6 +1,8 @@ -from dataclasses import dataclass, fields, MISSING -from typing import TypeVar, Type, Any, get_origin, get_args, Literal +from dataclasses import MISSING, dataclass, fields +from typing import Any, Literal, TypeVar, get_args, get_origin + from pydantic import BaseModel, ValidationError +from typing_extensions import Self T = TypeVar("T", bound="ConfigBase") @@ -19,7 +21,7 @@ class ConfigBase: """配置类的基类""" @classmethod - def from_dict(cls: Type[T], data: dict[str, Any]) -> T: + def from_dict(cls, data: dict[str, Any]) -> Self: """从字典加载配置字段""" if not isinstance(data, dict): raise TypeError(f"Expected a dictionary, got {type(data).__name__}") @@ -53,7 +55,7 @@ class ConfigBase: return cls() @classmethod - def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any: + def _convert_field(cls, value: Any, field_type: type[Any]) -> Any: """ 转换字段值为指定类型 diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 6a1613baa..ecdb5d5b5 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1,4 +1,5 @@ -from typing import Literal, Optional, List +from typing import Literal + from pydantic import Field from src.config.config_base import ValidatedConfigBase @@ -42,7 +43,7 @@ class BotConfig(ValidatedConfigBase): platform: str = Field(..., description="平台") qq_account: int = Field(..., description="QQ账号") nickname: str = Field(..., description="昵称") - alias_names: List[str] = Field(default_factory=list, description="别名列表") + alias_names: list[str] = Field(default_factory=list, description="别名列表") class PersonalityConfig(ValidatedConfigBase): @@ -54,7 +55,7 @@ class PersonalityConfig(ValidatedConfigBase): background_story: str = Field( default="", description="世界观背景故事,这部分内容会作为背景知识,LLM被指导不应主动复述" ) - safety_guidelines: List[str] = Field( + safety_guidelines: list[str] = Field( default_factory=list, description="安全与互动底线,Bot在任何情况下都必须遵守的原则" ) reply_style: str = Field(default="", description="表达风格") @@ -63,7 +64,7 @@ class PersonalityConfig(ValidatedConfigBase): compress_identity: bool = Field(default=True, description="是否压缩身份") # 回复规则配置 - reply_targeting_rules: List[str] = Field( + reply_targeting_rules: list[str] = Field( default_factory=lambda: [ "拒绝任何包含骚扰、冒犯、暴力、色情或危险内容的请求。", "在拒绝时,请使用符合你人设的、坚定的语气。", @@ -72,7 +73,7 @@ class PersonalityConfig(ValidatedConfigBase): description="安全与互动底线规则,Bot在任何情况下都必须遵守的原则", ) - message_targeting_analysis: List[str] = Field( + message_targeting_analysis: list[str] = Field( default_factory=lambda: [ "**直接针对你**:@你、回复你、明确询问你 → 必须回应", "**间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与", @@ -82,7 +83,7 @@ class PersonalityConfig(ValidatedConfigBase): description="消息针对性分析规则,用于判断是否需要回复", ) - reply_principles: List[str] = Field( + reply_principles: list[str] = Field( default_factory=lambda: [ "明确回应目标消息,而不是宽泛地评论。", "可以分享你的看法、提出相关问题,或者开个合适的玩笑。", @@ -111,7 +112,7 @@ class ChatConfig(ValidatedConfigBase): at_bot_inevitable_reply: bool = Field(default=False, description="@机器人的必然回复") allow_reply_self: bool = Field(default=False, description="是否允许回复自己说的话") focus_value: float = Field(default=1.0, description="专注值") - focus_mode_quiet_groups: List[str] = Field( + focus_mode_quiet_groups: list[str] = Field( default_factory=list, description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]', ) @@ -140,8 +141,8 @@ class ChatConfig(ValidatedConfigBase): class MessageReceiveConfig(ValidatedConfigBase): """消息接收配置类""" - ban_words: List[str] = Field(default_factory=lambda: list(), description="禁用词列表") - ban_msgs_regex: List[str] = Field(default_factory=lambda: list(), description="禁用消息正则列表") + ban_words: list[str] = Field(default_factory=lambda: list(), description="禁用词列表") + ban_msgs_regex: list[str] = Field(default_factory=lambda: list(), description="禁用消息正则列表") class NormalChatConfig(ValidatedConfigBase): @@ -155,16 +156,16 @@ class ExpressionRule(ValidatedConfigBase): use_expression: bool = Field(default=True, description="是否使用学到的表达") learn_expression: bool = Field(default=True, description="是否学习表达") learning_strength: float = Field(default=1.0, description="学习强度") - group: Optional[str] = Field(default=None, description="表达共享组") + group: str | None = Field(default=None, description="表达共享组") class ExpressionConfig(ValidatedConfigBase): """表达配置类""" - rules: List[ExpressionRule] = Field(default_factory=list, description="表达学习规则") + rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则") @staticmethod - def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]: + def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None: """ 解析流配置字符串并生成对应的 chat_id @@ -199,7 +200,7 @@ class ExpressionConfig(ValidatedConfigBase): except (ValueError, IndexError): return None - def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, float]: + def get_expression_config_for_chat(self, chat_stream_id: str | None = None) -> tuple[bool, bool, float]: """ 根据聊天流ID获取表达配置 @@ -362,7 +363,7 @@ class KeywordRuleConfig(ValidatedConfigBase): try: re.compile(pattern) except re.error as e: - raise ValueError(f"无效的正则表达式 '{pattern}': {str(e)}") from e + raise ValueError(f"无效的正则表达式 '{pattern}': {e!s}") from e class KeywordReactionConfig(ValidatedConfigBase): @@ -561,10 +562,10 @@ class SleepSystemConfig(ValidatedConfigBase): # --- 失眠机制相关参数 --- enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统") - insomnia_trigger_delay_minutes: List[int] = Field( + insomnia_trigger_delay_minutes: list[int] = Field( default_factory=lambda: [30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)" ) - insomnia_duration_minutes: List[int] = Field( + insomnia_duration_minutes: list[int] = Field( default_factory=lambda: [15, 45], description="单次失眠状态的持续时间范围(分钟)" ) sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值") @@ -590,7 +591,7 @@ class ContextGroup(ValidatedConfigBase): """上下文共享组配置""" name: str = Field(..., description="共享组的名称") - chat_ids: List[List[str]] = Field( + chat_ids: list[list[str]] = Field( ..., description='属于该组的聊天ID列表,格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]', ) @@ -600,20 +601,20 @@ class CrossContextConfig(ValidatedConfigBase): """跨群聊上下文共享配置""" enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能") - groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表") + groups: list[ContextGroup] = Field(default_factory=list, description="上下文共享组列表") class CommandConfig(ValidatedConfigBase): """命令系统配置类""" - command_prefixes: List[str] = Field(default_factory=lambda: ["/", "!", ".", "#"], description="支持的命令前缀列表") + command_prefixes: list[str] = Field(default_factory=lambda: ["/", "!", ".", "#"], description="支持的命令前缀列表") class PermissionConfig(ValidatedConfigBase): """权限系统配置类""" # Master用户配置(拥有最高权限,无视所有权限节点) - master_users: List[List[str]] = Field( + master_users: list[list[str]] = Field( default_factory=list, description="Master用户列表,格式: [[platform, user_id], ...]" ) @@ -668,10 +669,10 @@ class ProactiveThinkingConfig(ValidatedConfigBase): # --- 作用范围 --- enable_in_private: bool = Field(default=True, description="是否允许在私聊中主动发起对话") enable_in_group: bool = Field(default=True, description="是否允许在群聊中主动发起对话") - enabled_private_chats: List[str] = Field( + enabled_private_chats: list[str] = Field( default_factory=list, description='私聊白名单,为空则对所有私聊生效。格式: ["platform:user_id", ...]' ) - enabled_group_chats: List[str] = Field( + enabled_group_chats: list[str] = Field( default_factory=list, description='群聊白名单,为空则对所有群聊生效。格式: ["platform:group_id", ...]' ) diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index 4716921f9..83c24d4f6 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -1,13 +1,14 @@ -import orjson -import os import hashlib +import os import time +import orjson +from rich.traceback import install + from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import get_person_info_manager -from rich.traceback import install install(extra_lines=3) @@ -193,9 +194,9 @@ class Individuality: """从JSON文件中加载元信息""" if os.path.exists(self.meta_info_file_path): try: - with open(self.meta_info_file_path, "r", encoding="utf-8") as f: + with open(self.meta_info_file_path, encoding="utf-8") as f: return orjson.loads(f.read()) - except (orjson.JSONDecodeError, IOError) as e: + except (OSError, orjson.JSONDecodeError) as e: logger.error(f"读取meta_info文件失败: {e}, 将创建新文件。") return {} return {} @@ -206,16 +207,16 @@ class Individuality: os.makedirs(os.path.dirname(self.meta_info_file_path), exist_ok=True) with open(self.meta_info_file_path, "w", encoding="utf-8") as f: f.write(orjson.dumps(meta_info, option=orjson.OPT_INDENT_2).decode("utf-8")) - except IOError as e: + except OSError as e: logger.error(f"保存meta_info文件失败: {e}") def _load_personality_data(self) -> dict: """从JSON文件中加载personality数据""" if os.path.exists(self.personality_data_file_path): try: - with open(self.personality_data_file_path, "r", encoding="utf-8") as f: + with open(self.personality_data_file_path, encoding="utf-8") as f: return orjson.loads(f.read()) - except (orjson.JSONDecodeError, IOError) as e: + except (OSError, orjson.JSONDecodeError) as e: logger.error(f"读取personality_data文件失败: {e}, 将创建新文件。") return {} return {} @@ -227,7 +228,7 @@ class Individuality: with open(self.personality_data_file_path, "w", encoding="utf-8") as f: f.write(orjson.dumps(personality_data, option=orjson.OPT_INDENT_2).decode("utf-8")) logger.debug(f"已保存personality数据到文件: {self.personality_data_file_path}") - except IOError as e: + except OSError as e: logger.error(f"保存personality_data文件失败: {e}") def _get_personality_from_file(self) -> tuple[str, str]: diff --git a/src/individuality/not_using/offline_llm.py b/src/individuality/not_using/offline_llm.py index 2bafb69aa..752293ab8 100644 --- a/src/individuality/not_using/offline_llm.py +++ b/src/individuality/not_using/offline_llm.py @@ -1,13 +1,13 @@ import asyncio import os import time -from typing import Tuple, Union import aiohttp import requests +from rich.traceback import install + from src.common.logger import get_logger from src.common.tcp_connector import get_tcp_connector -from rich.traceback import install install(extra_lines=3) @@ -26,7 +26,7 @@ class LLMRequestOff: # logger.info(f"API URL: {self.base_url}") # 使用 logger 记录 base_url - def generate_response(self, prompt: str) -> Union[str, Tuple[str, str]]: + def generate_response(self, prompt: str) -> str | tuple[str, str]: """根据输入的提示生成模型的响应""" headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} @@ -67,16 +67,16 @@ class LLMRequestOff: except Exception as e: if retry < max_retries - 1: # 如果还有重试机会 wait_time = base_wait_time * (2**retry) - logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {e!s}") time.sleep(wait_time) else: - logger.error(f"请求失败: {str(e)}") - return f"请求失败: {str(e)}", "" + logger.error(f"请求失败: {e!s}") + return f"请求失败: {e!s}", "" logger.error("达到最大重试次数,请求仍然失败") return "达到最大重试次数,请求仍然失败", "" - async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]: + async def generate_response_async(self, prompt: str) -> str | tuple[str, str]: """异步方式根据输入的提示生成模型的响应""" headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} @@ -117,11 +117,11 @@ class LLMRequestOff: except Exception as e: if retry < max_retries - 1: # 如果还有重试机会 wait_time = base_wait_time * (2**retry) - logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {e!s}") await asyncio.sleep(wait_time) else: - logger.error(f"请求失败: {str(e)}") - return f"请求失败: {str(e)}", "" + logger.error(f"请求失败: {e!s}") + return f"请求失败: {e!s}", "" logger.error("达到最大重试次数,请求仍然失败") return "达到最大重试次数,请求仍然失败", "" diff --git a/src/individuality/not_using/per_bf_gen.py b/src/individuality/not_using/per_bf_gen.py index 9e4d0291f..4aea7e7de 100644 --- a/src/individuality/not_using/per_bf_gen.py +++ b/src/individuality/not_using/per_bf_gen.py @@ -1,10 +1,10 @@ -from typing import Dict, List -import orjson import os -from dotenv import load_dotenv -import sys -import toml import random +import sys + +import orjson +import toml +from dotenv import load_dotenv from tqdm import tqdm # 添加项目根目录到 Python 路径 @@ -13,13 +13,13 @@ sys.path.append(root_path) # 加载配置文件 config_path = os.path.join(root_path, "config", "bot_config.toml") -with open(config_path, "r", encoding="utf-8") as f: +with open(config_path, encoding="utf-8") as f: config = toml.load(f) # 现在可以导入src模块 from individuality.not_using.scene import get_scene_by_factor, PERSONALITY_SCENES # noqa E402 -from individuality.not_using.questionnaire import FACTOR_DESCRIPTIONS # noqa E402 -from individuality.not_using.offline_llm import LLMRequestOff # noqa E402 +from individuality.not_using.questionnaire import FACTOR_DESCRIPTIONS +from individuality.not_using.offline_llm import LLMRequestOff # 加载环境变量 env_path = os.path.join(root_path, ".env") @@ -75,7 +75,7 @@ def adapt_scene(scene: str) -> str: return adapted_scene except Exception as e: - print(f"场景改编过程出错:{str(e)},将使用原始场景") + print(f"场景改编过程出错:{e!s},将使用原始场景") return scene @@ -83,8 +83,8 @@ class PersonalityEvaluatorDirect: def __init__(self): self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} self.scenarios = [] - self.final_scores: Dict[str, float] = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} - self.dimension_counts = {trait: 0 for trait in self.final_scores} + self.final_scores: dict[str, float] = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} + self.dimension_counts = dict.fromkeys(self.final_scores, 0) # 为每个人格特质获取对应的场景 for trait in PERSONALITY_SCENES: @@ -112,7 +112,7 @@ class PersonalityEvaluatorDirect: self.llm = LLMRequestOff() - def evaluate_response(self, scenario: str, response: str, dimensions: List[str]) -> Dict[str, float]: + def evaluate_response(self, scenario: str, response: str, dimensions: list[str]) -> dict[str, float]: """ 使用 DeepSeek AI 评估用户对特定场景的反应 """ @@ -163,10 +163,10 @@ class PersonalityEvaluatorDirect: return {k: max(1, min(6, float(v))) for k, v in scores.items()} else: print("AI响应格式不正确,使用默认评分") - return {dim: 3.5 for dim in dimensions} + return dict.fromkeys(dimensions, 3.5) except Exception as e: - print(f"评估过程出错:{str(e)}") - return {dim: 3.5 for dim in dimensions} + print(f"评估过程出错:{e!s}") + return dict.fromkeys(dimensions, 3.5) def run_evaluation(self): """ diff --git a/src/individuality/not_using/scene.py b/src/individuality/not_using/scene.py index 929a9c426..9c16358e6 100644 --- a/src/individuality/not_using/scene.py +++ b/src/individuality/not_using/scene.py @@ -1,7 +1,8 @@ -import orjson import os from typing import Any +import orjson + def load_scenes() -> dict[str, Any]: """ @@ -13,7 +14,7 @@ def load_scenes() -> dict[str, Any]: current_dir = os.path.dirname(os.path.abspath(__file__)) json_path = os.path.join(current_dir, "template_scene.json") - with open(json_path, "r", encoding="utf-8") as f: + with open(json_path, encoding="utf-8") as f: return orjson.loads(f.read()) diff --git a/src/llm_models/exceptions.py b/src/llm_models/exceptions.py index 5b04f58c6..ad2b9a69d 100644 --- a/src/llm_models/exceptions.py +++ b/src/llm_models/exceptions.py @@ -1,6 +1,5 @@ from typing import Any - # 常见Error Code Mapping (以OpenAI API为例) error_code_mapping = { 400: "参数不正确", diff --git a/src/llm_models/model_client/aiohttp_gemini_client.py b/src/llm_models/model_client/aiohttp_gemini_client.py index 3d4dd8ca1..84470fb60 100644 --- a/src/llm_models/model_client/aiohttp_gemini_client.py +++ b/src/llm_models/model_client/aiohttp_gemini_client.py @@ -1,21 +1,24 @@ import asyncio -import orjson import io -from typing import Callable, Any, Coroutine, Optional -import aiohttp +from collections.abc import Callable, Coroutine +from typing import Any + +import aiohttp +import orjson -from src.config.api_ada_configs import ModelInfo, APIProvider from src.common.logger import get_logger -from .base_client import APIResponse, UsageRecord, BaseClient, client_registry +from src.config.api_ada_configs import APIProvider, ModelInfo + from ..exceptions import ( - RespParseException, NetworkConnectionError, - RespNotOkException, ReqAbortException, + RespNotOkException, + RespParseException, ) from ..payload_content.message import Message, RoleType from ..payload_content.resp_format import RespFormat, RespFormatType -from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall +from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam +from .base_client import APIResponse, BaseClient, UsageRecord, client_registry logger = get_logger("AioHTTP-Gemini客户端") @@ -210,7 +213,7 @@ class AiohttpGeminiStreamParser: chunk_data = orjson.loads(chunk_text) # 解析候选项 - if "candidates" in chunk_data and chunk_data["candidates"]: + if chunk_data.get("candidates"): candidate = chunk_data["candidates"][0] # 解析内容 @@ -266,7 +269,7 @@ class AiohttpGeminiStreamParser: async def _default_stream_response_handler( response: aiohttp.ClientResponse, interrupt_flag: asyncio.Event | None, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: +) -> tuple[APIResponse, tuple[int, int, int] | None]: """默认流式响应处理器""" parser = AiohttpGeminiStreamParser() @@ -290,13 +293,13 @@ async def _default_stream_response_handler( def _default_normal_response_parser( response_data: dict, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: +) -> tuple[APIResponse, tuple[int, int, int] | None]: """默认普通响应解析器""" api_response = APIResponse() try: # 解析候选项 - if "candidates" in response_data and response_data["candidates"]: + if response_data.get("candidates"): candidate = response_data["candidates"][0] # 解析文本内容 @@ -419,13 +422,12 @@ class AiohttpGeminiClient(BaseClient): max_tokens: int = 1024, temperature: float = 0.7, response_format: RespFormat | None = None, - stream_response_handler: Optional[ - Callable[ - [aiohttp.ClientResponse, asyncio.Event | None], - Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]], - ] - ] = None, - async_response_parser: Optional[Callable[[dict], tuple[APIResponse, Optional[tuple[int, int, int]]]]] = None, + stream_response_handler: Callable[ + [aiohttp.ClientResponse, asyncio.Event | None], + Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]], + ] + | None = None, + async_response_parser: Callable[[dict], tuple[APIResponse, tuple[int, int, int] | None]] | None = None, interrupt_flag: asyncio.Event | None = None, extra_params: dict[str, Any] | None = None, ) -> APIResponse: diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index eb74b0dfe..88f8601d6 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -1,12 +1,14 @@ import asyncio -from dataclasses import dataclass from abc import ABC, abstractmethod -from typing import Callable, Any, Optional +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from src.config.api_ada_configs import APIProvider, ModelInfo -from src.config.api_ada_configs import ModelInfo, APIProvider from ..payload_content.message import Message from ..payload_content.resp_format import RespFormat -from ..payload_content.tool_option import ToolOption, ToolCall +from ..payload_content.tool_option import ToolCall, ToolOption @dataclass @@ -75,9 +77,8 @@ class BaseClient(ABC): max_tokens: int = 1024, temperature: float = 0.7, response_format: RespFormat | None = None, - stream_response_handler: Optional[ - Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]] - ] = None, + stream_response_handler: Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]] + | None = None, async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None, interrupt_flag: asyncio.Event | None = None, extra_params: dict[str, Any] | None = None, diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 0ef79a89b..8005affaa 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -1,17 +1,17 @@ import asyncio -import io -import orjson -import re import base64 -from collections.abc import Iterable -from typing import Callable, Any, Coroutine, Optional -from json_repair import repair_json +import io +import re +from collections.abc import Callable, Coroutine, Iterable +from typing import Any +import orjson +from json_repair import repair_json from openai import ( - AsyncOpenAI, + NOT_GIVEN, APIConnectionError, APIStatusError, - NOT_GIVEN, + AsyncOpenAI, AsyncStream, ) from openai.types.chat import ( @@ -22,18 +22,19 @@ from openai.types.chat import ( ) from openai.types.chat.chat_completion_chunk import ChoiceDelta -from src.config.api_ada_configs import ModelInfo, APIProvider from src.common.logger import get_logger -from .base_client import APIResponse, UsageRecord, BaseClient, client_registry +from src.config.api_ada_configs import APIProvider, ModelInfo + from ..exceptions import ( - RespParseException, NetworkConnectionError, - RespNotOkException, ReqAbortException, + RespNotOkException, + RespParseException, ) from ..payload_content.message import Message, RoleType from ..payload_content.resp_format import RespFormat -from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall +from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam +from .base_client import APIResponse, BaseClient, UsageRecord, client_registry logger = get_logger("OpenAI客户端") @@ -241,7 +242,7 @@ def _build_stream_api_resp( async def _default_stream_response_handler( resp_stream: AsyncStream[ChatCompletionChunk], interrupt_flag: asyncio.Event | None, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: +) -> tuple[APIResponse, tuple[int, int, int] | None]: """ 流式响应处理函数 - 处理OpenAI API的流式响应 :param resp_stream: 流式响应对象 @@ -315,7 +316,7 @@ pattern = re.compile( def _default_normal_response_parser( resp: ChatCompletion, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: +) -> tuple[APIResponse, tuple[int, int, int] | None]: """ 解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象 :param resp: 响应对象 @@ -391,15 +392,13 @@ class OpenaiClient(BaseClient): max_tokens: int = 1024, temperature: float = 0.7, response_format: RespFormat | None = None, - stream_response_handler: Optional[ - Callable[ - [AsyncStream[ChatCompletionChunk], asyncio.Event | None], - Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]], - ] - ] = None, - async_response_parser: Optional[ - Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]] - ] = None, + stream_response_handler: Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], + Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]], + ] + | None = None, + async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int] | None]] + | None = None, interrupt_flag: asyncio.Event | None = None, extra_params: dict[str, Any] | None = None, ) -> APIResponse: @@ -514,17 +513,17 @@ class OpenaiClient(BaseClient): ) except APIConnectionError as e: # 添加详细的错误信息以便调试 - logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}") + logger.error(f"OpenAI API连接错误(嵌入模型): {e!s}") logger.error(f"错误类型: {type(e)}") if hasattr(e, "__cause__") and e.__cause__: - logger.error(f"底层错误: {str(e.__cause__)}") + logger.error(f"底层错误: {e.__cause__!s}") raise NetworkConnectionError() from e except APIStatusError as e: # 重封装APIError为RespNotOkException raise RespNotOkException(e.status_code) from e except Exception as e: # 添加通用异常处理和日志记录 - logger.error(f"获取嵌入时发生未知错误: {str(e)}") + logger.error(f"获取嵌入时发生未知错误: {e!s}") logger.error(f"错误类型: {type(e)}") raise diff --git a/src/llm_models/payload_content/message.py b/src/llm_models/payload_content/message.py index 17d1fa30b..7a34349a3 100644 --- a/src/llm_models/payload_content/message.py +++ b/src/llm_models/payload_content/message.py @@ -1,6 +1,5 @@ from enum import Enum - # 设计这系列类的目的是为未来可能的扩展做准备 diff --git a/src/llm_models/payload_content/resp_format.py b/src/llm_models/payload_content/resp_format.py index e1baa3742..342fbf327 100644 --- a/src/llm_models/payload_content/resp_format.py +++ b/src/llm_models/payload_content/resp_format.py @@ -1,8 +1,8 @@ from enum import Enum -from typing import Optional, Any +from typing import Any from pydantic import BaseModel -from typing_extensions import TypedDict, Required +from typing_extensions import Required, TypedDict class RespFormatType(Enum): @@ -20,7 +20,7 @@ class JsonSchema(TypedDict, total=False): of 64. """ - description: Optional[str] + description: str | None """ A description of what the response format is for, used by the model to determine how to respond in the format. @@ -32,7 +32,7 @@ class JsonSchema(TypedDict, total=False): to build JSON schemas [here](https://json-schema.org/). """ - strict: Optional[bool] + strict: bool | None """ Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the `schema` @@ -100,7 +100,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]: # 如果当前Schema是列表,则遍历每个元素 for i in range(len(sub_schema)): if isinstance(sub_schema[i], dict): - sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs) + sub_schema[i] = link_definitions_recursive(f"{path}/{i!s}", sub_schema[i], defs) else: # 否则为字典 if "$defs" in sub_schema: @@ -140,8 +140,7 @@ def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]: schema[idx] = _remove_title(item) elif isinstance(schema, dict): # 是字典,移除title字段,并对所有dict/list子元素递归调用 - if "$defs" in schema: - del schema["$defs"] + schema.pop("$defs", None) for key, value in schema.items(): if isinstance(value, (dict, list)): schema[key] = _remove_title(value) diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index c322e2ffb..bcac832f1 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -1,14 +1,15 @@ import base64 import io - -from PIL import Image from datetime import datetime -from src.common.logger import get_logger +from PIL import Image + from src.common.database.sqlalchemy_models import LLMUsage, get_db_session +from src.common.logger import get_logger from src.config.api_ada_configs import ModelInfo -from .payload_content.message import Message, MessageBuilder + from .model_client.base_client import UsageRecord +from .payload_content.message import Message, MessageBuilder logger = get_logger("消息压缩工具") @@ -38,7 +39,7 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * return image_data except Exception as e: - logger.error(f"图片转换格式失败: {str(e)}") + logger.error(f"图片转换格式失败: {e!s}") return image_data def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]: @@ -87,7 +88,7 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * return output_buffer.getvalue(), original_size, new_size except Exception as e: - logger.error(f"图片缩放失败: {str(e)}") + logger.error(f"图片缩放失败: {e!s}") import traceback logger.error(traceback.format_exc()) @@ -188,7 +189,7 @@ class LLMUsageRecorder: f"总计: {model_usage.total_tokens}" ) except Exception as e: - logger.error(f"记录token使用情况失败: {str(e)}") + logger.error(f"记录token使用情况失败: {e!s}") llm_usage_recorder = LLMUsageRecorder() diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index a8a68c2fb..afb2f13ed 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ @desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。 它被设计为一个高度容错和可扩展的系统,包含以下主要组件: @@ -19,24 +18,26 @@ 作为模块的统一入口(Facade),为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。 """ -import re import asyncio -import time import random +import re import string - +import time +from collections.abc import Callable, Coroutine from enum import Enum +from typing import Any + from rich.traceback import install -from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine from src.common.logger import get_logger -from src.config.config import model_config from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig -from .payload_content.message import MessageBuilder, Message -from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder -from .model_client.base_client import BaseClient, APIResponse, client_registry, UsageRecord -from .utils import compress_messages, llm_usage_recorder +from src.config.config import model_config + from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException +from .model_client.base_client import APIResponse, BaseClient, UsageRecord, client_registry +from .payload_content.message import Message, MessageBuilder +from .payload_content.tool_option import ToolCall, ToolOption, ToolOptionBuilder +from .utils import compress_messages, llm_usage_recorder install(extra_lines=3) @@ -139,7 +140,7 @@ class _ModelSelector: CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数 DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量 - def __init__(self, model_list: List[str], model_usage: Dict[str, Tuple[int, int, int]]): + def __init__(self, model_list: list[str], model_usage: dict[str, tuple[int, int, int]]): """ 初始化模型选择器。 @@ -153,7 +154,7 @@ class _ModelSelector: def select_best_available_model( self, failed_models_in_this_request: set, request_type: str - ) -> Optional[Tuple[ModelInfo, APIProvider, BaseClient]]: + ) -> tuple[ModelInfo, APIProvider, BaseClient] | None: """ 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。 @@ -306,7 +307,7 @@ class _PromptProcessor: return processed_prompt - def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]: + def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]: """ 处理响应内容,提取思维链并检查截断。 @@ -393,7 +394,7 @@ class _PromptProcessor: return " ".join(result) @staticmethod - def _extract_reasoning(content: str) -> Tuple[str, str]: + def _extract_reasoning(content: str) -> tuple[str, str]: """ 从模型返回的完整内容中提取被...标签包裹的思考过程, 并返回清理后的内容和思考过程。 @@ -462,7 +463,7 @@ class _RequestExecutor: RuntimeError: 如果达到最大重试次数。 """ retry_remain = api_provider.max_retry - compressed_messages: Optional[List[Message]] = None + compressed_messages: list[Message] | None = None while retry_remain > 0: try: @@ -487,7 +488,7 @@ class _RequestExecutor: return await client.get_audio_transcriptions(model_info=model_info, **kwargs) except Exception as e: - logger.debug(f"请求失败: {str(e)}") + logger.debug(f"请求失败: {e!s}") # 记录失败并更新模型的惩罚值 self.model_selector.update_failure_penalty(model_info.name, e) @@ -514,7 +515,7 @@ class _RequestExecutor: def _handle_exception( self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info - ) -> Tuple[int, Optional[List[Message]]]: + ) -> tuple[int, list[Message] | None]: """ 默认异常处理函数,决定是否重试。 @@ -532,12 +533,12 @@ class _RequestExecutor: logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}") return -1, None else: - logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {str(e)}") + logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {e!s}") return -1, None def _handle_resp_not_ok( self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info - ) -> Tuple[int, Optional[List[Message]]]: + ) -> tuple[int, list[Message] | None]: """ 处理非200的HTTP响应异常。 @@ -583,7 +584,7 @@ class _RequestExecutor: logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}") return -1, None - def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> Tuple[int, None]: + def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]: """ 辅助函数,根据剩余次数决定是否进行下一次重试。 @@ -620,7 +621,7 @@ class _RequestStrategy: model_selector: _ModelSelector, prompt_processor: _PromptProcessor, executor: _RequestExecutor, - model_list: List[str], + model_list: list[str], task_name: str, ): """ @@ -644,13 +645,13 @@ class _RequestStrategy: request_type: RequestType, raise_when_empty: bool = True, **kwargs, - ) -> Tuple[APIResponse, ModelInfo]: + ) -> tuple[APIResponse, ModelInfo]: """ 执行请求,动态选择最佳可用模型,并在模型失败时进行故障转移。 """ failed_models_in_this_request = set() max_attempts = len(self.model_list) - last_exception: Optional[Exception] = None + last_exception: Exception | None = None for attempt in range(max_attempts): selection_result = self.model_selector.select_best_available_model( @@ -787,9 +788,7 @@ class LLMRequest: """ self.task_name = request_type self.model_for_task = model_set - self.model_usage: Dict[str, Tuple[int, int, int]] = { - model: (0, 0, 0) for model in self.model_for_task.model_list - } + self.model_usage: dict[str, tuple[int, int, int]] = dict.fromkeys(self.model_for_task.model_list, (0, 0, 0)) """模型使用量记录,(total_tokens, penalty, usage_penalty)""" # 初始化辅助类 @@ -805,9 +804,9 @@ class LLMRequest: prompt: str, image_base64: str, image_format: str, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + temperature: float | None = None, + max_tokens: int | None = None, + ) -> tuple[str, tuple[str, str, list[ToolCall] | None]]: """ 为图像生成响应。 @@ -855,7 +854,7 @@ class LLMRequest: return content, (reasoning, model_info.name, response.tool_calls) - async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: + async def generate_response_for_voice(self, voice_base64: str) -> str | None: """ 为语音生成响应(语音转文字)。 使用故障转移策略来确保即使主模型失败也能获得结果。 @@ -872,11 +871,11 @@ class LLMRequest: async def generate_response_async( self, prompt: str, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, + temperature: float | None = None, + max_tokens: int | None = None, + tools: list[dict[str, Any]] | None = None, raise_when_empty: bool = True, - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + ) -> tuple[str, tuple[str, str, list[ToolCall] | None]]: """ 异步生成响应,支持并发请求。 @@ -914,11 +913,11 @@ class LLMRequest: async def _execute_single_text_request( self, prompt: str, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, + temperature: float | None = None, + max_tokens: int | None = None, + tools: list[dict[str, Any]] | None = None, raise_when_empty: bool = True, - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + ) -> tuple[str, tuple[str, str, list[ToolCall] | None]]: """ 执行单次文本生成请求的内部方法。 这是 `generate_response_async` 的核心实现,处理单个请求的完整生命周期, @@ -956,7 +955,7 @@ class LLMRequest: return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls) - async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: + async def get_embedding(self, embedding_input: str) -> tuple[list[float], str]: """ 获取嵌入向量。 @@ -978,7 +977,7 @@ class LLMRequest: return response.embedding, model_info.name - async def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str): + async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str): """ 记录模型使用情况。 @@ -1009,7 +1008,7 @@ class LLMRequest: ) @staticmethod - def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: + def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None: """ 根据输入的字典列表构建并验证 `ToolOption` 对象列表。 @@ -1028,7 +1027,7 @@ class LLMRequest: if not tools: return None - tool_options: List[ToolOption] = [] + tool_options: list[ToolOption] = [] # 遍历每个工具定义 for tool in tools: try: diff --git a/src/main.py b/src/main.py index 4e91f1419..914647508 100644 --- a/src/main.py +++ b/src/main.py @@ -1,40 +1,40 @@ # 再用这个就写一行注释来混提交的我直接全部🌿飞😡 import asyncio -import time import signal import sys -from functools import partial +import time import traceback -from typing import Dict, Any +from functools import partial +from typing import Any from maim_message import MessageServer - -from src.common.remote import TelemetryHeartBeatTask -from src.manager.async_task_manager import async_task_manager -from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask -from src.chat.emoji_system.emoji_manager import get_emoji_manager -from src.chat.message_receive.chat_stream import get_chat_manager -from src.config.config import global_config -from src.chat.message_receive.bot import chat_bot -from src.common.logger import get_logger -from src.individuality.individuality import get_individuality, Individuality -from src.common.server import get_global_server, Server -from src.mood.mood_manager import mood_manager from rich.traceback import install -from src.schedule.schedule_manager import schedule_manager -from src.schedule.monthly_plan_manager import monthly_plan_manager -from src.plugin_system.core.event_manager import event_manager -from src.plugin_system.base.component_types import EventType -# from src.api.main import start_api_server -# 导入新的插件管理器 -from src.plugin_system.core.plugin_manager import plugin_manager - -# 导入消息API和traceback模块 -from src.common.message import get_global_api +from src.chat.emoji_system.emoji_manager import get_emoji_manager # 导入增强记忆系统管理器 from src.chat.memory_system.memory_manager import memory_manager +from src.chat.message_receive.bot import chat_bot +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask +from src.common.logger import get_logger + +# 导入消息API和traceback模块 +from src.common.message import get_global_api +from src.common.remote import TelemetryHeartBeatTask +from src.common.server import Server, get_global_server +from src.config.config import global_config +from src.individuality.individuality import Individuality, get_individuality +from src.manager.async_task_manager import async_task_manager +from src.mood.mood_manager import mood_manager +from src.plugin_system.base.component_types import EventType +from src.plugin_system.core.event_manager import event_manager + +# from src.api.main import start_api_server +# 导入新的插件管理器 +from src.plugin_system.core.plugin_manager import plugin_manager +from src.schedule.monthly_plan_manager import monthly_plan_manager +from src.schedule.schedule_manager import schedule_manager # 插件系统现在使用统一的插件加载器 @@ -115,8 +115,8 @@ class MainSystem: # 停止消息重组器 try: - from src.plugin_system.core.event_manager import event_manager from src.plugin_system import EventType + from src.plugin_system.core.event_manager import event_manager from src.utils.message_chunker import reassembler await event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM") @@ -151,7 +151,7 @@ class MainSystem: except Exception as e: logger.error(f"同步清理资源时出错: {e}") - async def _message_process_wrapper(self, message_data: Dict[str, Any]): + async def _message_process_wrapper(self, message_data: dict[str, Any]): """并行处理消息的包装器""" try: start_time = time.time() @@ -225,8 +225,8 @@ MoFox_Bot(第三方修改版) event_manager.init_default_events() # 初始化权限管理器 - from src.plugin_system.core.permission_manager import PermissionManager from src.plugin_system.apis.permission_api import permission_api + from src.plugin_system.core.permission_manager import PermissionManager permission_manager = PermissionManager() await permission_manager.initialize() diff --git a/src/mais4u/mai_think.py b/src/mais4u/mai_think.py index 4c34c4798..6725e43db 100644 --- a/src/mais4u/mai_think.py +++ b/src/mais4u/mai_think.py @@ -1,12 +1,13 @@ -from src.chat.message_receive.chat_stream import get_chat_manager import time -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config + +from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.message import MessageRecvS4U -from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor -from src.mais4u.mais4u_chat.internal_manager import internal_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.common.logger import get_logger +from src.config.config import model_config +from src.llm_models.utils_model import LLMRequest +from src.mais4u.mais4u_chat.internal_manager import internal_manager +from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor logger = get_logger(__name__) diff --git a/src/mais4u/mais4u_chat/body_emotion_action_manager.py b/src/mais4u/mais4u_chat/body_emotion_action_manager.py index 38073baa4..423eeaf16 100644 --- a/src/mais4u/mais4u_chat/body_emotion_action_manager.py +++ b/src/mais4u/mais4u_chat/body_emotion_action_manager.py @@ -1,18 +1,18 @@ -import orjson import time +import orjson from json_repair import repair_json + from src.chat.message_receive.message import MessageRecv -from src.llm_models.utils_model import LLMRequest -from src.common.logger import get_logger from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive -from src.config.config import global_config, model_config from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest +from src.mais4u.s4u_config import s4u_config from src.manager.async_task_manager import AsyncTask, async_task_manager from src.plugin_system.apis import send_api -from src.mais4u.s4u_config import s4u_config - logger = get_logger("action") HEAD_CODE = { diff --git a/src/mais4u/mais4u_chat/context_web_manager.py b/src/mais4u/mais4u_chat/context_web_manager.py index 3bd107c55..422e6207b 100644 --- a/src/mais4u/mais4u_chat/context_web_manager.py +++ b/src/mais4u/mais4u_chat/context_web_manager.py @@ -1,10 +1,10 @@ import asyncio -import orjson from collections import deque from datetime import datetime -from typing import Dict, List, Optional -from aiohttp import web, WSMsgType + import aiohttp_cors +import orjson +from aiohttp import WSMsgType, web from src.chat.message_receive.message import MessageRecv from src.common.logger import get_logger @@ -57,8 +57,8 @@ class ContextWebManager: def __init__(self, max_messages: int = 10, port: int = 8765): self.max_messages = max_messages self.port = port - self.contexts: Dict[str, deque] = {} # chat_id -> deque of ContextMessage - self.websockets: List[web.WebSocketResponse] = [] + self.contexts: dict[str, deque] = {} # chat_id -> deque of ContextMessage + self.websockets: list[web.WebSocketResponse] = [] self.app = None self.runner = None self.site = None @@ -674,7 +674,7 @@ class ContextWebManager: # 全局实例 -_context_web_manager: Optional[ContextWebManager] = None +_context_web_manager: ContextWebManager | None = None def get_context_web_manager() -> ContextWebManager: diff --git a/src/mais4u/mais4u_chat/gift_manager.py b/src/mais4u/mais4u_chat/gift_manager.py index d489550c3..976476225 100644 --- a/src/mais4u/mais4u_chat/gift_manager.py +++ b/src/mais4u/mais4u_chat/gift_manager.py @@ -1,5 +1,5 @@ import asyncio -from typing import Dict, Tuple, Callable, Optional +from collections.abc import Callable from dataclasses import dataclass from src.chat.message_receive.message import MessageRecvS4U @@ -23,11 +23,11 @@ class GiftManager: def __init__(self): """初始化礼物管理器""" - self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {} + self.pending_gifts: dict[tuple[str, str], PendingGift] = {} self.debounce_timeout = 5.0 # 3秒防抖时间 async def handle_gift( - self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None + self, message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None = None ) -> bool: """处理礼物消息,返回是否应该立即处理 @@ -53,7 +53,7 @@ class GiftManager: await self._create_pending_gift(gift_key, message, callback) return False - async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None: + async def _merge_gift(self, gift_key: tuple[str, str], new_message: MessageRecvS4U) -> None: """合并礼物消息""" pending_gift = self.pending_gifts[gift_key] @@ -81,7 +81,7 @@ class GiftManager: logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}") async def _create_pending_gift( - self, gift_key: Tuple[str, str], message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] + self, gift_key: tuple[str, str], message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None ) -> None: """创建新的等待礼物""" try: @@ -100,7 +100,7 @@ class GiftManager: logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}") - async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None: + async def _gift_timeout(self, gift_key: tuple[str, str]) -> None: """礼物防抖超时处理""" try: # 等待防抖时间 diff --git a/src/mais4u/mais4u_chat/internal_manager.py b/src/mais4u/mais4u_chat/internal_manager.py index 4b3db3263..3e4a518d4 100644 --- a/src/mais4u/mais4u_chat/internal_manager.py +++ b/src/mais4u/mais4u_chat/internal_manager.py @@ -1,6 +1,6 @@ class InternalManager: def __init__(self): - self.now_internal_state = str() + self.now_internal_state = "" def set_internal_state(self, internal_state: str): self.now_internal_state = internal_state diff --git a/src/mais4u/mais4u_chat/s4u_chat.py b/src/mais4u/mais4u_chat/s4u_chat.py index 192e858b6..80bd91e22 100644 --- a/src/mais4u/mais4u_chat/s4u_chat.py +++ b/src/mais4u/mais4u_chat/s4u_chat.py @@ -1,25 +1,27 @@ import asyncio -import traceback -import time import random -from typing import Optional, Dict, Tuple, List # 导入类型提示 -from maim_message import UserInfo, Seg -from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from .s4u_stream_generator import S4UStreamGenerator -from src.chat.message_receive.message import MessageSending, MessageRecv, MessageRecvS4U -from src.config.config import global_config -from src.common.message.api import get_global_api -from src.chat.message_receive.storage import MessageStorage -from .s4u_watching_manager import watching_manager +import time +import traceback + import orjson -from .s4u_mood_manager import mood_manager -from src.person_info.relationship_builder_manager import relationship_builder_manager +from maim_message import Seg, UserInfo + +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.message import MessageRecv, MessageRecvS4U, MessageSending +from src.chat.message_receive.storage import MessageStorage +from src.common.logger import get_logger +from src.common.message.api import get_global_api +from src.config.config import global_config +from src.mais4u.constant_s4u import ENABLE_S4U from src.mais4u.s4u_config import s4u_config from src.person_info.person_info import PersonInfoManager +from src.person_info.relationship_builder_manager import relationship_builder_manager + +from .s4u_mood_manager import mood_manager +from .s4u_stream_generator import S4UStreamGenerator +from .s4u_watching_manager import watching_manager from .super_chat_manager import get_super_chat_manager from .yes_or_no import yes_or_no_head -from src.mais4u.constant_s4u import ENABLE_S4U logger = get_logger("S4U_chat") @@ -32,7 +34,7 @@ class MessageSenderContainer: self.original_message = original_message self.queue = asyncio.Queue() self.storage = MessageStorage() - self._task: Optional[asyncio.Task] = None + self._task: asyncio.Task | None = None self._paused_event = asyncio.Event() self._paused_event.set() # 默认设置为非暂停状态 @@ -158,7 +160,7 @@ class MessageSenderContainer: class S4UChatManager: def __init__(self): - self.s4u_chats: Dict[str, "S4UChat"] = {} + self.s4u_chats: dict[str, "S4UChat"] = {} def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat": if chat_stream.stream_id not in self.s4u_chats: @@ -196,16 +198,16 @@ class S4UChat: self._new_message_event = asyncio.Event() # 用于唤醒处理器 self._processing_task = asyncio.create_task(self._message_processor()) - self._current_generation_task: Optional[asyncio.Task] = None + self._current_generation_task: asyncio.Task | None = None # 当前消息的元数据:(队列类型, 优先级分数, 计数器, 消息对象) - self._current_message_being_replied: Optional[Tuple[str, float, int, MessageRecv]] = None + self._current_message_being_replied: tuple[str, float, int, MessageRecv] | None = None self._is_replying = False self.gpt = S4UStreamGenerator() self.gpt.chat_stream = self.chat_stream - self.interest_dict: Dict[str, float] = {} # 用户兴趣分 + self.interest_dict: dict[str, float] = {} # 用户兴趣分 - self.internal_message: List[MessageRecvS4U] = [] + self.internal_message: list[MessageRecvS4U] = [] self.msg_id = "" self.voice_done = "" diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py index d235843d4..2031f7c56 100644 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py @@ -1,16 +1,17 @@ import asyncio -import orjson import time +import orjson + from src.chat.message_receive.message import MessageRecv -from src.llm_models.utils_model import LLMRequest -from src.common.logger import get_logger from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive -from src.config.config import global_config, model_config from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest +from src.mais4u.constant_s4u import ENABLE_S4U from src.manager.async_task_manager import AsyncTask, async_task_manager from src.plugin_system.apis import send_api -from src.mais4u.constant_s4u import ENABLE_S4U """ 情绪管理系统使用说明: diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py index ba8ee54eb..c7b855394 100644 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -1,33 +1,33 @@ import asyncio import math -from typing import Tuple + +from maim_message.message_base import GroupInfo + +from src.chat.message_receive.chat_stream import get_chat_manager # 旧的Hippocampus系统已被移除,现在使用增强记忆系统 # from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager from src.chat.message_receive.message import MessageRecv, MessageRecvS4U -from maim_message.message_base import GroupInfo from src.chat.message_receive.storage import MessageStorage -from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.timer_calculator import Timer from src.chat.utils.utils import is_mentioned_bot_in_message from src.common.logger import get_logger from src.config.config import global_config from src.mais4u.mais4u_chat.body_emotion_action_manager import action_manager -from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager -from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager from src.mais4u.mais4u_chat.gift_manager import gift_manager +from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager +from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager from src.mais4u.mais4u_chat.screen_manager import screen_manager from .s4u_chat import get_s4u_chat_manager - # from ..message_receive.message_buffer import message_buffer logger = get_logger("chat") -async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: +async def _calculate_interest(message: MessageRecv) -> tuple[float, bool]: """计算消息的兴趣度 Args: diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index 1c8782d23..b53a8b3f6 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -1,25 +1,27 @@ -from src.config.config import global_config -from src.common.logger import get_logger -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat -import time -from src.chat.utils.utils import get_recent_group_speaker +import asyncio # 旧的Hippocampus系统已被移除,现在使用增强记忆系统 # from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager import random +import time from datetime import datetime -import asyncio -from src.mais4u.s4u_config import s4u_config -from src.chat.message_receive.message import MessageRecvS4U -from src.person_info.relationship_fetcher import relationship_fetcher_manager -from src.person_info.person_info import PersonInfoManager, get_person_info_manager -from src.chat.message_receive.chat_stream import ChatStream -from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager -from src.mais4u.mais4u_chat.screen_manager import screen_manager + from src.chat.express.expression_selector import expression_selector -from .s4u_mood_manager import mood_manager +from src.chat.message_receive.chat_stream import ChatStream +from src.chat.message_receive.message import MessageRecvS4U +from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.chat.utils.utils import get_recent_group_speaker +from src.common.logger import get_logger +from src.config.config import global_config from src.mais4u.mais4u_chat.internal_manager import internal_manager +from src.mais4u.mais4u_chat.screen_manager import screen_manager +from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager +from src.mais4u.s4u_config import s4u_config +from src.person_info.person_info import PersonInfoManager, get_person_info_manager +from src.person_info.relationship_fetcher import relationship_fetcher_manager + +from .s4u_mood_manager import mood_manager logger = get_logger("prompt") @@ -206,7 +208,7 @@ class PromptBuilder: limit=300, ) - talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}" + talk_type = f"{message.message_info.platform}:{message.chat_stream.user_info.user_id!s}" core_dialogue_list = [] background_dialogue_list = [] diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py index d4ec70edd..3f2ac4a80 100644 --- a/src/mais4u/mais4u_chat/s4u_stream_generator.py +++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py @@ -1,12 +1,12 @@ -from typing import AsyncGenerator -from src.mais4u.openai_client import AsyncOpenAIClient -from src.config.config import model_config -from src.chat.message_receive.message import MessageRecvS4U -from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder -from src.common.logger import get_logger import asyncio import re +from collections.abc import AsyncGenerator +from src.chat.message_receive.message import MessageRecvS4U +from src.common.logger import get_logger +from src.config.config import model_config +from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder +from src.mais4u.openai_client import AsyncOpenAIClient logger = get_logger("s4u_stream_generator") @@ -99,7 +99,7 @@ class S4UStreamGenerator: logger.info( f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}" - ) # noqa: E501 + ) current_client = self.client_1 self.current_model_name = self.model_1_name diff --git a/src/mais4u/mais4u_chat/screen_manager.py b/src/mais4u/mais4u_chat/screen_manager.py index 996e63990..60a7f914d 100644 --- a/src/mais4u/mais4u_chat/screen_manager.py +++ b/src/mais4u/mais4u_chat/screen_manager.py @@ -1,6 +1,6 @@ class ScreenManager: def __init__(self): - self.now_screen = str() + self.now_screen = "" def set_screen(self, screen_str: str): self.now_screen = screen_str diff --git a/src/mais4u/mais4u_chat/super_chat_manager.py b/src/mais4u/mais4u_chat/super_chat_manager.py index 5f0ee2ac2..df6245746 100644 --- a/src/mais4u/mais4u_chat/super_chat_manager.py +++ b/src/mais4u/mais4u_chat/super_chat_manager.py @@ -1,9 +1,9 @@ import asyncio import time from dataclasses import dataclass -from typing import Dict, List, Optional -from src.common.logger import get_logger + from src.chat.message_receive.message import MessageRecvS4U +from src.common.logger import get_logger # 全局SuperChat管理器实例 from src.mais4u.constant_s4u import ENABLE_S4U @@ -23,7 +23,7 @@ class SuperChatRecord: message_text: str timestamp: float expire_time: float - group_name: Optional[str] = None + group_name: str | None = None def is_expired(self) -> bool: """检查SuperChat是否已过期""" @@ -53,8 +53,8 @@ class SuperChatManager: """SuperChat管理器,负责管理和跟踪SuperChat消息""" def __init__(self): - self.super_chats: Dict[str, List[SuperChatRecord]] = {} # chat_id -> SuperChat列表 - self._cleanup_task: Optional[asyncio.Task] = None + self.super_chats: dict[str, list[SuperChatRecord]] = {} # chat_id -> SuperChat列表 + self._cleanup_task: asyncio.Task | None = None self._is_initialized = False logger.info("SuperChat管理器已初始化") @@ -186,7 +186,7 @@ class SuperChatManager: logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}") - def get_superchats_by_chat(self, chat_id: str) -> List[SuperChatRecord]: + def get_superchats_by_chat(self, chat_id: str) -> list[SuperChatRecord]: """获取指定聊天的所有有效SuperChat""" # 确保清理任务已启动 self._ensure_cleanup_task_started() @@ -198,7 +198,7 @@ class SuperChatManager: valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()] return valid_superchats - def get_all_valid_superchats(self) -> Dict[str, List[SuperChatRecord]]: + def get_all_valid_superchats(self) -> dict[str, list[SuperChatRecord]]: """获取所有有效的SuperChat""" # 确保清理任务已启动 self._ensure_cleanup_task_started() diff --git a/src/mais4u/mais4u_chat/yes_or_no.py b/src/mais4u/mais4u_chat/yes_or_no.py index c71c160d3..51fba0416 100644 --- a/src/mais4u/mais4u_chat/yes_or_no.py +++ b/src/mais4u/mais4u_chat/yes_or_no.py @@ -1,6 +1,6 @@ -from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger from src.config.config import model_config +from src.llm_models.utils_model import LLMRequest from src.plugin_system.apis import send_api logger = get_logger(__name__) diff --git a/src/mais4u/openai_client.py b/src/mais4u/openai_client.py index 2a5873dec..6f5e0484e 100644 --- a/src/mais4u/openai_client.py +++ b/src/mais4u/openai_client.py @@ -1,5 +1,6 @@ -from typing import AsyncGenerator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator from dataclasses import dataclass + from openai import AsyncOpenAI from openai.types.chat import ChatCompletion, ChatCompletionChunk @@ -11,14 +12,14 @@ class ChatMessage: role: str content: str - def to_dict(self) -> Dict[str, str]: + def to_dict(self) -> dict[str, str]: return {"role": self.role, "content": self.content} class AsyncOpenAIClient: """异步OpenAI客户端,支持流式传输""" - def __init__(self, api_key: str, base_url: Optional[str] = None): + def __init__(self, api_key: str, base_url: str | None = None): """ 初始化客户端 @@ -34,10 +35,10 @@ class AsyncOpenAIClient: async def chat_completion( self, - messages: List[Union[ChatMessage, Dict[str, str]]], + messages: list[ChatMessage | dict[str, str]], model: str = "gpt-3.5-turbo", temperature: float = 0.7, - max_tokens: Optional[int] = None, + max_tokens: int | None = None, **kwargs, ) -> ChatCompletion: """ @@ -81,10 +82,10 @@ class AsyncOpenAIClient: async def chat_completion_stream( self, - messages: List[Union[ChatMessage, Dict[str, str]]], + messages: list[ChatMessage | dict[str, str]], model: str = "gpt-3.5-turbo", temperature: float = 0.7, - max_tokens: Optional[int] = None, + max_tokens: int | None = None, **kwargs, ) -> AsyncGenerator[ChatCompletionChunk, None]: """ @@ -129,10 +130,10 @@ class AsyncOpenAIClient: async def get_stream_content( self, - messages: List[Union[ChatMessage, Dict[str, str]]], + messages: list[ChatMessage | dict[str, str]], model: str = "gpt-3.5-turbo", temperature: float = 0.7, - max_tokens: Optional[int] = None, + max_tokens: int | None = None, **kwargs, ) -> AsyncGenerator[str, None]: """ @@ -156,10 +157,10 @@ class AsyncOpenAIClient: async def collect_stream_response( self, - messages: List[Union[ChatMessage, Dict[str, str]]], + messages: list[ChatMessage | dict[str, str]], model: str = "gpt-3.5-turbo", temperature: float = 0.7, - max_tokens: Optional[int] = None, + max_tokens: int | None = None, **kwargs, ) -> str: """ @@ -199,7 +200,7 @@ class AsyncOpenAIClient: class ConversationManager: """对话管理器,用于管理对话历史""" - def __init__(self, client: AsyncOpenAIClient, system_prompt: Optional[str] = None): + def __init__(self, client: AsyncOpenAIClient, system_prompt: str | None = None): """ 初始化对话管理器 @@ -208,7 +209,7 @@ class ConversationManager: system_prompt: 系统提示词 """ self.client = client - self.messages: List[ChatMessage] = [] + self.messages: list[ChatMessage] = [] if system_prompt: self.messages.append(ChatMessage(role="system", content=system_prompt)) @@ -281,6 +282,6 @@ class ConversationManager: """获取消息数量""" return len(self.messages) - def get_conversation_history(self) -> List[Dict[str, str]]: + def get_conversation_history(self) -> list[dict[str, str]]: """获取对话历史""" return [msg.to_dict() for msg in self.messages] diff --git a/src/mais4u/s4u_config.py b/src/mais4u/s4u_config.py index 79a8f92c4..f42e871bc 100644 --- a/src/mais4u/s4u_config.py +++ b/src/mais4u/s4u_config.py @@ -1,13 +1,16 @@ import os -import tomlkit import shutil +from dataclasses import MISSING, dataclass, field, fields from datetime import datetime +from typing import Any, Literal, TypeVar, get_args, get_origin + +import tomlkit from tomlkit import TOMLDocument from tomlkit.items import Table -from dataclasses import dataclass, fields, MISSING, field -from typing import TypeVar, Type, Any, get_origin, get_args, Literal -from src.mais4u.constant_s4u import ENABLE_S4U +from typing_extensions import Self + from src.common.logger import get_logger +from src.mais4u.constant_s4u import ENABLE_S4U logger = get_logger("s4u_config") @@ -46,7 +49,7 @@ class S4UConfigBase: """S4U配置类的基类""" @classmethod - def from_dict(cls: Type[T], data: dict[str, Any]) -> T: + def from_dict(cls, data: dict[str, Any]) -> Self: """从字典加载配置字段""" data = table_to_dict(data) # 递归转dict,兼容tomlkit Table if not is_dict_like(data): @@ -81,7 +84,7 @@ class S4UConfigBase: return cls() @classmethod - def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any: + def _convert_field(cls, value: Any, field_type: type[Any]) -> Any: """转换字段值为指定类型""" # 如果是嵌套的 dataclass,递归调用 from_dict 方法 if isinstance(field_type, type) and issubclass(field_type, S4UConfigBase): @@ -271,9 +274,9 @@ def update_s4u_config(): return # 读取旧配置文件和模板文件 - with open(CONFIG_PATH, "r", encoding="utf-8") as f: + with open(CONFIG_PATH, encoding="utf-8") as f: old_config = tomlkit.load(f) - with open(TEMPLATE_PATH, "r", encoding="utf-8") as f: + with open(TEMPLATE_PATH, encoding="utf-8") as f: new_config = tomlkit.load(f) # 检查version是否相同 @@ -344,7 +347,7 @@ def load_s4u_config(config_path: str) -> S4UGlobalConfig: :return: S4UGlobalConfig对象 """ # 读取配置文件 - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: config_data = tomlkit.load(f) # 创建S4UGlobalConfig对象 diff --git a/src/manager/async_task_manager.py b/src/manager/async_task_manager.py index 92f6675bd..157849381 100644 --- a/src/manager/async_task_manager.py +++ b/src/manager/async_task_manager.py @@ -1,8 +1,7 @@ -from abc import abstractmethod, ABCMeta - import asyncio -from asyncio import Task, Event, Lock -from typing import Callable, Dict +from abc import ABCMeta, abstractmethod +from asyncio import Event, Lock, Task +from collections.abc import Callable from src.common.logger import get_logger @@ -46,7 +45,7 @@ class AsyncTaskManager: """异步任务管理器""" def __init__(self): - self.tasks: Dict[str, Task] = {} + self.tasks: dict[str, Task] = {} """任务列表""" self.abort_flag: Event = Event() @@ -116,7 +115,7 @@ class AsyncTaskManager: self.tasks[task.task_name] = task_inst # 将任务添加到任务列表 logger.debug(f"已启动任务 '{task.task_name}'") - def get_tasks_status(self) -> Dict[str, Dict[str, str]]: + def get_tasks_status(self) -> dict[str, dict[str, str]]: """ 获取所有任务的状态 """ diff --git a/src/manager/local_store_manager.py b/src/manager/local_store_manager.py index 63d191ef1..f5b5a28ca 100644 --- a/src/manager/local_store_manager.py +++ b/src/manager/local_store_manager.py @@ -1,6 +1,7 @@ -import orjson import os +import orjson + from src.common.logger import get_logger LOCAL_STORE_FILE_PATH = "data/local_store.json" @@ -24,7 +25,7 @@ class LocalStoreManager: """获取本地存储数据""" return self.store.get(item) - def __setitem__(self, key: str, value: str | list | dict | int | float | bool): + def __setitem__(self, key: str, value: str | list | dict | float | bool): """设置本地存储数据""" self.store[key] = value self.save_local_store() @@ -48,7 +49,7 @@ class LocalStoreManager: logger.info("正在阅读记事本......我在看,我真的在看!") logger.debug(f"加载本地存储数据: {self.file_path}") try: - with open(self.file_path, "r", encoding="utf-8") as f: + with open(self.file_path, encoding="utf-8") as f: self.store = orjson.loads(f.read()) logger.info("全都记起来了!") except orjson.JSONDecodeError: diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index 66fcee96f..76f8a547e 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -2,17 +2,16 @@ import math import random import time +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.message import MessageRecv +from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config, model_config -from src.chat.message_receive.message import MessageRecv -from src.common.data_models.database_data_model import DatabaseMessages -from src.chat.message_receive.chat_stream import get_chat_manager -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.llm_models.utils_model import LLMRequest from src.manager.async_task_manager import AsyncTask, async_task_manager - logger = get_logger("mood") diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 478d4c9fb..afde489dc 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -2,7 +2,8 @@ import copy import datetime import hashlib import time -from typing import Any, Callable, Dict, Union, Optional +from collections.abc import Callable +from typing import Any import orjson from json_repair import repair_json @@ -86,7 +87,7 @@ class PersonInfoManager: logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}") @staticmethod - def get_person_id(platform: str, user_id: Union[int, str]) -> str: + def get_person_id(platform: str, user_id: int | str) -> str: """获取唯一id(同步) 说明: 原来该方法为异步并在内部尝试执行数据库检查/迁移,导致在许多调用处未 await 时返回 coroutine 对象。 @@ -167,7 +168,7 @@ class PersonInfoManager: ) @staticmethod - async def create_person_info(person_id: str, data: Optional[dict] = None): + async def create_person_info(person_id: str, data: dict | None = None): """创建一个项""" if not person_id: logger.debug("创建失败,person_id不存在") @@ -228,7 +229,7 @@ class PersonInfoManager: await _db_create_async(final_data) @staticmethod - async def _safe_create_person_info(person_id: str, data: Optional[dict] = None): + async def _safe_create_person_info(person_id: str, data: dict | None = None): """安全地创建用户信息,处理竞态条件""" if not person_id: logger.debug("创建失败,person_id不存在") @@ -296,7 +297,7 @@ class PersonInfoManager: await _db_safe_create_async(final_data) - async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None): + async def update_one_field(self, person_id: str, field_name: str, value, data: dict | None = None): """更新某一个字段,会补全""" # 获取 SQLAlchemy 模型的所有字段名 model_fields = [column.name for column in PersonInfo.__table__.columns] @@ -628,7 +629,7 @@ class PersonInfoManager: async def get_specific_value_list( field_name: str, way: Callable[[Any], bool], - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ 获取满足条件的字段值字典 """ @@ -649,18 +650,18 @@ class PersonInfoManager: found_results[record.person_id] = value except Exception as e_query: logger.error( - f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True + f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {e_query!s}", exc_info=True ) return found_results try: return await _db_get_specific_async(field_name) except Exception as e: - logger.error(f"执行 get_specific_value_list 时出错: {str(e)}", exc_info=True) + logger.error(f"执行 get_specific_value_list 时出错: {e!s}", exc_info=True) return {} async def get_or_create_person( - self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None + self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: str | None = None ) -> str: """ 根据 platform 和 user_id 获取 person_id。 diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index 4dc478f6c..10f1d3d97 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -1,20 +1,21 @@ -import time -import traceback import os import pickle import random -from typing import List, Dict, Any -from src.config.config import global_config -from src.common.logger import get_logger -from src.person_info.relationship_manager import get_relationship_manager -from src.person_info.person_info import get_person_info_manager, PersonInfoManager +import time +import traceback +from typing import Any + from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import ( + get_raw_msg_before_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat_inclusive, - get_raw_msg_before_timestamp_with_chat, num_new_messages_since, ) +from src.common.logger import get_logger +from src.config.config import global_config +from src.person_info.person_info import PersonInfoManager, get_person_info_manager +from src.person_info.relationship_manager import get_relationship_manager logger = get_logger("relationship_builder") @@ -45,7 +46,7 @@ class RelationshipBuilder: self.chat_id = chat_id # 新的消息段缓存结构: # {person_id: [{"start_time": float, "end_time": float, "last_msg_time": float, "message_count": int}, ...]} - self.person_engaged_cache: Dict[str, List[Dict[str, Any]]] = {} + self.person_engaged_cache: dict[str, list[dict[str, Any]]] = {} # 持久化存储文件路径 self.cache_file_path = os.path.join("data", "relationship", f"relationship_cache_{self.chat_id}.pkl") @@ -401,7 +402,7 @@ class RelationshipBuilder: # 负责触发关系构建、整合消息段、更新用户印象 # ================================ - async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]): + async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: list[dict[str, Any]]): """基于消息段更新用户印象""" original_segment_count = len(segments) logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象") diff --git a/src/person_info/relationship_builder_manager.py b/src/person_info/relationship_builder_manager.py index f3bca25d2..61cad42e2 100644 --- a/src/person_info/relationship_builder_manager.py +++ b/src/person_info/relationship_builder_manager.py @@ -1,6 +1,7 @@ -from typing import Dict, Optional, List, Any +from typing import Any from src.common.logger import get_logger + from .relationship_builder import RelationshipBuilder logger = get_logger("relationship_builder_manager") @@ -13,7 +14,7 @@ class RelationshipBuilderManager: """ def __init__(self): - self.builders: Dict[str, RelationshipBuilder] = {} + self.builders: dict[str, RelationshipBuilder] = {} def get_or_create_builder(self, chat_id: str) -> RelationshipBuilder: """获取或创建关系构建器 @@ -30,7 +31,7 @@ class RelationshipBuilderManager: return self.builders[chat_id] - def get_builder(self, chat_id: str) -> Optional[RelationshipBuilder]: + def get_builder(self, chat_id: str) -> RelationshipBuilder | None: """获取关系构建器 Args: @@ -56,7 +57,7 @@ class RelationshipBuilderManager: return True return False - def get_all_chat_ids(self) -> List[str]: + def get_all_chat_ids(self) -> list[str]: """获取所有管理的聊天ID列表 Returns: @@ -64,7 +65,7 @@ class RelationshipBuilderManager: """ return list(self.builders.keys()) - def get_status(self) -> Dict[str, Any]: + def get_status(self) -> dict[str, Any]: """获取管理器状态 Returns: diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 90a353291..b0835fcb4 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -1,18 +1,17 @@ import time import traceback -import orjson +from typing import Any -from typing import List, Dict, Any +import orjson from json_repair import repair_json +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.message_receive.chat_stream import get_chat_manager from src.person_info.person_info import get_person_info_manager - logger = get_logger("relationship_fetcher") @@ -64,10 +63,10 @@ class RelationshipFetcher: self.chat_id = chat_id # 信息获取缓存:记录正在获取的信息请求 - self.info_fetching_cache: List[Dict[str, Any]] = [] + self.info_fetching_cache: list[dict[str, Any]] = [] # 信息结果缓存:存储已获取的信息结果,带TTL - self.info_fetched_cache: Dict[str, Dict[str, Any]] = {} + self.info_fetched_cache: dict[str, dict[str, Any]] = {} # 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}} # LLM模型配置 @@ -471,7 +470,7 @@ class RelationshipFetcherManager: """ def __init__(self): - self._fetchers: Dict[str, RelationshipFetcher] = {} + self._fetchers: dict[str, RelationshipFetcher] = {} def get_fetcher(self, chat_id: str) -> RelationshipFetcher: """获取或创建指定 chat_id 的 RelationshipFetcher @@ -499,7 +498,7 @@ class RelationshipFetcherManager: """清空所有 RelationshipFetcher""" self._fetchers.clear() - def get_active_chat_ids(self) -> List[str]: + def get_active_chat_ids(self) -> list[str]: """获取所有活跃的 chat_id 列表""" return list(self._fetchers.keys()) diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index a6ce8ab02..7792798f1 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -1,18 +1,21 @@ -from src.common.logger import get_logger -from .person_info import PersonInfoManager, get_person_info_manager -import time import random -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.chat.utils.chat_message_builder import build_readable_messages -import orjson -from json_repair import repair_json +import time from datetime import datetime from difflib import SequenceMatcher +from typing import Any + import jieba +import orjson +from json_repair import repair_json from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity -from typing import List, Dict, Any + +from src.chat.utils.chat_message_builder import build_readable_messages +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest + +from .person_info import PersonInfoManager, get_person_info_manager logger = get_logger("relation") @@ -54,7 +57,7 @@ class RelationshipManager: # person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar # ) - async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]): + async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: list[dict[str, Any]]): """更新用户印象 Args: diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index ae66a9803..9a3bb85d6 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -5,33 +5,49 @@ MaiBot 插件系统 """ # 导出主要的公共接口 +from .apis import ( + chat_api, + component_manage_api, + config_api, + database_api, + emoji_api, + generator_api, + get_logger, + llm_api, + message_api, + person_api, + plugin_manage_api, + register_plugin, + send_api, + tool_api, +) from .base import ( - BasePlugin, + ActionActivationType, + ActionInfo, BaseAction, BaseCommand, - BaseTool, - ConfigField, - ComponentType, - ActionActivationType, - ChatMode, - ComponentInfo, - ActionInfo, - CommandInfo, - PlusCommandInfo, - PluginInfo, - ToolInfo, - PythonDependency, BaseEventHandler, + BasePlugin, + BaseTool, + ChatMode, + ChatType, + CommandArgs, + CommandInfo, + ComponentInfo, + ComponentType, + ConfigField, EventHandlerInfo, EventType, MaiMessages, - ToolParamType, + PluginInfo, # 新增的增强命令系统 PlusCommand, - CommandArgs, PlusCommandAdapter, + PlusCommandInfo, + PythonDependency, + ToolInfo, + ToolParamType, create_plus_command_adapter, - ChatType, ) # 导入工具模块 @@ -41,28 +57,10 @@ from .utils import ( # validate_plugin_manifest, # generate_plugin_manifest, ) +from .utils.dependency_config import configure_dependency_settings, get_dependency_config # 导入依赖管理模块 -from .utils.dependency_manager import get_dependency_manager, configure_dependency_manager -from .utils.dependency_config import get_dependency_config, configure_dependency_settings - -from .apis import ( - chat_api, - tool_api, - component_manage_api, - config_api, - database_api, - emoji_api, - generator_api, - llm_api, - message_api, - person_api, - plugin_manage_api, - send_api, - register_plugin, - get_logger, -) - +from .utils.dependency_manager import configure_dependency_manager, get_dependency_manager __version__ = "2.0.0" diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py index c80c5942c..cc67b9348 100644 --- a/src/plugin_system/apis/__init__.py +++ b/src/plugin_system/apis/__init__.py @@ -14,14 +14,15 @@ from src.plugin_system.apis import ( generator_api, llm_api, message_api, + permission_api, person_api, plugin_manage_api, + schedule_api, send_api, tool_api, - permission_api, - schedule_api, ) from src.plugin_system.apis.chat_api import ChatManager as context_api + from .logging_api import get_logger from .plugin_register_api import register_plugin @@ -30,18 +31,18 @@ __all__ = [ "chat_api", "component_manage_api", "config_api", + "context_api", "database_api", "emoji_api", "generator_api", + "get_logger", "llm_api", "message_api", + "permission_api", "person_api", "plugin_manage_api", - "send_api", - "get_logger", "register_plugin", - "tool_api", - "permission_api", - "context_api", "schedule_api", + "send_api", + "tool_api", ] diff --git a/src/plugin_system/apis/chat_api.py b/src/plugin_system/apis/chat_api.py index 9e995d36f..47cecd2d5 100644 --- a/src/plugin_system/apis/chat_api.py +++ b/src/plugin_system/apis/chat_api.py @@ -12,11 +12,11 @@ streams = chat.get_all_group_streams() """ -from typing import List, Dict, Any, Optional from enum import Enum +from typing import Any -from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.common.logger import get_logger logger = get_logger("chat_api") @@ -31,7 +31,7 @@ class ChatManager: """聊天管理器 - 专门负责聊天信息的查询和管理""" @staticmethod - def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: + def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: # sourcery skip: for-append-to-extend """获取所有聊天流 @@ -57,7 +57,7 @@ class ChatManager: return streams @staticmethod - def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: + def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: # sourcery skip: for-append-to-extend """获取所有群聊聊天流 @@ -80,7 +80,7 @@ class ChatManager: return streams @staticmethod - def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: + def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: # sourcery skip: for-append-to-extend """获取所有私聊聊天流 @@ -107,8 +107,8 @@ class ChatManager: @staticmethod def get_group_stream_by_group_id( - group_id: str, platform: Optional[str] | SpecialTypes = "qq" - ) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast + group_id: str, platform: str | None | SpecialTypes = "qq" + ) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast """根据群ID获取聊天流 Args: @@ -144,8 +144,8 @@ class ChatManager: @staticmethod def get_private_stream_by_user_id( - user_id: str, platform: Optional[str] | SpecialTypes = "qq" - ) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast + user_id: str, platform: str | None | SpecialTypes = "qq" + ) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast """根据用户ID获取私聊流 Args: @@ -203,7 +203,7 @@ class ChatManager: return "unknown" @staticmethod - def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]: + def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]: """获取聊天流详细信息 Args: @@ -222,7 +222,7 @@ class ChatManager: raise TypeError("chat_stream 必须是 ChatStream 类型") try: - info: Dict[str, Any] = { + info: dict[str, Any] = { "stream_id": chat_stream.stream_id, "platform": chat_stream.platform, "type": ChatManager.get_stream_type(chat_stream), @@ -250,7 +250,7 @@ class ChatManager: return {} @staticmethod - def get_streams_summary() -> Dict[str, int]: + def get_streams_summary() -> dict[str, int]: """获取聊天流统计摘要 Returns: @@ -285,27 +285,27 @@ class ChatManager: # ============================================================================= -def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: +def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: """获取所有聊天流的便捷函数""" return ChatManager.get_all_streams(platform) -def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: +def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: """获取群聊聊天流的便捷函数""" return ChatManager.get_group_streams(platform) -def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: +def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: """获取私聊聊天流的便捷函数""" return ChatManager.get_private_streams(platform) -def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]: +def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None: """根据群ID获取聊天流的便捷函数""" return ChatManager.get_group_stream_by_group_id(group_id, platform) -def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]: +def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None: """根据用户ID获取私聊流的便捷函数""" return ChatManager.get_private_stream_by_user_id(user_id, platform) @@ -315,11 +315,11 @@ def get_stream_type(chat_stream: ChatStream) -> str: return ChatManager.get_stream_type(chat_stream) -def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]: +def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]: """获取聊天流信息的便捷函数""" return ChatManager.get_stream_info(chat_stream) -def get_streams_summary() -> Dict[str, int]: +def get_streams_summary() -> dict[str, int]: """获取聊天流统计摘要的便捷函数""" return ChatManager.get_streams_summary() diff --git a/src/plugin_system/apis/component_manage_api.py b/src/plugin_system/apis/component_manage_api.py index 1ffa0833e..490237188 100644 --- a/src/plugin_system/apis/component_manage_api.py +++ b/src/plugin_system/apis/component_manage_api.py @@ -1,16 +1,15 @@ -from typing import Optional, Union, Dict from src.plugin_system.base.component_types import ( - CommandInfo, ActionInfo, + CommandInfo, + ComponentType, EventHandlerInfo, PluginInfo, - ComponentType, ToolInfo, ) # === 插件信息查询 === -def get_all_plugin_info() -> Dict[str, PluginInfo]: +def get_all_plugin_info() -> dict[str, PluginInfo]: """ 获取所有插件的信息。 @@ -22,7 +21,7 @@ def get_all_plugin_info() -> Dict[str, PluginInfo]: return component_registry.get_all_plugins() -def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]: +def get_plugin_info(plugin_name: str) -> PluginInfo | None: """ 获取指定插件的信息。 @@ -40,7 +39,7 @@ def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]: # === 组件查询方法 === def get_component_info( component_name: str, component_type: ComponentType -) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]: +) -> CommandInfo | ActionInfo | EventHandlerInfo | None: """ 获取指定组件的信息。 @@ -57,7 +56,7 @@ def get_component_info( def get_components_info_by_type( component_type: ComponentType, -) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]: +) -> dict[str, CommandInfo | ActionInfo | EventHandlerInfo]: """ 获取指定类型的所有组件信息。 @@ -74,7 +73,7 @@ def get_components_info_by_type( def get_enabled_components_info_by_type( component_type: ComponentType, -) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]: +) -> dict[str, CommandInfo | ActionInfo | EventHandlerInfo]: """ 获取指定类型的所有启用的组件信息。 @@ -90,7 +89,7 @@ def get_enabled_components_info_by_type( # === Action 查询方法 === -def get_registered_action_info(action_name: str) -> Optional[ActionInfo]: +def get_registered_action_info(action_name: str) -> ActionInfo | None: """ 获取指定 Action 的注册信息。 @@ -105,7 +104,7 @@ def get_registered_action_info(action_name: str) -> Optional[ActionInfo]: return component_registry.get_registered_action_info(action_name) -def get_registered_command_info(command_name: str) -> Optional[CommandInfo]: +def get_registered_command_info(command_name: str) -> CommandInfo | None: """ 获取指定 Command 的注册信息。 @@ -120,7 +119,7 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]: return component_registry.get_registered_command_info(command_name) -def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]: +def get_registered_tool_info(tool_name: str) -> ToolInfo | None: """ 获取指定 Tool 的注册信息。 @@ -138,7 +137,7 @@ def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]: # === EventHandler 特定查询方法 === def get_registered_event_handler_info( event_handler_name: str, -) -> Optional[EventHandlerInfo]: +) -> EventHandlerInfo | None: """ 获取指定 EventHandler 的注册信息。 diff --git a/src/plugin_system/apis/config_api.py b/src/plugin_system/apis/config_api.py index 05556414e..3ec8694b2 100644 --- a/src/plugin_system/apis/config_api.py +++ b/src/plugin_system/apis/config_api.py @@ -8,6 +8,7 @@ """ from typing import Any + from src.common.logger import get_logger from src.config.config import global_config diff --git a/src/plugin_system/apis/cross_context_api.py b/src/plugin_system/apis/cross_context_api.py index 76bd45bde..3e84cc26b 100644 --- a/src/plugin_system/apis/cross_context_api.py +++ b/src/plugin_system/apis/cross_context_api.py @@ -3,20 +3,20 @@ """ import time -from typing import Dict, Any, Optional, List +from typing import Any +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.utils.chat_message_builder import ( + build_readable_messages_with_id, + get_raw_msg_before_timestamp_with_chat, +) from src.common.logger import get_logger from src.config.config import global_config -from src.chat.utils.chat_message_builder import ( - get_raw_msg_before_timestamp_with_chat, - build_readable_messages_with_id, -) -from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream logger = get_logger("cross_context_api") -def get_context_groups(chat_id: str) -> Optional[List[List[str]]]: +def get_context_groups(chat_id: str) -> list[list[str]] | None: """ 获取当前聊天所在的共享组的其他聊天ID """ @@ -41,7 +41,7 @@ def get_context_groups(chat_id: str) -> Optional[List[List[str]]]: return None -async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: List[List[str]]) -> str: +async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: list[list[str]]) -> str: """ 构建跨群聊/私聊上下文 (Normal模式) """ @@ -74,8 +74,8 @@ async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: async def build_cross_context_s4u( chat_stream: ChatStream, - other_chat_infos: List[List[str]], - target_user_info: Optional[Dict[str, Any]], + other_chat_infos: list[list[str]], + target_user_info: dict[str, Any] | None, ) -> str: """ 构建跨群聊/私聊上下文 (S4U模式) diff --git a/src/plugin_system/apis/database_api.py b/src/plugin_system/apis/database_api.py index c3195bab4..aa6714655 100644 --- a/src/plugin_system/apis/database_api.py +++ b/src/plugin_system/apis/database_api.py @@ -9,7 +9,7 @@ 注意:此模块现在使用SQLAlchemy实现,提供更好的连接管理和错误处理 """ -from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get, store_action_info, MODEL_MAPPING +from src.common.database.sqlalchemy_database_api import MODEL_MAPPING, db_get, db_query, db_save, store_action_info # 保持向后兼容性 -__all__ = ["db_query", "db_save", "db_get", "store_action_info", "MODEL_MAPPING"] +__all__ = ["MODEL_MAPPING", "db_get", "db_query", "db_save", "store_action_info"] diff --git a/src/plugin_system/apis/emoji_api.py b/src/plugin_system/apis/emoji_api.py index 4fbadb98f..a62977d66 100644 --- a/src/plugin_system/apis/emoji_api.py +++ b/src/plugin_system/apis/emoji_api.py @@ -10,10 +10,9 @@ import random -from typing import Optional, Tuple, List -from src.common.logger import get_logger from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.utils.utils_image import image_path_to_base64 +from src.common.logger import get_logger logger = get_logger("emoji_api") @@ -23,7 +22,7 @@ logger = get_logger("emoji_api") # ============================================================================= -async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]: +async def get_by_description(description: str) -> tuple[str, str, str] | None: """根据描述选择表情包 Args: @@ -65,7 +64,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]] return None -async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: +async def get_random(count: int | None = 1) -> list[tuple[str, str, str]]: """随机获取指定数量的表情包 Args: @@ -137,7 +136,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: return [] -async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: +async def get_by_emotion(emotion: str) -> tuple[str, str, str] | None: """根据情感标签获取表情包 Args: @@ -227,7 +226,7 @@ def get_info(): return {"current_count": 0, "max_count": 0, "available_emojis": 0} -def get_emotions() -> List[str]: +def get_emotions() -> list[str]: """获取所有可用的情感标签 Returns: @@ -247,7 +246,7 @@ def get_emotions() -> List[str]: return [] -def get_descriptions() -> List[str]: +def get_descriptions() -> list[str]: """获取所有表情包描述 Returns: diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 2a907c60b..21bc6fdde 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -9,13 +9,15 @@ """ import traceback -from typing import Tuple, Any, Dict, List, Optional +from typing import Any + from rich.traceback import install -from src.common.logger import get_logger -from src.chat.replyer.default_generator import DefaultReplyer + from src.chat.message_receive.chat_stream import ChatStream -from src.chat.utils.utils import process_llm_response +from src.chat.replyer.default_generator import DefaultReplyer from src.chat.replyer.replyer_manager import replyer_manager +from src.chat.utils.utils import process_llm_response +from src.common.logger import get_logger from src.plugin_system.base.component_types import ActionInfo install(extra_lines=3) @@ -30,10 +32,10 @@ logger = get_logger("generator_api") def get_replyer( - chat_stream: Optional[ChatStream] = None, - chat_id: Optional[str] = None, + chat_stream: ChatStream | None = None, + chat_id: str | None = None, request_type: str = "replyer", -) -> Optional[DefaultReplyer]: +) -> DefaultReplyer | None: """获取回复器对象 优先使用chat_stream,如果没有则使用chat_id直接查找。 @@ -71,13 +73,13 @@ def get_replyer( async def generate_reply( - chat_stream: Optional[ChatStream] = None, - chat_id: Optional[str] = None, - action_data: Optional[Dict[str, Any]] = None, + chat_stream: ChatStream | None = None, + chat_id: str | None = None, + action_data: dict[str, Any] | None = None, reply_to: str = "", - reply_message: Optional[Dict[str, Any]] = None, + reply_message: dict[str, Any] | None = None, extra_info: str = "", - available_actions: Optional[Dict[str, ActionInfo]] = None, + available_actions: dict[str, ActionInfo] | None = None, enable_tool: bool = False, enable_splitter: bool = True, enable_chinese_typo: bool = True, @@ -85,7 +87,7 @@ async def generate_reply( request_type: str = "generator_api", from_plugin: bool = True, read_mark: float = 0.0, -) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: +) -> tuple[bool, list[tuple[str, Any]], str | None]: """生成回复 Args: @@ -168,9 +170,9 @@ async def generate_reply( async def rewrite_reply( - chat_stream: Optional[ChatStream] = None, - reply_data: Optional[Dict[str, Any]] = None, - chat_id: Optional[str] = None, + chat_stream: ChatStream | None = None, + reply_data: dict[str, Any] | None = None, + chat_id: str | None = None, enable_splitter: bool = True, enable_chinese_typo: bool = True, raw_reply: str = "", @@ -178,7 +180,7 @@ async def rewrite_reply( reply_to: str = "", return_prompt: bool = False, request_type: str = "generator_api", -) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: +) -> tuple[bool, list[tuple[str, Any]], str | None]: """重写回复 Args: @@ -237,7 +239,7 @@ async def rewrite_reply( return False, [], None -def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]: +def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> list[tuple[str, Any]]: """将文本处理为更拟人化的文本 Args: @@ -266,11 +268,11 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: async def generate_response_custom( - chat_stream: Optional[ChatStream] = None, - chat_id: Optional[str] = None, + chat_stream: ChatStream | None = None, + chat_id: str | None = None, request_type: str = "generator_api", prompt: str = "", -) -> Optional[str]: +) -> str | None: """ 使用自定义提示生成回复 diff --git a/src/plugin_system/apis/llm_api.py b/src/plugin_system/apis/llm_api.py index debb67d7e..e868d40a2 100644 --- a/src/plugin_system/apis/llm_api.py +++ b/src/plugin_system/apis/llm_api.py @@ -7,12 +7,13 @@ success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config) """ -from typing import Tuple, Dict, List, Any, Optional +from typing import Any + from src.common.logger import get_logger +from src.config.api_ada_configs import TaskConfig +from src.config.config import model_config from src.llm_models.payload_content.tool_option import ToolCall from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config -from src.config.api_ada_configs import TaskConfig logger = get_logger("llm_api") @@ -21,7 +22,7 @@ logger = get_logger("llm_api") # ============================================================================= -def get_available_models() -> Dict[str, TaskConfig]: +def get_available_models() -> dict[str, TaskConfig]: """获取所有可用的模型配置 Returns: @@ -31,7 +32,7 @@ def get_available_models() -> Dict[str, TaskConfig]: # 自动获取所有属性并转换为字典形式 models = model_config.model_task_config attrs = dir(models) - rets: Dict[str, TaskConfig] = {} + rets: dict[str, TaskConfig] = {} for attr in attrs: if not attr.startswith("__"): try: @@ -52,9 +53,9 @@ async def generate_with_model( prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, -) -> Tuple[bool, str, str, str]: + temperature: float | None = None, + max_tokens: int | None = None, +) -> tuple[bool, str, str, str]: """使用指定模型生成内容 Args: @@ -78,7 +79,7 @@ async def generate_with_model( return True, response, reasoning_content, model_name except Exception as e: - error_msg = f"生成内容时出错: {str(e)}" + error_msg = f"生成内容时出错: {e!s}" logger.error(f"[LLMAPI] {error_msg}") return False, error_msg, "", "" @@ -86,11 +87,11 @@ async def generate_with_model( async def generate_with_model_with_tools( prompt: str, model_config: TaskConfig, - tool_options: List[Dict[str, Any]] | None = None, + tool_options: list[dict[str, Any]] | None = None, request_type: str = "plugin.generate", - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, -) -> Tuple[bool, str, str, str, List[ToolCall] | None]: + temperature: float | None = None, + max_tokens: int | None = None, +) -> tuple[bool, str, str, str, list[ToolCall] | None]: """使用指定模型和工具生成内容 Args: @@ -117,6 +118,6 @@ async def generate_with_model_with_tools( return True, response, reasoning_content, model_name, tool_call except Exception as e: - error_msg = f"生成内容时出错: {str(e)}" + error_msg = f"生成内容时出错: {e!s}" logger.error(f"[LLMAPI] {error_msg}") return False, error_msg, "", "", None diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index baf6418dd..4a9610ca2 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -8,26 +8,26 @@ readable_text = message_api.build_readable_messages(messages) """ -from typing import List, Dict, Any, Tuple, Optional -from src.config.config import global_config import time +from typing import Any + from src.chat.utils.chat_message_builder import ( - get_raw_msg_by_timestamp, - get_raw_msg_by_timestamp_with_chat, - get_raw_msg_by_timestamp_with_chat_inclusive, - get_raw_msg_by_timestamp_with_chat_users, - get_raw_msg_by_timestamp_random, - get_raw_msg_by_timestamp_with_users, - get_raw_msg_before_timestamp, - get_raw_msg_before_timestamp_with_chat, - get_raw_msg_before_timestamp_with_users, - num_new_messages_since, - num_new_messages_since_with_users, build_readable_messages, build_readable_messages_with_list, get_person_id_list, + get_raw_msg_before_timestamp, + get_raw_msg_before_timestamp_with_chat, + get_raw_msg_before_timestamp_with_users, + get_raw_msg_by_timestamp, + get_raw_msg_by_timestamp_random, + get_raw_msg_by_timestamp_with_chat, + get_raw_msg_by_timestamp_with_chat_inclusive, + get_raw_msg_by_timestamp_with_chat_users, + get_raw_msg_by_timestamp_with_users, + num_new_messages_since, + num_new_messages_since_with_users, ) - +from src.config.config import global_config # ============================================================================= # 消息查询API函数 @@ -36,7 +36,7 @@ from src.chat.utils.chat_message_builder import ( async def get_messages_by_time( start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 获取指定时间范围内的消息 @@ -70,7 +70,7 @@ async def get_messages_by_time_in_chat( limit_mode: str = "latest", filter_mai: bool = False, filter_command: bool = False, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 获取指定聊天中指定时间范围内的消息 @@ -111,7 +111,7 @@ async def get_messages_by_time_in_chat_inclusive( limit_mode: str = "latest", filter_mai: bool = False, filter_command: bool = False, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 获取指定聊天中指定时间范围内的消息(包含边界) @@ -152,10 +152,10 @@ async def get_messages_by_time_in_chat_for_users( chat_id: str, start_time: float, end_time: float, - person_ids: List[str], + person_ids: list[str], limit: int = 0, limit_mode: str = "latest", -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 获取指定聊天中指定用户在指定时间范围内的消息 @@ -186,7 +186,7 @@ async def get_messages_by_time_in_chat_for_users( async def get_random_chat_messages( start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 随机选择一个聊天,返回该聊天在指定时间范围内的消息 @@ -213,8 +213,8 @@ async def get_random_chat_messages( async def get_messages_by_time_for_users( - start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" -) -> List[Dict[str, Any]]: + start_time: float, end_time: float, person_ids: list[str], limit: int = 0, limit_mode: str = "latest" +) -> list[dict[str, Any]]: """ 获取指定用户在所有聊天中指定时间范围内的消息 @@ -238,7 +238,7 @@ async def get_messages_by_time_for_users( return await get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode) -async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]: +async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> list[dict[str, Any]]: """ 获取指定时间戳之前的消息 @@ -294,8 +294,8 @@ async def get_messages_before_time_in_chat( async def get_messages_before_time_for_users( - timestamp: float, person_ids: List[str], limit: int = 0 -) -> List[Dict[str, Any]]: + timestamp: float, person_ids: list[str], limit: int = 0 +) -> list[dict[str, Any]]: """ 获取指定用户在指定时间戳之前的消息 @@ -319,7 +319,7 @@ async def get_messages_before_time_for_users( async def get_recent_messages( chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 获取指定聊天中最近一段时间的消息 @@ -358,7 +358,7 @@ async def get_recent_messages( # ============================================================================= -async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int: +async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: float | None = None) -> int: """ 计算指定聊天中从开始时间到结束时间的新消息数量 @@ -382,7 +382,7 @@ async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Op return await num_new_messages_since(chat_id, start_time, end_time) -async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int: +async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: list[str]) -> int: """ 计算指定聊天中指定用户从开始时间到结束时间的新消息数量 @@ -413,7 +413,7 @@ async def count_new_messages_for_users(chat_id: str, start_time: float, end_time async def build_readable_messages_to_str( - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, timestamp_mode: str = "relative", @@ -442,12 +442,12 @@ async def build_readable_messages_to_str( async def build_readable_messages_with_details( - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, timestamp_mode: str = "relative", truncate: bool = False, -) -> Tuple[str, List[Tuple[float, str, str]]]: +) -> tuple[str, list[tuple[float, str, str]]]: """ 将消息列表构建成可读的字符串,并返回详细信息 @@ -464,7 +464,7 @@ async def build_readable_messages_with_details( return await build_readable_messages_with_list(messages, replace_bot_name, merge_messages, timestamp_mode, truncate) -async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]: +async def get_person_ids_from_messages(messages: list[dict[str, Any]]) -> list[str]: """ 从消息列表中提取不重复的用户ID列表 @@ -482,7 +482,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s # ============================================================================= -async def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +async def filter_mai_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """ 从消息列表中移除麦麦的消息 Args: diff --git a/src/plugin_system/apis/permission_api.py b/src/plugin_system/apis/permission_api.py index 61b4ca40f..3c42f9eab 100644 --- a/src/plugin_system/apis/permission_api.py +++ b/src/plugin_system/apis/permission_api.py @@ -1,9 +1,9 @@ """纯异步权限API定义。所有外部调用方必须使用 await。""" -from typing import Optional, List, Dict, Any +from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from abc import ABC, abstractmethod +from typing import Any from src.common.logger import get_logger @@ -48,18 +48,18 @@ class IPermissionManager(ABC): async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: ... @abstractmethod - async def get_user_permissions(self, user: UserInfo) -> List[str]: ... + async def get_user_permissions(self, user: UserInfo) -> list[str]: ... @abstractmethod - async def get_all_permission_nodes(self) -> List[PermissionNode]: ... + async def get_all_permission_nodes(self) -> list[PermissionNode]: ... @abstractmethod - async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: ... + async def get_plugin_permission_nodes(self, plugin_name: str) -> list[PermissionNode]: ... class PermissionAPI: def __init__(self): - self._permission_manager: Optional[IPermissionManager] = None + self._permission_manager: IPermissionManager | None = None # 需要保留的前缀(视为绝对节点名,不再自动加 plugins.. 前缀) self.RESERVED_PREFIXES: tuple[str, ...] = "system." # 系统节点列表 (name, description, default_granted) @@ -147,11 +147,11 @@ class PermissionAPI: self._ensure_manager() return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node) - async def get_user_permissions(self, platform: str, user_id: str) -> List[str]: + async def get_user_permissions(self, platform: str, user_id: str) -> list[str]: self._ensure_manager() return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id)) - async def get_all_permission_nodes(self) -> List[Dict[str, Any]]: + async def get_all_permission_nodes(self) -> list[dict[str, Any]]: self._ensure_manager() nodes = await self._permission_manager.get_all_permission_nodes() return [ @@ -164,7 +164,7 @@ class PermissionAPI: for n in nodes ] - async def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]: + async def get_plugin_permission_nodes(self, plugin_name: str) -> list[dict[str, Any]]: self._ensure_manager() nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name) return [ diff --git a/src/plugin_system/apis/person_api.py b/src/plugin_system/apis/person_api.py index e3f7be714..5c3427dff 100644 --- a/src/plugin_system/apis/person_api.py +++ b/src/plugin_system/apis/person_api.py @@ -7,9 +7,10 @@ value = await person_api.get_person_value(person_id, "nickname") """ -from typing import Any, Optional +from typing import Any + from src.common.logger import get_logger -from src.person_info.person_info import get_person_info_manager, PersonInfoManager +from src.person_info.person_info import PersonInfoManager, get_person_info_manager logger = get_logger("person_api") @@ -63,7 +64,7 @@ async def get_person_value(person_id: str, field_name: str, default: Any = None) return default -async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict: +async def get_person_values(person_id: str, field_names: list, default_dict: dict | None = None) -> dict: """批量获取用户信息字段值 Args: diff --git a/src/plugin_system/apis/plugin_manage_api.py b/src/plugin_system/apis/plugin_manage_api.py index d428eb282..d7a802b8c 100644 --- a/src/plugin_system/apis/plugin_manage_api.py +++ b/src/plugin_system/apis/plugin_manage_api.py @@ -1,7 +1,4 @@ -from typing import Tuple, List - - -def list_loaded_plugins() -> List[str]: +def list_loaded_plugins() -> list[str]: """ 列出所有当前加载的插件。 @@ -13,7 +10,7 @@ def list_loaded_plugins() -> List[str]: return plugin_manager.list_loaded_plugins() -def list_registered_plugins() -> List[str]: +def list_registered_plugins() -> list[str]: """ 列出所有已注册的插件。 @@ -80,7 +77,7 @@ async def reload_plugin(plugin_name: str) -> bool: return await plugin_manager.reload_registered_plugin(plugin_name) -def load_plugin(plugin_name: str) -> Tuple[bool, int]: +def load_plugin(plugin_name: str) -> tuple[bool, int]: """ 加载指定的插件。 @@ -109,7 +106,7 @@ def add_plugin_directory(plugin_directory: str) -> bool: return plugin_manager.add_plugin_directory(plugin_directory) -def rescan_plugin_directory() -> Tuple[int, int]: +def rescan_plugin_directory() -> tuple[int, int]: """ 重新扫描插件目录,加载新插件。 Returns: diff --git a/src/plugin_system/apis/plugin_register_api.py b/src/plugin_system/apis/plugin_register_api.py index 2e14b0c84..6741c7ea9 100644 --- a/src/plugin_system/apis/plugin_register_api.py +++ b/src/plugin_system/apis/plugin_register_api.py @@ -6,8 +6,8 @@ logger = get_logger("plugin_manager") # 复用plugin_manager名称 def register_plugin(cls): - from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.base.base_plugin import BasePlugin + from src.plugin_system.core.plugin_manager import plugin_manager """插件注册装饰器 diff --git a/src/plugin_system/apis/schedule_api.py b/src/plugin_system/apis/schedule_api.py index e3e759968..61c5d13f4 100644 --- a/src/plugin_system/apis/schedule_api.py +++ b/src/plugin_system/apis/schedule_api.py @@ -30,7 +30,7 @@ """ from datetime import datetime -from typing import List, Dict, Any, Optional +from typing import Any from src.common.database.sqlalchemy_models import MonthlyPlan from src.common.logger import get_logger @@ -44,7 +44,7 @@ class ScheduleAPI: """日程表与月度计划API - 负责日程和计划信息的查询与管理""" @staticmethod - async def get_today_schedule() -> Optional[List[Dict[str, Any]]]: + async def get_today_schedule() -> list[dict[str, Any]] | None: """(异步) 获取今天的日程安排 Returns: @@ -58,7 +58,7 @@ class ScheduleAPI: return None @staticmethod - async def get_current_activity() -> Optional[str]: + async def get_current_activity() -> str | None: """(异步) 获取当前正在进行的活动 Returns: @@ -87,7 +87,7 @@ class ScheduleAPI: return False @staticmethod - async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]: + async def get_monthly_plans(target_month: str | None = None) -> list[MonthlyPlan]: """(异步) 获取指定月份的有效月度计划 Args: @@ -106,7 +106,7 @@ class ScheduleAPI: return [] @staticmethod - async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool: + async def ensure_monthly_plans(target_month: str | None = None) -> bool: """(异步) 确保指定月份存在月度计划,如果不存在则触发生成 Args: @@ -125,7 +125,7 @@ class ScheduleAPI: return False @staticmethod - async def archive_monthly_plans(target_month: Optional[str] = None) -> bool: + async def archive_monthly_plans(target_month: str | None = None) -> bool: """(异步) 归档指定月份的月度计划 Args: @@ -150,12 +150,12 @@ class ScheduleAPI: # ============================================================================= -async def get_today_schedule() -> Optional[List[Dict[str, Any]]]: +async def get_today_schedule() -> list[dict[str, Any]] | None: """(异步) 获取今天的日程安排的便捷函数""" return await ScheduleAPI.get_today_schedule() -async def get_current_activity() -> Optional[str]: +async def get_current_activity() -> str | None: """(异步) 获取当前正在进行的活动的便捷函数""" return await ScheduleAPI.get_current_activity() @@ -165,16 +165,16 @@ async def regenerate_schedule() -> bool: return await ScheduleAPI.regenerate_schedule() -async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]: +async def get_monthly_plans(target_month: str | None = None) -> list[MonthlyPlan]: """(异步) 获取指定月份的有效月度计划的便捷函数""" return await ScheduleAPI.get_monthly_plans(target_month) -async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool: +async def ensure_monthly_plans(target_month: str | None = None) -> bool: """(异步) 确保指定月份存在月度计划的便捷函数""" return await ScheduleAPI.ensure_monthly_plans(target_month) -async def archive_monthly_plans(target_month: Optional[str] = None) -> bool: +async def archive_monthly_plans(target_month: str | None = None) -> bool: """(异步) 归档指定月份的月度计划的便捷函数""" return await ScheduleAPI.archive_monthly_plans(target_month) diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index c770db78b..d05e50355 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -28,29 +28,28 @@ """ -import traceback -import time import asyncio -from typing import Optional, Union, Dict, Any -from src.common.logger import get_logger +import time +import traceback +from typing import Any + +from maim_message import Seg, UserInfo # 导入依赖 -from src.chat.message_receive.chat_stream import get_chat_manager -from maim_message import UserInfo -from src.chat.message_receive.chat_stream import ChatStream +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.message import MessageRecv, MessageSending from src.chat.message_receive.uni_message_sender import HeartFCSender -from src.chat.message_receive.message import MessageSending, MessageRecv -from maim_message import Seg +from src.common.logger import get_logger from src.config.config import global_config # 日志记录器 logger = get_logger("send_api") # 适配器命令响应等待池 -_adapter_response_pool: Dict[str, asyncio.Future] = {} +_adapter_response_pool: dict[str, asyncio.Future] = {} -def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[MessageRecv]: +def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None: """查找要回复的消息 Args: @@ -134,13 +133,13 @@ async def wait_adapter_response(request_id: str, timeout: float = 30.0) -> dict: async def _send_to_target( message_type: str, - content: Union[str, dict], + content: str | dict, stream_id: str, display_message: str = "", typing: bool = False, reply_to: str = "", set_reply: bool = False, - reply_to_message: Optional[Dict[str, Any]] = None, + reply_to_message: dict[str, Any] | None = None, storage_message: bool = True, show_log: bool = True, ) -> bool: @@ -247,7 +246,7 @@ async def text_to_stream( stream_id: str, typing: bool = False, reply_to: str = "", - reply_to_message: Optional[Dict[str, Any]] = None, + reply_to_message: dict[str, Any] | None = None, set_reply: bool = True, storage_message: bool = True, ) -> bool: @@ -313,7 +312,7 @@ async def image_to_stream( async def command_to_stream( - command: Union[str, dict], + command: str | dict, stream_id: str, storage_message: bool = True, display_message: str = "", @@ -341,7 +340,7 @@ async def custom_to_stream( display_message: str = "", typing: bool = False, reply_to: str = "", - reply_to_message: Optional[Dict[str, Any]] = None, + reply_to_message: dict[str, Any] | None = None, set_reply: bool = True, storage_message: bool = True, show_log: bool = True, @@ -377,8 +376,8 @@ async def custom_to_stream( async def adapter_command_to_stream( action: str, params: dict, - platform: Optional[str] = "qq", - stream_id: Optional[str] = None, + platform: str | None = "qq", + stream_id: str | None = None, timeout: float = 30.0, storage_message: bool = False, ) -> dict: @@ -497,4 +496,4 @@ async def adapter_command_to_stream( except Exception as e: logger.error(f"[SendAPI] 发送适配器命令时出错: {e}") traceback.print_exc() - return {"status": "error", "message": f"发送适配器命令时出错: {str(e)}"} + return {"status": "error", "message": f"发送适配器命令时出错: {e!s}"} diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py index c3472243a..6b949b2e5 100644 --- a/src/plugin_system/apis/tool_api.py +++ b/src/plugin_system/apis/tool_api.py @@ -1,13 +1,11 @@ -from typing import Optional, Type +from src.common.logger import get_logger from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.component_types import ComponentType -from src.common.logger import get_logger - logger = get_logger("tool_api") -def get_tool_instance(tool_name: str) -> Optional[BaseTool]: +def get_tool_instance(tool_name: str) -> BaseTool | None: """获取公开工具实例""" from src.plugin_system.core import component_registry @@ -18,7 +16,7 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]: else: plugin_config = None - tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore + tool_class: type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore return tool_class(plugin_config) if tool_class else None diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py index 83debab01..87f004ff5 100644 --- a/src/plugin_system/base/__init__.py +++ b/src/plugin_system/base/__init__.py @@ -4,31 +4,31 @@ 提供插件开发的基础类和类型定义 """ -from .base_plugin import BasePlugin from .base_action import BaseAction -from .base_tool import BaseTool from .base_command import BaseCommand from .base_events_handler import BaseEventHandler +from .base_plugin import BasePlugin +from .base_tool import BaseTool +from .command_args import CommandArgs from .component_types import ( - ComponentType, ActionActivationType, + ActionInfo, ChatMode, ChatType, - ComponentInfo, - ActionInfo, CommandInfo, - PlusCommandInfo, - ToolInfo, - PluginInfo, - PythonDependency, + ComponentInfo, + ComponentType, EventHandlerInfo, EventType, MaiMessages, + PluginInfo, + PlusCommandInfo, + PythonDependency, + ToolInfo, ToolParamType, ) from .config_types import ConfigField from .plus_command import PlusCommand, PlusCommandAdapter, create_plus_command_adapter -from .command_args import CommandArgs __all__ = [ "BasePlugin", diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index d3f012be5..37711794b 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -1,14 +1,11 @@ -import time import asyncio - +import time from abc import ABC, abstractmethod -from typing import Tuple, Optional, List, Dict -from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream -from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType, ChatType -from src.plugin_system.apis import send_api, database_api, message_api - +from src.common.logger import get_logger +from src.plugin_system.apis import database_api, message_api, send_api +from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ChatMode, ChatType, ComponentType logger = get_logger("base_action") @@ -39,7 +36,7 @@ class BaseAction(ABC): """是否为二步Action。如果为True,Action将分两步执行:第一步选择操作,第二步执行具体操作""" step_one_description: str = "" """第一步的描述,用于向LLM展示Action的基本功能""" - sub_actions: List[Tuple[str, str, Dict[str, str]]] = [] + sub_actions: list[tuple[str, str, dict[str, str]]] = [] """子Action列表,格式为[(子Action名, 子Action描述, 子Action参数)]。仅在二步Action中使用""" def __init__( @@ -50,8 +47,8 @@ class BaseAction(ABC): thinking_id: str, chat_stream: ChatStream, log_prefix: str = "", - plugin_config: Optional[dict] = None, - action_message: Optional[dict] = None, + plugin_config: dict | None = None, + action_message: dict | None = None, **kwargs, ): # sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs @@ -109,8 +106,8 @@ class BaseAction(ABC): # 二步Action相关实例属性 self.is_two_step_action: bool = getattr(self.__class__, "is_two_step_action", False) self.step_one_description: str = getattr(self.__class__, "step_one_description", "") - self.sub_actions: List[Tuple[str, str, Dict[str, str]]] = getattr(self.__class__, "sub_actions", []).copy() - self._selected_sub_action: Optional[str] = None + self.sub_actions: list[tuple[str, str, dict[str, str]]] = getattr(self.__class__, "sub_actions", []).copy() + self._selected_sub_action: str | None = None """当前选择的子Action名称,用于二步Action的状态管理""" # ============================================================================= @@ -200,7 +197,7 @@ class BaseAction(ABC): """ return self._validate_chat_type() - async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]: + async def wait_for_new_message(self, timeout: int = 1200) -> tuple[bool, str]: """等待新消息或超时 在loop_start_time之后等待新消息,如果没有新消息且没有超时,就一直等待。 @@ -232,7 +229,7 @@ class BaseAction(ABC): # 检查新消息 current_time = time.time() - new_message_count = message_api.count_new_messages( + new_message_count = await message_api.count_new_messages( chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time ) @@ -258,7 +255,7 @@ class BaseAction(ABC): return False, "" except Exception as e: logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}") - return False, f"等待新消息失败: {str(e)}" + return False, f"等待新消息失败: {e!s}" async def send_text(self, content: str, reply_to: str = "", typing: bool = False) -> bool: """发送文本消息 @@ -359,7 +356,7 @@ class BaseAction(ABC): ) async def send_command( - self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True + self, command_name: str, args: dict | None = None, display_message: str = "", storage_message: bool = True ) -> bool: """发送命令消息 @@ -400,7 +397,7 @@ class BaseAction(ABC): logger.error(f"{self.log_prefix} 发送命令时出错: {e}") return False - async def call_action(self, action_name: str, action_data: Optional[dict] = None) -> Tuple[bool, str]: + async def call_action(self, action_name: str, action_data: dict | None = None) -> tuple[bool, str]: """ 在当前Action中调用另一个Action。 @@ -514,7 +511,7 @@ class BaseAction(ABC): sub_actions=getattr(cls, "sub_actions", []).copy(), ) - async def handle_step_one(self) -> Tuple[bool, str]: + async def handle_step_one(self) -> tuple[bool, str]: """处理二步Action的第一步 Returns: @@ -546,7 +543,7 @@ class BaseAction(ABC): # 调用第二步执行 return await self.execute_step_two(selected_action) - async def execute_step_two(self, sub_action_name: str) -> Tuple[bool, str]: + async def execute_step_two(self, sub_action_name: str) -> tuple[bool, str]: """执行二步Action的第二步 Args: @@ -562,7 +559,7 @@ class BaseAction(ABC): return False, f"二步Action必须实现execute_step_two方法来处理操作: {sub_action_name}" @abstractmethod - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """执行Action的抽象方法,子类必须实现 对于二步Action,会自动处理第一步逻辑 @@ -577,7 +574,7 @@ class BaseAction(ABC): # 普通Action由子类实现 pass - async def handle_action(self) -> Tuple[bool, str]: + async def handle_action(self) -> tuple[bool, str]: """兼容旧系统的handle_action接口,委托给execute方法 为了保持向后兼容性,旧系统的代码可能会调用handle_action方法。 diff --git a/src/plugin_system/base/base_chatter.py b/src/plugin_system/base/base_chatter.py index 1dd225252..b8a1288af 100644 --- a/src/plugin_system/base/base_chatter.py +++ b/src/plugin_system/base/base_chatter.py @@ -1,9 +1,11 @@ from abc import ABC, abstractmethod -from typing import List, TYPE_CHECKING +from typing import TYPE_CHECKING + from src.common.data_models.message_manager_data_model import StreamContext -from .component_types import ChatType from src.plugin_system.base.component_types import ChatterInfo, ComponentType +from .component_types import ChatType + if TYPE_CHECKING: from src.chat.planner_actions.action_manager import ChatterActionManager @@ -13,7 +15,7 @@ class BaseChatter(ABC): """Chatter组件的名称""" chatter_description: str = "" """Chatter组件的描述""" - chat_types: List[ChatType] = [ChatType.PRIVATE, ChatType.GROUP] + chat_types: list[ChatType] = [ChatType.PRIVATE, ChatType.GROUP] def __init__(self, stream_id: str, action_manager: "ChatterActionManager"): """ diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 212634d5d..9cb41ed04 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from typing import Dict, Tuple, Optional -from src.common.logger import get_logger -from src.plugin_system.base.component_types import CommandInfo, ComponentType, ChatType + from src.chat.message_receive.message import MessageRecv +from src.common.logger import get_logger from src.plugin_system.apis import send_api +from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType logger = get_logger("base_command") @@ -29,7 +29,7 @@ class BaseCommand(ABC): chat_type_allow: ChatType = ChatType.ALL """允许的聊天类型,默认为所有类型""" - def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None): + def __init__(self, message: MessageRecv, plugin_config: dict | None = None): """初始化Command组件 Args: @@ -37,7 +37,7 @@ class BaseCommand(ABC): plugin_config: 插件配置字典 """ self.message = message - self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组 + self.matched_groups: dict[str, str] = {} # 存储正则表达式匹配的命名组 self.plugin_config = plugin_config or {} # 直接存储插件配置字典 self.log_prefix = "[Command]" @@ -55,7 +55,7 @@ class BaseCommand(ABC): f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}" ) - def set_matched_groups(self, groups: Dict[str, str]) -> None: + def set_matched_groups(self, groups: dict[str, str]) -> None: """设置正则表达式匹配的命名组 Args: @@ -93,7 +93,7 @@ class BaseCommand(ABC): return self._validate_chat_type() @abstractmethod - async def execute(self) -> Tuple[bool, Optional[str], bool]: + async def execute(self) -> tuple[bool, str | None, bool]: """执行Command的抽象方法,子类必须实现 Returns: @@ -175,7 +175,7 @@ class BaseCommand(ABC): ) async def send_command( - self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True + self, command_name: str, args: dict | None = None, display_message: str = "", storage_message: bool = True ) -> bool: """发送命令消息 diff --git a/src/plugin_system/base/base_event.py b/src/plugin_system/base/base_event.py index c7dd09a58..f8c45e54d 100644 --- a/src/plugin_system/base/base_event.py +++ b/src/plugin_system/base/base_event.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Dict, Any, Optional +from typing import Any from src.common.logger import get_logger @@ -25,22 +25,22 @@ class HandlerResult: class HandlerResultsCollection: """HandlerResult集合,提供便捷的查询方法""" - def __init__(self, results: List[HandlerResult]): + def __init__(self, results: list[HandlerResult]): self.results = results def all_continue_process(self) -> bool: """检查是否所有handler的continue_process都为True""" return all(result.continue_process for result in self.results) - def get_all_results(self) -> List[HandlerResult]: + def get_all_results(self) -> list[HandlerResult]: """获取所有HandlerResult""" return self.results - def get_failed_handlers(self) -> List[HandlerResult]: + def get_failed_handlers(self) -> list[HandlerResult]: """获取执行失败的handler结果""" return [result for result in self.results if not result.success] - def get_stopped_handlers(self) -> List[HandlerResult]: + def get_stopped_handlers(self) -> list[HandlerResult]: """获取continue_process为False的handler结果""" return [result for result in self.results if not result.continue_process] @@ -57,7 +57,7 @@ class HandlerResultsCollection: else: return {result.handler_name: result.message for result in self.results} - def get_handler_result(self, handler_name: str) -> Optional[HandlerResult]: + def get_handler_result(self, handler_name: str) -> HandlerResult | None: """获取指定handler的结果""" for result in self.results: if result.handler_name == handler_name: @@ -72,7 +72,7 @@ class HandlerResultsCollection: """获取执行失败的handler数量""" return sum(1 for result in self.results if not result.success) - def get_summary(self) -> Dict[str, Any]: + def get_summary(self) -> dict[str, Any]: """获取执行摘要""" return { "total_handlers": len(self.results), @@ -85,13 +85,13 @@ class HandlerResultsCollection: class BaseEvent: - def __init__(self, name: str, allowed_subscribers: List[str] = None, allowed_triggers: List[str] = None): + def __init__(self, name: str, allowed_subscribers: list[str] = None, allowed_triggers: list[str] = None): self.name = name self.enabled = True self.allowed_subscribers = allowed_subscribers # 记录事件处理器名 self.allowed_triggers = allowed_triggers # 记录插件名 - self.subscribers: List["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表 + self.subscribers: list["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表 self.event_handle_lock = asyncio.Lock() diff --git a/src/plugin_system/base/base_events_handler.py b/src/plugin_system/base/base_events_handler.py index 517de92c2..fa73dccc8 100644 --- a/src/plugin_system/base/base_events_handler.py +++ b/src/plugin_system/base/base_events_handler.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod -from typing import Tuple, Optional, List, Union from src.common.logger import get_logger -from .component_types import EventType, EventHandlerInfo, ComponentType + +from .component_types import ComponentType, EventHandlerInfo, EventType logger = get_logger("base_event_handler") @@ -21,7 +21,7 @@ class BaseEventHandler(ABC): """处理器权重,越大权重越高""" intercept_message: bool = False """是否拦截消息,默认为否""" - init_subscribe: List[Union[EventType, str]] = [EventType.UNKNOWN] + init_subscribe: list[EventType | str] = [EventType.UNKNOWN] """初始化时订阅的事件名称""" plugin_name = None @@ -44,7 +44,7 @@ class BaseEventHandler(ABC): self.plugin_config = getattr(self.__class__, "plugin_config", {}) @abstractmethod - async def execute(self, kwargs: dict | None) -> Tuple[bool, bool, Optional[str]]: + async def execute(self, kwargs: dict | None) -> tuple[bool, bool, str | None]: """执行事件处理的抽象方法,子类必须实现 Args: kwargs (dict | None): 事件消息对象,当你注册的事件为ON_START和ON_STOP时message为None diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index 8916fadfd..232365bce 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -1,13 +1,13 @@ from abc import abstractmethod -from typing import List, Type, Tuple, Union -from .plugin_base import PluginBase from src.common.logger import get_logger -from src.plugin_system.base.component_types import ActionInfo, CommandInfo, PlusCommandInfo, EventHandlerInfo, ToolInfo +from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, PlusCommandInfo, ToolInfo + from .base_action import BaseAction from .base_command import BaseCommand from .base_events_handler import BaseEventHandler from .base_tool import BaseTool +from .plugin_base import PluginBase from .plus_command import PlusCommand logger = get_logger("base_plugin") @@ -28,14 +28,12 @@ class BasePlugin(PluginBase): @abstractmethod def get_plugin_components( self, - ) -> List[ - Union[ - Tuple[ActionInfo, Type[BaseAction]], - Tuple[CommandInfo, Type[BaseCommand]], - Tuple[PlusCommandInfo, Type[PlusCommand]], - Tuple[EventHandlerInfo, Type[BaseEventHandler]], - Tuple[ToolInfo, Type[BaseTool]], - ] + ) -> list[ + tuple[ActionInfo, type[BaseAction]] + | tuple[CommandInfo, type[BaseCommand]] + | tuple[PlusCommandInfo, type[PlusCommand]] + | tuple[EventHandlerInfo, type[BaseEventHandler]] + | tuple[ToolInfo, type[BaseTool]] ]: """获取插件包含的组件列表 diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py index 229cadb63..5cd04b485 100644 --- a/src/plugin_system/base/base_tool.py +++ b/src/plugin_system/base/base_tool.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional, Tuple +from typing import Any + from rich.traceback import install from src.common.logger import get_logger @@ -17,7 +18,7 @@ class BaseTool(ABC): """工具的名称""" description: str = "" """工具的描述""" - parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = [] + parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = [] """工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式 param_name: 参数名称 param_type: 参数类型 @@ -35,7 +36,7 @@ class BaseTool(ABC): """是否为该工具启用缓存""" cache_ttl: int = 3600 """缓存的TTL值(秒),默认为3600秒(1小时)""" - semantic_cache_query_key: Optional[str] = None + semantic_cache_query_key: str | None = None """用于语义缓存的查询参数键名。如果设置,将使用此参数的值进行语义相似度搜索""" # 二步工具调用相关属性 @@ -43,10 +44,10 @@ class BaseTool(ABC): """是否为二步工具。如果为True,工具将分两步调用:第一步展示工具信息,第二步执行具体操作""" step_one_description: str = "" """第一步的描述,用于向LLM展示工具的基本功能""" - sub_tools: List[Tuple[str, str, List[Tuple[str, ToolParamType, str, bool, List[str] | None]]]] = [] + sub_tools: list[tuple[str, str, list[tuple[str, ToolParamType, str, bool, list[str] | None]]]] = [] """子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用""" - def __init__(self, plugin_config: Optional[dict] = None): + def __init__(self, plugin_config: dict | None = None): self.plugin_config = plugin_config or {} # 直接存储插件配置字典 @classmethod @@ -101,7 +102,7 @@ class BaseTool(ABC): raise ValueError(f"未找到子工具: {sub_tool_name}") @classmethod - def get_all_sub_tool_definitions(cls) -> List[dict[str, Any]]: + def get_all_sub_tool_definitions(cls) -> list[dict[str, Any]]: """获取所有子工具的定义 Returns: diff --git a/src/plugin_system/base/command_args.py b/src/plugin_system/base/command_args.py index 980eb958f..72d55dd6b 100644 --- a/src/plugin_system/base/command_args.py +++ b/src/plugin_system/base/command_args.py @@ -3,7 +3,6 @@ 提供简单易用的命令参数解析功能 """ -from typing import List, Optional import shlex @@ -20,7 +19,7 @@ class CommandArgs: raw_args: 原始参数字符串 """ self._raw_args = raw_args.strip() - self._parsed_args: Optional[List[str]] = None + self._parsed_args: list[str] | None = None def get_raw(self) -> str: """获取完整的参数字符串 @@ -30,7 +29,7 @@ class CommandArgs: """ return self._raw_args - def get_args(self) -> List[str]: + def get_args(self) -> list[str]: """获取解析后的参数列表 将参数按空格分割,支持引号包围的参数 diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 2b1122b9f..9ae921466 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -1,10 +1,11 @@ -from enum import Enum -from typing import Dict, Any, List, Optional, Tuple from dataclasses import dataclass, field +from enum import Enum +from typing import Any + from maim_message import Seg -from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType from src.llm_models.payload_content.tool_option import ToolCall as ToolCall +from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType # 组件类型枚举 @@ -114,7 +115,7 @@ class ComponentInfo: enabled: bool = True # 是否启用 plugin_name: str = "" # 所属插件名称 is_built_in: bool = False # 是否为内置组件 - metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据 + metadata: dict[str, Any] = field(default_factory=dict) # 额外元数据 def __post_init__(self): if self.metadata is None: @@ -125,18 +126,18 @@ class ComponentInfo: class ActionInfo(ComponentInfo): """动作组件信息""" - action_parameters: Dict[str, str] = field( + action_parameters: dict[str, str] = field( default_factory=dict ) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"} - action_require: List[str] = field(default_factory=list) # 动作需求说明 - associated_types: List[str] = field(default_factory=list) # 关联的消息类型 + action_require: list[str] = field(default_factory=list) # 动作需求说明 + associated_types: list[str] = field(default_factory=list) # 关联的消息类型 # 激活类型相关 focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS activation_type: ActionActivationType = ActionActivationType.ALWAYS random_activation_probability: float = 0.0 llm_judge_prompt: str = "" - activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表 + activation_keywords: list[str] = field(default_factory=list) # 激活关键词列表 keyword_case_sensitive: bool = False # 模式和并行设置 mode_enable: ChatMode = ChatMode.ALL @@ -145,7 +146,7 @@ class ActionInfo(ComponentInfo): # 二步Action相关属性 is_two_step_action: bool = False # 是否为二步Action step_one_description: str = "" # 第一步的描述 - sub_actions: List[Tuple[str, str, Dict[str, str]]] = field(default_factory=list) # 子Action列表 + sub_actions: list[tuple[str, str, dict[str, str]]] = field(default_factory=list) # 子Action列表 def __post_init__(self): super().__post_init__() @@ -178,7 +179,7 @@ class CommandInfo(ComponentInfo): class PlusCommandInfo(ComponentInfo): """增强命令组件信息""" - command_aliases: List[str] = field(default_factory=list) # 命令别名列表 + command_aliases: list[str] = field(default_factory=list) # 命令别名列表 priority: int = 0 # 命令优先级 chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型 intercept_message: bool = False # 是否拦截消息 @@ -194,7 +195,7 @@ class PlusCommandInfo(ComponentInfo): class ToolInfo(ComponentInfo): """工具组件信息""" - tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field( + tool_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = field( default_factory=list ) # 工具参数定义 tool_description: str = "" # 工具描述 @@ -248,18 +249,18 @@ class PluginInfo: author: str = "" # 插件作者 enabled: bool = True # 是否启用 is_built_in: bool = False # 是否为内置插件 - components: List[ComponentInfo] = field(default_factory=list) # 包含的组件列表 - dependencies: List[str] = field(default_factory=list) # 依赖的其他插件 - python_dependencies: List[PythonDependency] = field(default_factory=list) # Python包依赖 + components: list[ComponentInfo] = field(default_factory=list) # 包含的组件列表 + dependencies: list[str] = field(default_factory=list) # 依赖的其他插件 + python_dependencies: list[PythonDependency] = field(default_factory=list) # Python包依赖 config_file: str = "" # 配置文件路径 - metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据 + metadata: dict[str, Any] = field(default_factory=dict) # 额外元数据 # 新增:manifest相关信息 - manifest_data: Dict[str, Any] = field(default_factory=dict) # manifest文件数据 + manifest_data: dict[str, Any] = field(default_factory=dict) # manifest文件数据 license: str = "" # 插件许可证 homepage_url: str = "" # 插件主页 repository_url: str = "" # 插件仓库地址 - keywords: List[str] = field(default_factory=list) # 插件关键词 - categories: List[str] = field(default_factory=list) # 插件分类 + keywords: list[str] = field(default_factory=list) # 插件关键词 + categories: list[str] = field(default_factory=list) # 插件分类 min_host_version: str = "" # 最低主机版本要求 max_host_version: str = "" # 最高主机版本要求 @@ -279,7 +280,7 @@ class PluginInfo: if self.categories is None: self.categories = [] - def get_missing_packages(self) -> List[PythonDependency]: + def get_missing_packages(self) -> list[PythonDependency]: """检查缺失的Python包""" missing = [] for dep in self.python_dependencies: @@ -290,7 +291,7 @@ class PluginInfo: missing.append(dep) return missing - def get_pip_requirements(self) -> List[str]: + def get_pip_requirements(self) -> list[str]: """获取所有pip安装格式的依赖""" return [dep.get_pip_requirement() for dep in self.python_dependencies] @@ -299,16 +300,16 @@ class PluginInfo: class MaiMessages: """MaiM插件消息""" - message_segments: List[Seg] = field(default_factory=list) + message_segments: list[Seg] = field(default_factory=list) """消息段列表,支持多段消息""" - message_base_info: Dict[str, Any] = field(default_factory=dict) + message_base_info: dict[str, Any] = field(default_factory=dict) """消息基本信息,包含平台,用户信息等数据""" plain_text: str = "" """纯文本消息内容""" - raw_message: Optional[str] = None + raw_message: str | None = None """原始消息内容""" is_group_message: bool = False @@ -317,28 +318,28 @@ class MaiMessages: is_private_message: bool = False """是否为私聊消息""" - stream_id: Optional[str] = None + stream_id: str | None = None """流ID,用于标识消息流""" - llm_prompt: Optional[str] = None + llm_prompt: str | None = None """LLM提示词""" - llm_response_content: Optional[str] = None + llm_response_content: str | None = None """LLM响应内容""" - llm_response_reasoning: Optional[str] = None + llm_response_reasoning: str | None = None """LLM响应推理内容""" - llm_response_model: Optional[str] = None + llm_response_model: str | None = None """LLM响应模型名称""" - llm_response_tool_call: Optional[List[ToolCall]] = None + llm_response_tool_call: list[ToolCall] | None = None """LLM使用的工具调用""" - action_usage: Optional[List[str]] = None + action_usage: list[str] | None = None """使用的Action""" - additional_data: Dict[Any, Any] = field(default_factory=dict) + additional_data: dict[Any, Any] = field(default_factory=dict) """附加数据,可以存储额外信息""" def __post_init__(self): diff --git a/src/plugin_system/base/config_types.py b/src/plugin_system/base/config_types.py index 752b33453..9dc9b58eb 100644 --- a/src/plugin_system/base/config_types.py +++ b/src/plugin_system/base/config_types.py @@ -2,8 +2,8 @@ 插件系统配置类型定义 """ -from typing import Any, Optional, List from dataclasses import dataclass, field +from typing import Any @dataclass @@ -13,6 +13,6 @@ class ConfigField: type: type # 字段类型 default: Any # 默认值 description: str # 字段描述 - example: Optional[str] = None # 示例值 + example: str | None = None # 示例值 required: bool = False # 是否必需 - choices: Optional[List[Any]] = field(default_factory=list) # 可选值列表 + choices: list[Any] | None = field(default_factory=list) # 可选值列表 diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index a61b8e04c..8cc3312db 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -1,11 +1,12 @@ -from abc import ABC, abstractmethod -from typing import Dict, List, Any, Union -import os -import toml -import orjson -import shutil import datetime +import os +import shutil +from abc import ABC, abstractmethod from pathlib import Path +from typing import Any + +import orjson +import toml from src.common.logger import get_logger from src.config.config import CONFIG_DIR @@ -38,12 +39,12 @@ class PluginBase(ABC): @property @abstractmethod - def dependencies(self) -> List[str]: + def dependencies(self) -> list[str]: return [] # 依赖的其他插件 @property @abstractmethod - def python_dependencies(self) -> List[Union[str, PythonDependency]]: + def python_dependencies(self) -> list[str | PythonDependency]: return [] # Python包依赖,支持字符串列表或PythonDependency对象列表 @property @@ -53,15 +54,15 @@ class PluginBase(ABC): # manifest文件相关 manifest_file_name: str = "_manifest.json" # manifest文件名 - manifest_data: Dict[str, Any] = {} # manifest数据 + manifest_data: dict[str, Any] = {} # manifest数据 # 配置定义 @property @abstractmethod - def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]: + def config_schema(self) -> dict[str, dict[str, ConfigField] | str]: return {} - config_section_descriptions: Dict[str, str] = {} + config_section_descriptions: dict[str, str] = {} def __init__(self, plugin_dir: str): """初始化插件 @@ -69,7 +70,7 @@ class PluginBase(ABC): Args: plugin_dir: 插件目录路径,由插件管理器传递 """ - self.config: Dict[str, Any] = {} # 插件配置 + self.config: dict[str, Any] = {} # 插件配置 self.plugin_dir = plugin_dir # 插件目录路径 self.log_prefix = f"[Plugin:{self.plugin_name}]" self._is_enabled = self.enable_plugin # 从插件定义中获取默认启用状态 @@ -144,7 +145,7 @@ class PluginBase(ABC): raise FileNotFoundError(error_msg) try: - with open(manifest_path, "r", encoding="utf-8") as f: + with open(manifest_path, encoding="utf-8") as f: self.manifest_data = orjson.loads(f.read()) logger.debug(f"{self.log_prefix} 成功加载manifest文件: {manifest_path}") @@ -155,8 +156,8 @@ class PluginBase(ABC): except orjson.JSONDecodeError as e: error_msg = f"{self.log_prefix} manifest文件格式错误: {e}" logger.error(error_msg) - raise ValueError(error_msg) # noqa - except IOError as e: + raise ValueError(error_msg) + except OSError as e: error_msg = f"{self.log_prefix} 读取manifest文件失败: {e}" logger.error(error_msg) raise IOError(error_msg) # noqa @@ -266,7 +267,7 @@ class PluginBase(ABC): with open(config_file_path, "w", encoding="utf-8") as f: f.write(toml_str) logger.info(f"{self.log_prefix} 已生成默认配置文件: {config_file_path}") - except IOError as e: + except OSError as e: logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True) def _backup_config_file(self, config_file_path: str) -> str: @@ -288,13 +289,13 @@ class PluginBase(ABC): return "" def _synchronize_config( - self, schema_config: Dict[str, Any], user_config: Dict[str, Any] - ) -> tuple[Dict[str, Any], bool]: + self, schema_config: dict[str, Any], user_config: dict[str, Any] + ) -> tuple[dict[str, Any], bool]: """递归地将用户配置与 schema 同步,返回同步后的配置和是否发生变化的标志""" changed = False # 内部递归函数 - def _sync_dicts(schema_dict: Dict[str, Any], user_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, Any]: + def _sync_dicts(schema_dict: dict[str, Any], user_dict: dict[str, Any], parent_key: str = "") -> dict[str, Any]: nonlocal changed synced_dict = schema_dict.copy() @@ -326,7 +327,7 @@ class PluginBase(ABC): final_config = _sync_dicts(schema_config, user_config) return final_config, changed - def _generate_config_from_schema(self) -> Dict[str, Any]: + def _generate_config_from_schema(self) -> dict[str, Any]: # sourcery skip: dict-comprehension """根据schema生成配置数据结构(不写入文件)""" if not self.config_schema: @@ -348,7 +349,7 @@ class PluginBase(ABC): return config_data - def _save_config_to_file(self, config_data: Dict[str, Any], config_file_path: str): + def _save_config_to_file(self, config_data: dict[str, Any], config_file_path: str): """将配置数据保存为TOML文件(包含注释)""" if not self.config_schema: logger.debug(f"{self.log_prefix} 插件未定义config_schema,不生成配置文件") @@ -410,7 +411,7 @@ class PluginBase(ABC): with open(config_file_path, "w", encoding="utf-8") as f: f.write(toml_str) logger.info(f"{self.log_prefix} 配置文件已保存: {config_file_path}") - except IOError as e: + except OSError as e: logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True) def _load_plugin_config(self): # sourcery skip: extract-method @@ -456,7 +457,7 @@ class PluginBase(ABC): return try: - with open(user_config_path, "r", encoding="utf-8") as f: + with open(user_config_path, encoding="utf-8") as f: user_config = toml.load(f) or {} except Exception as e: logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True) @@ -520,7 +521,7 @@ class PluginBase(ABC): return current - def _normalize_python_dependencies(self, dependencies: Any) -> List[PythonDependency]: + def _normalize_python_dependencies(self, dependencies: Any) -> list[PythonDependency]: """将依赖列表标准化为PythonDependency对象""" from packaging.requirements import Requirement @@ -549,7 +550,7 @@ class PluginBase(ABC): return normalized - def _check_python_dependencies(self, dependencies: List[PythonDependency]) -> bool: + def _check_python_dependencies(self, dependencies: list[PythonDependency]) -> bool: """检查Python依赖并尝试自动安装""" if not dependencies: logger.info(f"{self.log_prefix} 无Python依赖需要检查") diff --git a/src/plugin_system/base/plus_command.py b/src/plugin_system/base/plus_command.py index a64866806..1319560b6 100644 --- a/src/plugin_system/base/plus_command.py +++ b/src/plugin_system/base/plus_command.py @@ -3,17 +3,16 @@ 提供更简单易用的命令处理方式,无需手写正则表达式 """ -from abc import ABC, abstractmethod -from typing import Tuple, Optional, List import re +from abc import ABC, abstractmethod -from src.common.logger import get_logger -from src.plugin_system.base.component_types import PlusCommandInfo, ComponentType, ChatType from src.chat.message_receive.message import MessageRecv -from src.plugin_system.apis import send_api -from src.plugin_system.base.command_args import CommandArgs -from src.plugin_system.base.base_command import BaseCommand +from src.common.logger import get_logger from src.config.config import global_config +from src.plugin_system.apis import send_api +from src.plugin_system.base.base_command import BaseCommand +from src.plugin_system.base.command_args import CommandArgs +from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo logger = get_logger("plus_command") @@ -39,7 +38,7 @@ class PlusCommand(ABC): command_description: str = "" """命令描述""" - command_aliases: List[str] = [] + command_aliases: list[str] = [] """命令别名列表,如 ['say', 'repeat']""" priority: int = 0 @@ -51,7 +50,7 @@ class PlusCommand(ABC): intercept_message: bool = False """是否拦截消息,不进行后续处理""" - def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None): + def __init__(self, message: MessageRecv, plugin_config: dict | None = None): """初始化命令组件 Args: @@ -172,7 +171,7 @@ class PlusCommand(ABC): return False @abstractmethod - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: """执行命令的抽象方法,子类必须实现 Args: @@ -341,7 +340,7 @@ class PlusCommandAdapter(BaseCommand): 将PlusCommand适配到现有的插件系统,继承BaseCommand """ - def __init__(self, plus_command_class, message: MessageRecv, plugin_config: Optional[dict] = None): + def __init__(self, plus_command_class, message: MessageRecv, plugin_config: dict | None = None): """初始化适配器 Args: @@ -363,7 +362,7 @@ class PlusCommandAdapter(BaseCommand): # 创建PlusCommand实例 self.plus_command = plus_command_class(message, plugin_config) - async def execute(self) -> Tuple[bool, Optional[str], bool]: + async def execute(self) -> tuple[bool, str | None, bool]: """执行命令 Returns: @@ -382,7 +381,7 @@ class PlusCommandAdapter(BaseCommand): return await self.plus_command.execute(self.plus_command.args) except Exception as e: logger.error(f"执行命令时出错: {e}", exc_info=True) - return False, f"命令执行出错: {str(e)}", self.intercept_message + return False, f"命令执行出错: {e!s}", self.intercept_message def create_plus_command_adapter(plus_command_class): @@ -401,13 +400,13 @@ def create_plus_command_adapter(plus_command_class): command_pattern = plus_command_class._generate_command_pattern() chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL) - def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None): + def __init__(self, message: MessageRecv, plugin_config: dict | None = None): super().__init__(message, plugin_config) self.plus_command = plus_command_class(message, plugin_config) self.priority = getattr(plus_command_class, "priority", 0) self.intercept_message = getattr(plus_command_class, "intercept_message", False) - async def execute(self) -> Tuple[bool, Optional[str], bool]: + async def execute(self) -> tuple[bool, str | None, bool]: """执行命令""" # 从BaseCommand的正则匹配结果中提取参数 args_text = "" @@ -429,7 +428,7 @@ def create_plus_command_adapter(plus_command_class): return await self.plus_command.execute(command_args) except Exception as e: logger.error(f"执行命令时出错: {e}", exc_info=True) - return False, f"命令执行出错: {str(e)}", self.intercept_message + return False, f"命令执行出错: {e!s}", self.intercept_message return AdapterClass diff --git a/src/plugin_system/core/__init__.py b/src/plugin_system/core/__init__.py index 46aa5a96c..4e43fba11 100644 --- a/src/plugin_system/core/__init__.py +++ b/src/plugin_system/core/__init__.py @@ -4,14 +4,14 @@ 提供插件的加载、注册和管理功能 """ -from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager +from src.plugin_system.core.plugin_manager import plugin_manager __all__ = [ - "plugin_manager", "component_registry", "event_manager", "global_announcement_manager", + "plugin_manager", ] diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 9c82553f8..878b6c465 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -1,27 +1,26 @@ -from pathlib import Path import re - -from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type +from pathlib import Path +from re import Pattern +from typing import Any, Optional, Union from src.common.logger import get_logger -from src.plugin_system.base.component_types import ( - ComponentInfo, - ActionInfo, - ToolInfo, - CommandInfo, - PlusCommandInfo, - EventHandlerInfo, - ChatterInfo, - PluginInfo, - ComponentType, -) - -from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_action import BaseAction -from src.plugin_system.base.base_tool import BaseTool -from src.plugin_system.base.base_events_handler import BaseEventHandler -from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.base.base_chatter import BaseChatter +from src.plugin_system.base.base_command import BaseCommand +from src.plugin_system.base.base_events_handler import BaseEventHandler +from src.plugin_system.base.base_tool import BaseTool +from src.plugin_system.base.component_types import ( + ActionInfo, + ChatterInfo, + CommandInfo, + ComponentInfo, + ComponentType, + EventHandlerInfo, + PluginInfo, + PlusCommandInfo, + ToolInfo, +) +from src.plugin_system.base.plus_command import PlusCommand logger = get_logger("component_registry") @@ -34,46 +33,46 @@ class ComponentRegistry: def __init__(self): # 命名空间式组件名构成法 f"{component_type}.{component_name}" - self._components: Dict[str, "ComponentInfo"] = {} + self._components: dict[str, "ComponentInfo"] = {} """组件注册表 命名空间式组件名 -> 组件信息""" - self._components_by_type: Dict["ComponentType", Dict[str, "ComponentInfo"]] = { + self._components_by_type: dict["ComponentType", dict[str, "ComponentInfo"]] = { types: {} for types in ComponentType } """类型 -> 组件原名称 -> 组件信息""" - self._components_classes: Dict[ - str, Type[Union["BaseCommand", "BaseAction", "BaseTool", "BaseEventHandler", "PlusCommand", "BaseChatter"]] + self._components_classes: dict[ + str, type["BaseCommand" | "BaseAction" | "BaseTool" | "BaseEventHandler" | "PlusCommand" | "BaseChatter"] ] = {} """命名空间式组件名 -> 组件类""" # 插件注册表 - self._plugins: Dict[str, "PluginInfo"] = {} + self._plugins: dict[str, "PluginInfo"] = {} """插件名 -> 插件信息""" # Action特定注册表 - self._action_registry: Dict[str, Type["BaseAction"]] = {} + self._action_registry: dict[str, type["BaseAction"]] = {} """Action注册表 action名 -> action类""" - self._default_actions: Dict[str, "ActionInfo"] = {} + self._default_actions: dict[str, "ActionInfo"] = {} """默认动作集,即启用的Action集,用于重置ActionManager状态""" # Command特定注册表 - self._command_registry: Dict[str, Type["BaseCommand"]] = {} + self._command_registry: dict[str, type["BaseCommand"]] = {} """Command类注册表 command名 -> command类""" - self._command_patterns: Dict[Pattern, str] = {} + self._command_patterns: dict[Pattern, str] = {} """编译后的正则 -> command名""" # 工具特定注册表 - self._tool_registry: Dict[str, Type["BaseTool"]] = {} # 工具名 -> 工具类 - self._llm_available_tools: Dict[str, Type["BaseTool"]] = {} # llm可用的工具名 -> 工具类 + self._tool_registry: dict[str, type["BaseTool"]] = {} # 工具名 -> 工具类 + self._llm_available_tools: dict[str, type["BaseTool"]] = {} # llm可用的工具名 -> 工具类 # EventHandler特定注册表 - self._event_handler_registry: Dict[str, Type["BaseEventHandler"]] = {} + self._event_handler_registry: dict[str, type["BaseEventHandler"]] = {} """event_handler名 -> event_handler类""" - self._enabled_event_handlers: Dict[str, Type["BaseEventHandler"]] = {} + self._enabled_event_handlers: dict[str, type["BaseEventHandler"]] = {} """启用的事件处理器 event_handler名 -> event_handler类""" - self._chatter_registry: Dict[str, Type["BaseChatter"]] = {} + self._chatter_registry: dict[str, type["BaseChatter"]] = {} """chatter名 -> chatter类""" - self._enabled_chatter_registry: Dict[str, Type["BaseChatter"]] = {} + self._enabled_chatter_registry: dict[str, type["BaseChatter"]] = {} """启用的chatter名 -> chatter类""" logger.info("组件注册中心初始化完成") @@ -101,7 +100,7 @@ class ComponentRegistry: def register_component( self, component_info: ComponentInfo, - component_class: Type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]], + component_class: type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]], ) -> bool: """注册组件 @@ -174,7 +173,7 @@ class ComponentRegistry: ) return True - def _register_action_component(self, action_info: "ActionInfo", action_class: Type["BaseAction"]) -> bool: + def _register_action_component(self, action_info: "ActionInfo", action_class: type["BaseAction"]) -> bool: """注册Action组件到Action特定注册表""" if not (action_name := action_info.name): logger.error(f"Action组件 {action_class.__name__} 必须指定名称") @@ -194,7 +193,7 @@ class ComponentRegistry: return True - def _register_command_component(self, command_info: "CommandInfo", command_class: Type["BaseCommand"]) -> bool: + def _register_command_component(self, command_info: "CommandInfo", command_class: type["BaseCommand"]) -> bool: """注册Command组件到Command特定注册表""" if not (command_name := command_info.name): logger.error(f"Command组件 {command_class.__name__} 必须指定名称") @@ -221,7 +220,7 @@ class ComponentRegistry: return True def _register_plus_command_component( - self, plus_command_info: "PlusCommandInfo", plus_command_class: Type["PlusCommand"] + self, plus_command_info: "PlusCommandInfo", plus_command_class: type["PlusCommand"] ) -> bool: """注册PlusCommand组件到特定注册表""" plus_command_name = plus_command_info.name @@ -235,7 +234,7 @@ class ComponentRegistry: # 创建专门的PlusCommand注册表(如果还没有) if not hasattr(self, "_plus_command_registry"): - self._plus_command_registry: Dict[str, Type["PlusCommand"]] = {} + self._plus_command_registry: dict[str, type["PlusCommand"]] = {} plus_command_class.plugin_name = plus_command_info.plugin_name # 设置插件配置 @@ -245,7 +244,7 @@ class ComponentRegistry: logger.debug(f"已注册PlusCommand组件: {plus_command_name}") return True - def _register_tool_component(self, tool_info: "ToolInfo", tool_class: Type["BaseTool"]) -> bool: + def _register_tool_component(self, tool_info: "ToolInfo", tool_class: type["BaseTool"]) -> bool: """注册Tool组件到Tool特定注册表""" tool_name = tool_info.name @@ -261,7 +260,7 @@ class ComponentRegistry: return True def _register_event_handler_component( - self, handler_info: "EventHandlerInfo", handler_class: Type["BaseEventHandler"] + self, handler_info: "EventHandlerInfo", handler_class: type["BaseEventHandler"] ) -> bool: if not (handler_name := handler_info.name): logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称") @@ -287,7 +286,7 @@ class ComponentRegistry: handler_class, self.get_plugin_config(handler_info.plugin_name) or {} ) - def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: Type["BaseChatter"]) -> bool: + def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: type["BaseChatter"]) -> bool: """注册Chatter组件到Chatter特定注册表""" chatter_name = chatter_info.name @@ -532,7 +531,7 @@ class ComponentRegistry: self, component_name: str, component_type: Optional["ComponentType"] = None, - ) -> Optional[Union[Type["BaseCommand"], Type["BaseAction"], Type["BaseEventHandler"], Type["BaseTool"]]]: + ) -> type["BaseCommand"] | type["BaseAction"] | type["BaseEventHandler"] | type["BaseTool"] | None: """获取组件类,支持自动命名空间解析 Args: @@ -574,18 +573,18 @@ class ComponentRegistry: # 4. 都没找到 return None - def get_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]: + def get_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]: """获取指定类型的所有组件""" return self._components_by_type.get(component_type, {}).copy() - def get_enabled_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]: + def get_enabled_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]: """获取指定类型的所有启用组件""" components = self.get_components_by_type(component_type) return {name: info for name, info in components.items() if info.enabled} # === Action特定查询方法 === - def get_action_registry(self) -> Dict[str, Type["BaseAction"]]: + def get_action_registry(self) -> dict[str, type["BaseAction"]]: """获取Action注册表""" return self._action_registry.copy() @@ -594,13 +593,13 @@ class ComponentRegistry: info = self.get_component_info(action_name, ComponentType.ACTION) return info if isinstance(info, ActionInfo) else None - def get_default_actions(self) -> Dict[str, ActionInfo]: + def get_default_actions(self) -> dict[str, ActionInfo]: """获取默认动作集""" return self._default_actions.copy() # === Command特定查询方法 === - def get_command_registry(self) -> Dict[str, Type["BaseCommand"]]: + def get_command_registry(self) -> dict[str, type["BaseCommand"]]: """获取Command注册表""" return self._command_registry.copy() @@ -609,11 +608,11 @@ class ComponentRegistry: info = self.get_component_info(command_name, ComponentType.COMMAND) return info if isinstance(info, CommandInfo) else None - def get_command_patterns(self) -> Dict[Pattern, str]: + def get_command_patterns(self) -> dict[Pattern, str]: """获取Command模式注册表""" return self._command_patterns.copy() - def find_command_by_text(self, text: str) -> Optional[Tuple[Type["BaseCommand"], dict, "CommandInfo"]]: + def find_command_by_text(self, text: str) -> tuple[type["BaseCommand"], dict, "CommandInfo"] | None: # sourcery skip: use-named-expression, use-next """根据文本查找匹配的命令 @@ -640,11 +639,11 @@ class ComponentRegistry: return None # === Tool 特定查询方法 === - def get_tool_registry(self) -> Dict[str, Type["BaseTool"]]: + def get_tool_registry(self) -> dict[str, type["BaseTool"]]: """获取Tool注册表""" return self._tool_registry.copy() - def get_llm_available_tools(self) -> Dict[str, Type["BaseTool"]]: + def get_llm_available_tools(self) -> dict[str, type["BaseTool"]]: """获取LLM可用的Tool列表""" return self._llm_available_tools.copy() @@ -661,10 +660,10 @@ class ComponentRegistry: return info if isinstance(info, ToolInfo) else None # === PlusCommand 特定查询方法 === - def get_plus_command_registry(self) -> Dict[str, Type["PlusCommand"]]: + def get_plus_command_registry(self) -> dict[str, type["PlusCommand"]]: """获取PlusCommand注册表""" if not hasattr(self, "_plus_command_registry"): - self._plus_command_registry: Dict[str, Type[PlusCommand]] = {} + self._plus_command_registry: dict[str, type[PlusCommand]] = {} return self._plus_command_registry.copy() def get_registered_plus_command_info(self, command_name: str) -> Optional["PlusCommandInfo"]: @@ -681,7 +680,7 @@ class ComponentRegistry: # === EventHandler 特定查询方法 === - def get_event_handler_registry(self) -> Dict[str, Type["BaseEventHandler"]]: + def get_event_handler_registry(self) -> dict[str, type["BaseEventHandler"]]: """获取事件处理器注册表""" return self._event_handler_registry.copy() @@ -690,21 +689,21 @@ class ComponentRegistry: info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER) return info if isinstance(info, EventHandlerInfo) else None - def get_enabled_event_handlers(self) -> Dict[str, Type["BaseEventHandler"]]: + def get_enabled_event_handlers(self) -> dict[str, type["BaseEventHandler"]]: """获取启用的事件处理器""" return self._enabled_event_handlers.copy() # === Chatter 特定查询方法 === - def get_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]: + def get_chatter_registry(self) -> dict[str, type["BaseChatter"]]: """获取Chatter注册表""" if not hasattr(self, "_chatter_registry"): - self._chatter_registry: Dict[str, Type[BaseChatter]] = {} + self._chatter_registry: dict[str, type[BaseChatter]] = {} return self._chatter_registry.copy() - def get_enabled_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]: + def get_enabled_chatter_registry(self) -> dict[str, type["BaseChatter"]]: """获取启用的Chatter注册表""" if not hasattr(self, "_enabled_chatter_registry"): - self._enabled_chatter_registry: Dict[str, Type[BaseChatter]] = {} + self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {} return self._enabled_chatter_registry.copy() def get_registered_chatter_info(self, chatter_name: str) -> Optional["ChatterInfo"]: @@ -718,7 +717,7 @@ class ComponentRegistry: """获取插件信息""" return self._plugins.get(plugin_name) - def get_all_plugins(self) -> Dict[str, "PluginInfo"]: + def get_all_plugins(self) -> dict[str, "PluginInfo"]: """获取所有插件""" return self._plugins.copy() @@ -726,7 +725,7 @@ class ComponentRegistry: # """获取所有启用的插件""" # return {name: info for name, info in self._plugins.items() if info.enabled} - def get_plugin_components(self, plugin_name: str) -> List["ComponentInfo"]: + def get_plugin_components(self, plugin_name: str) -> list["ComponentInfo"]: """获取插件的所有组件""" plugin_info = self.get_plugin_info(plugin_name) return plugin_info.components if plugin_info else [] @@ -753,7 +752,7 @@ class ComponentRegistry: config_path = Path("config") / "plugins" / plugin_name / "config.toml" if config_path.exists(): - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: config_data = toml.load(f) logger.debug(f"从配置文件读取插件 {plugin_name} 的配置") return config_data @@ -762,7 +761,7 @@ class ComponentRegistry: return {} - def get_registry_stats(self) -> Dict[str, Any]: + def get_registry_stats(self) -> dict[str, Any]: """获取注册中心统计信息""" action_components: int = 0 command_components: int = 0 diff --git a/src/plugin_system/core/event_manager.py b/src/plugin_system/core/event_manager.py index dac75b88f..8a7c7d66c 100644 --- a/src/plugin_system/core/event_manager.py +++ b/src/plugin_system/core/event_manager.py @@ -3,8 +3,8 @@ 提供统一的事件注册、管理和触发接口 """ -from typing import Dict, Type, List, Optional, Any, Union from threading import Lock +from typing import Any, Optional from src.common.logger import get_logger from src.plugin_system import BaseEventHandler @@ -37,17 +37,17 @@ class EventManager: if self._initialized: return - self._events: Dict[str, BaseEvent] = {} - self._event_handlers: Dict[str, Type[BaseEventHandler]] = {} - self._pending_subscriptions: Dict[str, List[str]] = {} # 缓存失败的订阅 + self._events: dict[str, BaseEvent] = {} + self._event_handlers: dict[str, type[BaseEventHandler]] = {} + self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅 self._initialized = True logger.info("EventManager 单例初始化完成") def register_event( self, - event_name: Union[EventType, str], - allowed_subscribers: List[str] = None, - allowed_triggers: List[str] = None, + event_name: EventType | str, + allowed_subscribers: list[str] = None, + allowed_triggers: list[str] = None, ) -> bool: """注册一个新的事件 @@ -75,7 +75,7 @@ class EventManager: return True - def get_event(self, event_name: Union[EventType, str]) -> Optional[BaseEvent]: + def get_event(self, event_name: EventType | str) -> BaseEvent | None: """获取指定事件实例 Args: @@ -86,7 +86,7 @@ class EventManager: """ return self._events.get(event_name) - def get_all_events(self) -> Dict[str, BaseEvent]: + def get_all_events(self) -> dict[str, BaseEvent]: """获取所有已注册的事件 Returns: @@ -94,7 +94,7 @@ class EventManager: """ return self._events.copy() - def get_enabled_events(self) -> Dict[str, BaseEvent]: + def get_enabled_events(self) -> dict[str, BaseEvent]: """获取所有已启用的事件 Returns: @@ -102,7 +102,7 @@ class EventManager: """ return {name: event for name, event in self._events.items() if event.enabled} - def get_disabled_events(self) -> Dict[str, BaseEvent]: + def get_disabled_events(self) -> dict[str, BaseEvent]: """获取所有已禁用的事件 Returns: @@ -110,7 +110,7 @@ class EventManager: """ return {name: event for name, event in self._events.items() if not event.enabled} - def enable_event(self, event_name: Union[EventType, str]) -> bool: + def enable_event(self, event_name: EventType | str) -> bool: """启用指定事件 Args: @@ -128,7 +128,7 @@ class EventManager: logger.info(f"事件 {event_name} 已启用") return True - def disable_event(self, event_name: Union[EventType, str]) -> bool: + def disable_event(self, event_name: EventType | str) -> bool: """禁用指定事件 Args: @@ -146,9 +146,7 @@ class EventManager: logger.info(f"事件 {event_name} 已禁用") return True - def register_event_handler( - self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None - ) -> bool: + def register_event_handler(self, handler_class: type[BaseEventHandler], plugin_config: dict | None = None) -> bool: """注册事件处理器 Args: @@ -190,7 +188,7 @@ class EventManager: logger.info(f"事件处理器 {handler_name} 注册成功") return True - def get_event_handler(self, handler_name: str) -> Optional[Type[BaseEventHandler]]: + def get_event_handler(self, handler_name: str) -> type[BaseEventHandler] | None: """获取指定事件处理器实例 Args: @@ -209,7 +207,7 @@ class EventManager: """ return self._event_handlers.copy() - def subscribe_handler_to_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool: + def subscribe_handler_to_event(self, handler_name: str, event_name: EventType | str) -> bool: """订阅事件处理器到指定事件 Args: @@ -246,7 +244,7 @@ class EventManager: logger.info(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成") return True - def unsubscribe_handler_from_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool: + def unsubscribe_handler_from_event(self, handler_name: str, event_name: EventType | str) -> bool: """从指定事件取消订阅事件处理器 Args: @@ -276,7 +274,7 @@ class EventManager: return removed - def get_event_subscribers(self, event_name: Union[EventType, str]) -> Dict[str, BaseEventHandler]: + def get_event_subscribers(self, event_name: EventType | str) -> dict[str, BaseEventHandler]: """获取订阅指定事件的所有事件处理器 Args: @@ -292,8 +290,8 @@ class EventManager: return {handler.handler_name: handler for handler in event.subscribers} async def trigger_event( - self, event_name: Union[EventType, str], permission_group: Optional[str] = "", **kwargs - ) -> Optional[HandlerResultsCollection]: + self, event_name: EventType | str, permission_group: str | None = "", **kwargs + ) -> HandlerResultsCollection | None: """触发指定事件 Args: @@ -345,7 +343,7 @@ class EventManager: self._event_handlers.clear() logger.info("所有事件和处理器已清除") - def get_event_summary(self) -> Dict[str, Any]: + def get_event_summary(self) -> dict[str, Any]: """获取事件系统摘要 Returns: @@ -364,7 +362,7 @@ class EventManager: "pending_subscriptions": len(self._pending_subscriptions), } - def _process_pending_subscriptions(self, event_name: Union[EventType, str]) -> None: + def _process_pending_subscriptions(self, event_name: EventType | str) -> None: """处理指定事件的缓存订阅 Args: diff --git a/src/plugin_system/core/global_announcement_manager.py b/src/plugin_system/core/global_announcement_manager.py index 05abf0b79..1dca4a53a 100644 --- a/src/plugin_system/core/global_announcement_manager.py +++ b/src/plugin_system/core/global_announcement_manager.py @@ -1,5 +1,3 @@ -from typing import List, Dict - from src.common.logger import get_logger logger = get_logger("global_announcement_manager") @@ -8,13 +6,13 @@ logger = get_logger("global_announcement_manager") class GlobalAnnouncementManager: def __init__(self) -> None: # 用户禁用的动作,chat_id -> [action_name] - self._user_disabled_actions: Dict[str, List[str]] = {} + self._user_disabled_actions: dict[str, list[str]] = {} # 用户禁用的命令,chat_id -> [command_name] - self._user_disabled_commands: Dict[str, List[str]] = {} + self._user_disabled_commands: dict[str, list[str]] = {} # 用户禁用的事件处理器,chat_id -> [handler_name] - self._user_disabled_event_handlers: Dict[str, List[str]] = {} + self._user_disabled_event_handlers: dict[str, list[str]] = {} # 用户禁用的工具,chat_id -> [tool_name] - self._user_disabled_tools: Dict[str, List[str]] = {} + self._user_disabled_tools: dict[str, list[str]] = {} def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool: """禁用特定聊天的某个动作""" @@ -100,19 +98,19 @@ class GlobalAnnouncementManager: return False return False - def get_disabled_chat_actions(self, chat_id: str) -> List[str]: + def get_disabled_chat_actions(self, chat_id: str) -> list[str]: """获取特定聊天禁用的所有动作""" return self._user_disabled_actions.get(chat_id, []).copy() - def get_disabled_chat_commands(self, chat_id: str) -> List[str]: + def get_disabled_chat_commands(self, chat_id: str) -> list[str]: """获取特定聊天禁用的所有命令""" return self._user_disabled_commands.get(chat_id, []).copy() - def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]: + def get_disabled_chat_event_handlers(self, chat_id: str) -> list[str]: """获取特定聊天禁用的所有事件处理器""" return self._user_disabled_event_handlers.get(chat_id, []).copy() - def get_disabled_chat_tools(self, chat_id: str) -> List[str]: + def get_disabled_chat_tools(self, chat_id: str) -> list[str]: """获取特定聊天禁用的所有工具""" return self._user_disabled_tools.get(chat_id, []).copy() diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index 0bb22afdf..99f00340c 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -4,16 +4,16 @@ 这个模块提供了权限系统的核心实现,包括权限检查、权限节点管理、用户权限管理等功能。 """ -from typing import List, Set, Tuple -from sqlalchemy.ext.asyncio import async_sessionmaker -from sqlalchemy.exc import IntegrityError, SQLAlchemyError from datetime import datetime -from sqlalchemy import select, delete +from sqlalchemy import delete, select +from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.ext.asyncio import async_sessionmaker + +from src.common.database.sqlalchemy_models import PermissionNodes, UserPermissions, get_engine from src.common.logger import get_logger -from src.common.database.sqlalchemy_models import get_engine, PermissionNodes, UserPermissions -from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo from src.config.config import global_config +from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo logger = get_logger(__name__) @@ -24,7 +24,7 @@ class PermissionManager(IPermissionManager): def __init__(self): self.engine = None self.SessionLocal = None - self._master_users: Set[Tuple[str, str]] = set() + self._master_users: set[tuple[str, str]] = set() self._load_master_users() async def initialize(self): @@ -276,7 +276,7 @@ class PermissionManager(IPermissionManager): logger.error(f"撤销权限时发生未知错误: {e}") return False - async def get_user_permissions(self, user: UserInfo) -> List[str]: + async def get_user_permissions(self, user: UserInfo) -> list[str]: """ 获取用户拥有的所有权限节点 @@ -328,7 +328,7 @@ class PermissionManager(IPermissionManager): logger.error(f"获取用户权限时发生未知错误: {e}") return [] - async def get_all_permission_nodes(self) -> List[PermissionNode]: + async def get_all_permission_nodes(self) -> list[PermissionNode]: """ 获取所有已注册的权限节点 @@ -356,7 +356,7 @@ class PermissionManager(IPermissionManager): logger.error(f"获取所有权限节点时发生未知错误: {e}") return [] - async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: + async def get_plugin_permission_nodes(self, plugin_name: str) -> list[PermissionNode]: """ 获取指定插件的所有权限节点 @@ -431,7 +431,7 @@ class PermissionManager(IPermissionManager): logger.error(f"删除插件权限时发生未知错误: {e}") return False - async def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]: + async def get_users_with_permission(self, permission_node: str) -> list[tuple[str, str]]: """ 获取拥有指定权限的所有用户 diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 2950101a9..046c05b4f 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -1,19 +1,17 @@ import asyncio +import importlib import os import traceback -import importlib - -from typing import Dict, List, Optional, Tuple, Type, Any -from importlib.util import spec_from_file_location, module_from_spec +from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path - +from typing import Any, Optional from src.common.logger import get_logger -from src.plugin_system.base.plugin_base import PluginBase from src.plugin_system.base.component_types import ComponentType +from src.plugin_system.base.plugin_base import PluginBase from src.plugin_system.utils.manifest_utils import VersionComparator -from .component_registry import component_registry +from .component_registry import component_registry logger = get_logger("plugin_manager") @@ -26,12 +24,12 @@ class PluginManager: """ def __init__(self): - self.plugin_directories: List[str] = [] # 插件根目录列表 - self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类 - self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径 + self.plugin_directories: list[str] = [] # 插件根目录列表 + self.plugin_classes: dict[str, type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类 + self.plugin_paths: dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径 - self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例 - self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息 + self.loaded_plugins: dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例 + self.failed_plugins: dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息 # 确保插件目录存在 self._ensure_plugin_directories() @@ -54,7 +52,7 @@ class PluginManager: # === 插件加载管理 === - def load_all_plugins(self) -> Tuple[int, int]: + def load_all_plugins(self) -> tuple[int, int]: """加载所有插件 Returns: @@ -87,7 +85,7 @@ class PluginManager: return total_registered, total_failed_registration - def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]: + def load_registered_plugin_classes(self, plugin_name: str) -> tuple[bool, int]: # sourcery skip: extract-duplicate-method, extract-method """ 加载已经注册的插件类 @@ -142,7 +140,7 @@ class PluginManager: except FileNotFoundError as e: # manifest文件缺失 - error_msg = f"缺少manifest文件: {str(e)}" + error_msg = f"缺少manifest文件: {e!s}" self.failed_plugins[plugin_name] = error_msg logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") return False, 1 @@ -150,14 +148,14 @@ class PluginManager: except ValueError as e: # manifest文件格式错误或验证失败 traceback.print_exc() - error_msg = f"manifest验证失败: {str(e)}" + error_msg = f"manifest验证失败: {e!s}" self.failed_plugins[plugin_name] = error_msg logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") return False, 1 except Exception as e: # 其他错误 - error_msg = f"未知错误: {str(e)}" + error_msg = f"未知错误: {e!s}" self.failed_plugins[plugin_name] = error_msg logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") logger.debug("详细错误信息: ", exc_info=True) @@ -192,7 +190,7 @@ class PluginManager: logger.debug(f"插件 {plugin_name} 重载成功") return True - def rescan_plugin_directory(self) -> Tuple[int, int]: + def rescan_plugin_directory(self) -> tuple[int, int]: """ 重新扫描插件根目录 """ @@ -220,7 +218,7 @@ class PluginManager: return self.loaded_plugins.get(plugin_name) # === 查询方法 === - def list_loaded_plugins(self) -> List[str]: + def list_loaded_plugins(self) -> list[str]: """ 列出所有当前加载的插件。 @@ -229,7 +227,7 @@ class PluginManager: """ return list(self.loaded_plugins.keys()) - def list_registered_plugins(self) -> List[str]: + def list_registered_plugins(self) -> list[str]: """ 列出所有已注册的插件类。 @@ -238,7 +236,7 @@ class PluginManager: """ return list(self.plugin_classes.keys()) - def get_plugin_path(self, plugin_name: str) -> Optional[str]: + def get_plugin_path(self, plugin_name: str) -> str | None: """ 获取指定插件的路径。 @@ -329,7 +327,7 @@ class PluginManager: # == 兼容性检查 == @staticmethod - def _check_plugin_version_compatibility(plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: + def _check_plugin_version_compatibility(plugin_name: str, manifest_data: dict[str, Any]) -> tuple[bool, str]: """检查插件版本兼容性 Args: @@ -569,7 +567,7 @@ class PluginManager: return True except Exception as e: - logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}", exc_info=True) + logger.error(f"❌ 插件卸载失败: {plugin_name} - {e!s}", exc_info=True) return False def reload_plugin(self, plugin_name: str) -> bool: @@ -606,7 +604,7 @@ class PluginManager: return False except Exception as e: - logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}", exc_info=True) + logger.error(f"❌ 插件重载失败: {plugin_name} - {e!s}", exc_info=True) return False def force_reload_plugin(self, plugin_name: str) -> bool: diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index e666e32d4..17fe46ddf 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -1,16 +1,17 @@ +import inspect import time -from typing import List, Dict, Tuple, Optional, Any +from typing import Any + +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.cache_manager import tool_cache +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.payload_content import ToolCall +from src.llm_models.utils_model import LLMRequest from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.core.global_announcement_manager import global_announcement_manager -from src.llm_models.utils_model import LLMRequest -from src.llm_models.payload_content import ToolCall -from src.config.config import global_config, model_config -from src.chat.utils.prompt import Prompt, global_prompt_manager -import inspect -from src.chat.message_receive.chat_stream import get_chat_manager -from src.common.logger import get_logger -from src.common.cache_manager import tool_cache logger = get_logger("tool_use") @@ -56,14 +57,14 @@ class ToolExecutor: self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor") # 二步工具调用状态管理 - self._pending_step_two_tools: Dict[str, Dict[str, Any]] = {} + self._pending_step_two_tools: dict[str, dict[str, Any]] = {} """待处理的第二步工具调用,格式为 {tool_name: step_two_definition}""" logger.info(f"{self.log_prefix}工具执行器初始化完成") async def execute_from_chat_message( self, target_message: str, chat_history: str, sender: str, return_details: bool = False - ) -> Tuple[List[Dict[str, Any]], List[str], str]: + ) -> tuple[list[dict[str, Any]], list[str], str]: """从聊天消息执行工具 Args: @@ -113,7 +114,7 @@ class ToolExecutor: else: return tool_results, [], "" - def _get_tool_definitions(self) -> List[Dict[str, Any]]: + def _get_tool_definitions(self) -> list[dict[str, Any]]: all_tools = get_llm_available_tool_definitions() user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) @@ -129,7 +130,7 @@ class ToolExecutor: return tool_definitions - async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: + async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]: """执行工具调用 Args: @@ -138,7 +139,7 @@ class ToolExecutor: Returns: Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表) """ - tool_results: List[Dict[str, Any]] = [] + tool_results: list[dict[str, Any]] = [] used_tools = [] if not tool_calls: @@ -192,7 +193,7 @@ class ToolExecutor: error_info = { "type": "tool_error", "id": f"tool_error_{time.time()}", - "content": f"工具{tool_name}执行失败: {str(e)}", + "content": f"工具{tool_name}执行失败: {e!s}", "tool_name": tool_name, "timestamp": time.time(), } @@ -201,8 +202,8 @@ class ToolExecutor: return tool_results, used_tools async def execute_tool_call( - self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None - ) -> Optional[Dict[str, Any]]: + self, tool_call: ToolCall, tool_instance: BaseTool | None = None + ) -> dict[str, Any] | None: """执行单个工具调用,并处理缓存""" function_args = tool_call.args or {} @@ -256,8 +257,8 @@ class ToolExecutor: return result async def _original_execute_tool_call( - self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None - ) -> Optional[Dict[str, Any]]: + self, tool_call: ToolCall, tool_instance: BaseTool | None = None + ) -> dict[str, Any] | None: """执行单个工具调用的原始逻辑""" try: function_name = tool_call.func_name @@ -323,10 +324,10 @@ class ToolExecutor: logger.warning(f"{self.log_prefix}工具 {function_name} 返回空结果") return None except Exception as e: - logger.error(f"执行工具调用时发生错误: {str(e)}") + logger.error(f"执行工具调用时发生错误: {e!s}") raise e - async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]: + async def execute_specific_tool_simple(self, tool_name: str, tool_args: dict) -> dict | None: """直接执行指定工具 Args: diff --git a/src/plugin_system/utils/dependency_alias.py b/src/plugin_system/utils/dependency_alias.py index 7a2aa1d80..a7e478d76 100644 --- a/src/plugin_system/utils/dependency_alias.py +++ b/src/plugin_system/utils/dependency_alias.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 本模块包含一个从Python包的“安装名”到其“导入名”的映射。 diff --git a/src/plugin_system/utils/dependency_config.py b/src/plugin_system/utils/dependency_config.py index b14f88b46..081d0216c 100644 --- a/src/plugin_system/utils/dependency_config.py +++ b/src/plugin_system/utils/dependency_config.py @@ -1,4 +1,3 @@ -from typing import Optional from src.common.logger import get_logger logger = get_logger("dependency_config") @@ -66,7 +65,7 @@ class DependencyConfig: # 全局配置实例 -_global_dependency_config: Optional[DependencyConfig] = None +_global_dependency_config: DependencyConfig | None = None def get_dependency_config() -> DependencyConfig: diff --git a/src/plugin_system/utils/dependency_manager.py b/src/plugin_system/utils/dependency_manager.py index 980f538cc..4d5e48a9d 100644 --- a/src/plugin_system/utils/dependency_manager.py +++ b/src/plugin_system/utils/dependency_manager.py @@ -1,8 +1,9 @@ -import subprocess -import sys import importlib import importlib.util -from typing import List, Tuple, Optional, Any +import subprocess +import sys +from typing import Any + from packaging import version from packaging.requirements import Requirement @@ -19,7 +20,7 @@ class DependencyManager: 负责检查和自动安装插件的Python包依赖 """ - def __init__(self, auto_install: bool = True, use_mirror: bool = False, mirror_url: Optional[str] = None): + def __init__(self, auto_install: bool = True, use_mirror: bool = False, mirror_url: str | None = None): """初始化依赖管理器 Args: @@ -46,7 +47,7 @@ class DependencyManager: self.mirror_url = mirror_url or "" self.install_timeout = 300 - def check_dependencies(self, dependencies: Any, plugin_name: str = "") -> Tuple[bool, List[str], List[str]]: + def check_dependencies(self, dependencies: Any, plugin_name: str = "") -> tuple[bool, list[str], list[str]]: """检查依赖包是否满足要求 Args: @@ -69,7 +70,7 @@ class DependencyManager: logger.info(f"{log_prefix}缺少依赖包: {dep.get_pip_requirement()}") missing_packages.append(dep.get_pip_requirement()) except Exception as e: - error_msg = f"检查依赖 {dep.package_name} 时发生错误: {str(e)}" + error_msg = f"检查依赖 {dep.package_name} 时发生错误: {e!s}" error_messages.append(error_msg) logger.error(f"{log_prefix}{error_msg}") @@ -84,7 +85,7 @@ class DependencyManager: return all_satisfied, missing_packages, error_messages - def install_dependencies(self, packages: List[str], plugin_name: str = "") -> Tuple[bool, List[str]]: + def install_dependencies(self, packages: list[str], plugin_name: str = "") -> tuple[bool, list[str]]: """自动安装缺失的依赖包 Args: @@ -115,7 +116,7 @@ class DependencyManager: logger.error(f"{log_prefix}❌ 安装失败: {package}") except Exception as e: failed_packages.append(package) - logger.error(f"{log_prefix}❌ 安装 {package} 时发生异常: {str(e)}") + logger.error(f"{log_prefix}❌ 安装 {package} 时发生异常: {e!s}") success = len(failed_packages) == 0 if success: @@ -125,7 +126,7 @@ class DependencyManager: return success, failed_packages - def check_and_install_dependencies(self, dependencies: Any, plugin_name: str = "") -> Tuple[bool, List[str]]: + def check_and_install_dependencies(self, dependencies: Any, plugin_name: str = "") -> tuple[bool, list[str]]: """检查并自动安装依赖(组合操作) Args: @@ -163,7 +164,7 @@ class DependencyManager: return False, all_errors @staticmethod - def _normalize_dependencies(dependencies: Any) -> List[PythonDependency]: + def _normalize_dependencies(dependencies: Any) -> list[PythonDependency]: """将依赖列表标准化为PythonDependency对象""" normalized = [] @@ -277,7 +278,7 @@ class DependencyManager: # 全局依赖管理器实例 -_global_dependency_manager: Optional[DependencyManager] = None +_global_dependency_manager: DependencyManager | None = None def get_dependency_manager() -> DependencyManager: @@ -288,7 +289,7 @@ def get_dependency_manager() -> DependencyManager: return _global_dependency_manager -def configure_dependency_manager(auto_install: bool = True, use_mirror: bool = False, mirror_url: Optional[str] = None): +def configure_dependency_manager(auto_install: bool = True, use_mirror: bool = False, mirror_url: str | None = None): """配置全局依赖管理器""" global _global_dependency_manager _global_dependency_manager = DependencyManager( diff --git a/src/plugin_system/utils/manifest_utils.py b/src/plugin_system/utils/manifest_utils.py index b714aefd7..21025127f 100644 --- a/src/plugin_system/utils/manifest_utils.py +++ b/src/plugin_system/utils/manifest_utils.py @@ -5,7 +5,8 @@ """ import re -from typing import Dict, Any, Tuple +from typing import Any + from src.common.logger import get_logger from src.config.config import MMC_VERSION @@ -70,7 +71,7 @@ class VersionComparator: return normalized @staticmethod - def parse_version(version: str) -> Tuple[int, int, int]: + def parse_version(version: str) -> tuple[int, int, int]: """解析版本号为元组 Args: @@ -109,7 +110,7 @@ class VersionComparator: return 0 @staticmethod - def check_forward_compatibility(current_version: str, max_version: str) -> Tuple[bool, str]: + def check_forward_compatibility(current_version: str, max_version: str) -> tuple[bool, str]: """检查向前兼容性(仅使用兼容性映射表) Args: @@ -131,7 +132,7 @@ class VersionComparator: return False, "" @staticmethod - def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]: + def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> tuple[bool, str]: """检查版本是否在指定范围内,支持兼容性检查 Args: @@ -195,7 +196,7 @@ class VersionComparator: logger.info(f"添加兼容性映射:{base_normalized} -> {compatible_versions}") @staticmethod - def get_compatibility_info() -> Dict[str, list]: + def get_compatibility_info() -> dict[str, list]: """获取当前的兼容性映射表 Returns: @@ -232,7 +233,7 @@ class ManifestValidator: self.validation_errors = [] self.validation_warnings = [] - def validate_manifest(self, manifest_data: Dict[str, Any]) -> bool: + def validate_manifest(self, manifest_data: dict[str, Any]) -> bool: """验证manifest数据 Args: @@ -266,7 +267,7 @@ class ManifestValidator: if "name" not in author or not author["name"]: self.validation_errors.append("作者信息缺少name字段或为空") # url字段是可选的 - if "url" in author and author["url"]: + if author.get("url"): url = author["url"] if not (url.startswith("http://") or url.startswith("https://")): self.validation_warnings.append("作者URL建议使用完整的URL格式") @@ -305,7 +306,7 @@ class ManifestValidator: # 检查URL格式(可选字段) for url_field in ["homepage_url", "repository_url"]: - if url_field in manifest_data and manifest_data[url_field]: + if manifest_data.get(url_field): url: str = manifest_data[url_field] if not (url.startswith("http://") or url.startswith("https://")): self.validation_warnings.append(f"{url_field}建议使用完整的URL格式") diff --git a/src/plugin_system/utils/permission_decorators.py b/src/plugin_system/utils/permission_decorators.py index 278ab2068..7629e608c 100644 --- a/src/plugin_system/utils/permission_decorators.py +++ b/src/plugin_system/utils/permission_decorators.py @@ -4,19 +4,19 @@ 提供方便的权限检查装饰器,用于插件命令和其他需要权限验证的地方。 """ +from collections.abc import Callable from functools import wraps -from typing import Callable, Optional from inspect import iscoroutinefunction +from src.chat.message_receive.chat_stream import ChatStream +from src.plugin_system.apis.logging_api import get_logger from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.apis.send_api import text_to_stream -from src.plugin_system.apis.logging_api import get_logger -from src.chat.message_receive.chat_stream import ChatStream logger = get_logger(__name__) -def require_permission(permission_node: str, deny_message: Optional[str] = None): +def require_permission(permission_node: str, deny_message: str | None = None): """ 权限检查装饰器 @@ -90,7 +90,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None) return decorator -def require_master(deny_message: Optional[str] = None): +def require_master(deny_message: str | None = None): """ Master权限检查装饰器 @@ -186,9 +186,7 @@ class PermissionChecker: return permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id) @staticmethod - async def ensure_permission( - chat_stream: ChatStream, permission_node: str, deny_message: Optional[str] = None - ) -> bool: + async def ensure_permission(chat_stream: ChatStream, permission_node: str, deny_message: str | None = None) -> bool: """ 确保用户拥有指定权限,如果没有权限会发送消息并返回False @@ -209,7 +207,7 @@ class PermissionChecker: return has_permission @staticmethod - async def ensure_master(chat_stream: ChatStream, deny_message: Optional[str] = None) -> bool: + async def ensure_master(chat_stream: ChatStream, deny_message: str | None = None) -> bool: """ 确保用户为Master用户,如果不是会发送消息并返回False diff --git a/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py b/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py index 3a652e2f4..25cdb1fa0 100644 --- a/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py +++ b/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py @@ -7,15 +7,15 @@ import asyncio import time import traceback from datetime import datetime -from typing import Dict, Any +from typing import Any +from src.chat.express.expression_learner import expression_learner_manager +from src.chat.planner_actions.action_manager import ChatterActionManager +from src.common.data_models.message_manager_data_model import StreamContext +from src.common.logger import get_logger from src.plugin_system.base.base_chatter import BaseChatter from src.plugin_system.base.component_types import ChatType -from src.common.data_models.message_manager_data_model import StreamContext from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner -from src.chat.planner_actions.action_manager import ChatterActionManager -from src.common.logger import get_logger -from src.chat.express.expression_learner import expression_learner_manager logger = get_logger("affinity_chatter") @@ -113,7 +113,7 @@ class AffinityChatter(BaseChatter): "executed_count": 0, } - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: """ 获取处理器统计信息 @@ -122,7 +122,7 @@ class AffinityChatter(BaseChatter): """ return self.stats.copy() - def get_planner_stats(self) -> Dict[str, Any]: + def get_planner_stats(self) -> dict[str, Any]: """ 获取规划器统计信息 @@ -131,7 +131,7 @@ class AffinityChatter(BaseChatter): """ return self.planner.get_planner_stats() - def get_interest_scoring_stats(self) -> Dict[str, Any]: + def get_interest_scoring_stats(self) -> dict[str, Any]: """ 获取兴趣度评分统计信息 @@ -140,7 +140,7 @@ class AffinityChatter(BaseChatter): """ return self.planner.get_interest_scoring_stats() - def get_relationship_stats(self) -> Dict[str, Any]: + def get_relationship_stats(self) -> dict[str, Any]: """ 获取用户关系统计信息 @@ -158,7 +158,7 @@ class AffinityChatter(BaseChatter): """ return self.planner.get_current_mood_state() - def get_mood_stats(self) -> Dict[str, Any]: + def get_mood_stats(self) -> dict[str, Any]: """ 获取情绪状态统计信息 diff --git a/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py index 1bb60146b..6892b0916 100644 --- a/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py +++ b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py @@ -5,11 +5,11 @@ """ import traceback -from typing import Dict, List, Any +from typing import Any +from src.chat.interest_system import bot_interest_manager from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.info_data_model import InterestScore -from src.chat.interest_system import bot_interest_manager from src.common.logger import get_logger from src.config.config import global_config @@ -47,11 +47,11 @@ class ChatterInterestScoringSystem: ) # 每次不回复增加的概率 # 用户关系数据 - self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score + self.user_relationships: dict[str, float] = {} # user_id -> relationship_score async def calculate_interest_scores( - self, messages: List[DatabaseMessages], bot_nickname: str - ) -> List[InterestScore]: + self, messages: list[DatabaseMessages], bot_nickname: str + ) -> list[InterestScore]: """计算消息的兴趣度评分""" user_messages = [msg for msg in messages if str(msg.user_info.user_id) != str(global_config.bot.qq_account)] if not user_messages: @@ -97,7 +97,7 @@ class ChatterInterestScoringSystem: details=details, ) - async def _calculate_interest_match_score(self, content: str, keywords: List[str] = None) -> float: + async def _calculate_interest_match_score(self, content: str, keywords: list[str] = None) -> float: """计算兴趣匹配度 - 使用智能embedding匹配""" if not content: return 0.0 @@ -109,7 +109,7 @@ class ChatterInterestScoringSystem: # 智能匹配未初始化,返回默认分数 return 0.3 - async def _calculate_smart_interest_match(self, content: str, keywords: List[str] = None) -> float: + async def _calculate_smart_interest_match(self, content: str, keywords: list[str] = None) -> float: """使用embedding计算智能兴趣匹配""" try: # 如果没有传入关键词,则提取 @@ -134,7 +134,7 @@ class ChatterInterestScoringSystem: logger.error(f"智能兴趣匹配计算失败: {e}") return 0.0 - def _extract_keywords_from_database(self, message: DatabaseMessages) -> List[str]: + def _extract_keywords_from_database(self, message: DatabaseMessages) -> list[str]: """从数据库消息中提取关键词""" keywords = [] @@ -166,7 +166,7 @@ class ChatterInterestScoringSystem: return keywords[:15] # 返回前15个关键词 - def _extract_keywords_from_content(self, content: str) -> List[str]: + def _extract_keywords_from_content(self, content: str) -> list[str]: """从内容中提取关键词(降级方案)""" import re @@ -287,7 +287,7 @@ class ChatterInterestScoringSystem: """获取用户关系分""" return self.user_relationships.get(user_id, 0.3) - def get_scoring_stats(self) -> Dict: + def get_scoring_stats(self) -> dict: """获取评分系统统计""" return { "no_reply_count": self.no_reply_count, @@ -318,7 +318,7 @@ class ChatterInterestScoringSystem: logger.error(f"初始化智能兴趣系统失败: {e}") traceback.print_exc() - def get_matching_config(self) -> Dict[str, Any]: + def get_matching_config(self) -> dict[str, Any]: """获取匹配配置信息""" return { "use_smart_matching": self.use_smart_matching, diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py index 8d322c880..b68591100 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py @@ -5,12 +5,11 @@ PlanExecutor: 接收 Plan 对象并执行其中的所有动作。 import asyncio import time -from typing import Dict, List -from src.config.config import global_config from src.chat.planner_actions.action_manager import ChatterActionManager -from src.common.data_models.info_data_model import Plan, ActionPlannerInfo +from src.common.data_models.info_data_model import ActionPlannerInfo, Plan from src.common.logger import get_logger +from src.config.config import global_config logger = get_logger("plan_executor") @@ -52,7 +51,7 @@ class ChatterPlanExecutor: """设置关系追踪器""" self.relationship_tracker = relationship_tracker - async def execute(self, plan: Plan) -> Dict[str, any]: + async def execute(self, plan: Plan) -> dict[str, any]: """ 遍历并执行Plan对象中`decided_actions`列表里的所有动作。 @@ -110,7 +109,7 @@ class ChatterPlanExecutor: "results": execution_results, } - async def _execute_reply_actions(self, reply_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]: + async def _execute_reply_actions(self, reply_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, any]: """串行执行所有回复动作,增加去重逻辑,避免对同一消息多次回复""" results = [] @@ -162,7 +161,7 @@ class ChatterPlanExecutor: async def _execute_single_reply_action( self, action_info: ActionPlannerInfo, plan: Plan, clear_unread: bool = True - ) -> Dict[str, any]: + ) -> dict[str, any]: """执行单个回复动作""" start_time = time.time() success = False @@ -240,7 +239,7 @@ class ChatterPlanExecutor: else reply_content, } - async def _execute_other_actions(self, other_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]: + async def _execute_other_actions(self, other_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, any]: """执行其他动作""" results = [] @@ -269,7 +268,7 @@ class ChatterPlanExecutor: return {"results": results} - async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> Dict[str, any]: + async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> dict[str, any]: """执行单个其他动作""" start_time = time.time() success = False @@ -378,7 +377,7 @@ class ChatterPlanExecutor: logger.debug(f"action_message类型: {type(action_info.action_message)}") logger.debug(f"action_message内容: {action_info.action_message}") - def get_execution_stats(self) -> Dict[str, any]: + def get_execution_stats(self) -> dict[str, any]: """获取执行统计信息""" stats = self.execution_stats.copy() @@ -409,7 +408,7 @@ class ChatterPlanExecutor: "execution_times": [], } - def get_recent_performance(self, limit: int = 10) -> List[Dict[str, any]]: + def get_recent_performance(self, limit: int = 10) -> list[dict[str, any]]: """获取最近的执行性能""" recent_times = self.execution_stats["execution_times"][-limit:] if not recent_times: diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py index 1bc153fad..92b299219 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py @@ -2,13 +2,13 @@ PlanFilter: 接收 Plan 对象,根据不同模式的逻辑进行筛选,决定最终要执行的动作。 """ -import orjson +import re import time import traceback -import re from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any +import orjson from json_repair import repair_json # 旧的Hippocampus系统已被移除,现在使用增强记忆系统 @@ -39,7 +39,7 @@ class ChatterPlanFilter: 根据 Plan 中的模式和信息,筛选并决定最终的动作。 """ - def __init__(self, chat_id: str, available_actions: List[str]): + def __init__(self, chat_id: str, available_actions: list[str]): """ 初始化动作计划筛选器。 @@ -316,8 +316,8 @@ class ChatterPlanFilter: """构建已读/未读历史消息块""" try: # 从message_manager获取真实的已读/未读消息 - from src.chat.utils.utils import assign_message_ids from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat + from src.chat.utils.utils import assign_message_ids # 获取聊天流的上下文 from src.plugin_system.apis.chat_api import get_chat_manager @@ -392,14 +392,15 @@ class ChatterPlanFilter: logger.error(f"构建已读/未读历史消息块时出错: {e}") return "构建已读历史消息时出错", "构建未读历史消息时出错", [] - async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]: + async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]: """为消息获取兴趣度评分""" interest_scores = {} try: - from .interest_scoring import chatter_interest_scoring_system from src.common.data_models.database_data_model import DatabaseMessages + from .interest_scoring import chatter_interest_scoring_system + # 使用插件内部的兴趣度评分系统计算评分 for msg_dict in messages: try: @@ -450,7 +451,7 @@ class ChatterPlanFilter: async def _parse_single_action( self, action_json: dict, message_id_list: list, plan: Plan - ) -> List[ActionPlannerInfo]: + ) -> list[ActionPlannerInfo]: parsed_actions = [] try: # 从新的actions结构中获取动作信息 @@ -599,7 +600,7 @@ class ChatterPlanFilter: ) return parsed_actions - def _filter_no_actions(self, action_list: List[ActionPlannerInfo]) -> List[ActionPlannerInfo]: + def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]: non_no_actions = [a for a in action_list if a.action_type not in ["no_action", "no_reply"]] if non_no_actions: return non_no_actions @@ -652,7 +653,7 @@ class ChatterPlanFilter: logger.error(f"获取长期记忆时出错: {e}") return "回忆时出现了一些问题。" - async def _build_action_options(self, current_available_actions: Dict[str, ActionInfo]) -> str: + async def _build_action_options(self, current_available_actions: dict[str, ActionInfo]) -> str: action_options_block = "" for action_name, action_info in current_available_actions.items(): # 构建参数的JSON示例 @@ -723,7 +724,7 @@ class ChatterPlanFilter: ) return action_options_block - def _find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]: + def _find_message_by_id(self, message_id: str, message_id_list: list) -> dict[str, Any] | None: """ 增强的消息查找函数,支持多种格式和模糊匹配 兼容大模型可能返回的各种格式变体 @@ -828,12 +829,12 @@ class ChatterPlanFilter: logger.warning(f"未找到任何匹配的消息: {original_id} (候选: {candidate_ids})") return None - def _get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]: + def _get_latest_message(self, message_id_list: list) -> dict[str, Any] | None: if not message_id_list: return None return message_id_list[-1].get("message") - def _find_poke_notice(self, message_id_list: list) -> Optional[Dict[str, Any]]: + def _find_poke_notice(self, message_id_list: list) -> dict[str, Any] | None: """在消息列表中寻找戳一戳的通知消息""" for item in reversed(message_id_list): message = item.get("message") diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_generator.py b/src/plugins/built_in/affinity_flow_chatter/plan_generator.py index 86539ac01..d946934d5 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_generator.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_generator.py @@ -3,7 +3,6 @@ PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个 """ import time -from typing import Dict from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat from src.chat.utils.utils import get_chat_type_and_target_info @@ -85,7 +84,7 @@ class ChatterPlanGenerator: chat_history=[], ) - async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> Dict[str, ActionInfo]: + async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> dict[str, ActionInfo]: """ 获取当前可用的动作列表。 @@ -152,7 +151,7 @@ class ChatterPlanGenerator: # 如果获取失败,返回空列表 return [] - def get_generator_stats(self) -> Dict: + def get_generator_stats(self) -> dict: """ 获取生成器统计信息。 diff --git a/src/plugins/built_in/affinity_flow_chatter/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner.py index b6be512b9..e2321aab1 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner.py @@ -4,22 +4,20 @@ """ from dataclasses import asdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple - -from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor -from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter -from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator -from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system -from src.mood.mood_manager import mood_manager - +from typing import TYPE_CHECKING, Any from src.common.logger import get_logger from src.config.config import global_config +from src.mood.mood_manager import mood_manager +from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system +from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor +from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter +from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator if TYPE_CHECKING: - from src.common.data_models.message_manager_data_model import StreamContext - from src.common.data_models.info_data_model import Plan from src.chat.planner_actions.action_manager import ChatterActionManager + from src.common.data_models.info_data_model import Plan + from src.common.data_models.message_manager_data_model import StreamContext # 导入提示词模块以确保其被初始化 from src.plugins.built_in.affinity_flow_chatter import planner_prompts # noqa @@ -62,7 +60,7 @@ class ChatterActionPlanner: "other_actions_executed": 0, } - async def plan(self, context: "StreamContext" = None) -> Tuple[List[Dict], Optional[Dict]]: + async def plan(self, context: "StreamContext" = None) -> tuple[list[dict], dict | None]: """ 执行完整的增强版规划流程。 @@ -84,7 +82,7 @@ class ChatterActionPlanner: self.planner_stats["failed_plans"] += 1 return [], None - async def _enhanced_plan_flow(self, context: "StreamContext") -> Tuple[List[Dict], Optional[Dict]]: + async def _enhanced_plan_flow(self, context: "StreamContext") -> tuple[list[dict], dict | None]: """执行增强版规划流程""" try: # 在规划前,先进行动作修改 @@ -104,7 +102,7 @@ class ChatterActionPlanner: score = 0.0 should_reply = False reply_not_available = False - interest_updates: List[Dict[str, Any]] = [] + interest_updates: list[dict[str, Any]] = [] if unread_messages: # 为每条消息计算兴趣度,并延迟提交数据库更新 @@ -193,7 +191,7 @@ class ChatterActionPlanner: self.planner_stats["failed_plans"] += 1 return [], None - async def _commit_interest_updates(self, updates: List[Dict[str, Any]]) -> None: + async def _commit_interest_updates(self, updates: list[dict[str, Any]]) -> None: """统一更新消息兴趣度,减少数据库写入次数""" if not updates: return @@ -220,7 +218,7 @@ class ChatterActionPlanner: except Exception as e: logger.warning(f"批量更新数据库兴趣度失败: {e}") - def _update_stats_from_execution_result(self, execution_result: Dict[str, any]): + def _update_stats_from_execution_result(self, execution_result: dict[str, any]): """根据执行结果更新规划器统计""" if not execution_result: return @@ -244,7 +242,7 @@ class ChatterActionPlanner: self.planner_stats["replies_generated"] += reply_count self.planner_stats["other_actions_executed"] += other_count - def _build_return_result(self, plan: "Plan") -> Tuple[List[Dict], Optional[Dict]]: + def _build_return_result(self, plan: "Plan") -> tuple[list[dict], dict | None]: """构建返回结果""" final_actions = plan.decided_actions or [] final_target_message = next((act.action_message for act in final_actions if act.action_message), None) @@ -261,7 +259,7 @@ class ChatterActionPlanner: return final_actions_dict, final_target_message_dict - def get_planner_stats(self) -> Dict[str, any]: + def get_planner_stats(self) -> dict[str, any]: """获取规划器统计""" return self.planner_stats.copy() @@ -270,7 +268,7 @@ class ChatterActionPlanner: chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id) return chat_mood.mood_state - def get_mood_stats(self) -> Dict[str, any]: + def get_mood_stats(self) -> dict[str, any]: """获取情绪状态统计""" chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id) return { diff --git a/src/plugins/built_in/affinity_flow_chatter/plugin.py b/src/plugins/built_in/affinity_flow_chatter/plugin.py index 7c86d13fe..32d869e67 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plugin.py +++ b/src/plugins/built_in/affinity_flow_chatter/plugin.py @@ -2,12 +2,10 @@ 亲和力聊天处理器插件 """ -from typing import List, Tuple, Type - +from src.common.logger import get_logger from src.plugin_system.apis.plugin_register_api import register_plugin from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system.base.component_types import ComponentInfo -from src.common.logger import get_logger logger = get_logger("affinity_chatter_plugin") @@ -29,7 +27,7 @@ class AffinityChatterPlugin(BasePlugin): # 简单的 config_schema 占位(如果将来需要配置可扩展) config_schema = {} - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: """返回插件包含的组件列表(ChatterInfo, AffinityChatter) 这里采用延迟导入 AffinityChatter 来避免循环依赖和启动顺序问题。 diff --git a/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py b/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py index 2320670a0..e3dcb9791 100644 --- a/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py +++ b/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py @@ -5,15 +5,15 @@ """ import time -from typing import Dict, List, Optional -from src.common.logger import get_logger -from src.config.config import model_config, global_config -from src.llm_models.utils_model import LLMRequest -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import UserRelationships, Messages -from sqlalchemy import select, desc +from sqlalchemy import desc, select + from src.common.data_models.database_data_model import DatabaseMessages +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import Messages, UserRelationships +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("chatter_relationship_tracker") @@ -22,15 +22,15 @@ class ChatterRelationshipTracker: """用户关系追踪器""" def __init__(self, interest_scoring_system=None): - self.tracking_users: Dict[str, Dict] = {} # user_id -> interaction_data + self.tracking_users: dict[str, dict] = {} # user_id -> interaction_data self.max_tracking_users = 3 self.update_interval_minutes = 30 self.last_update_time = time.time() - self.relationship_history: List[Dict] = [] + self.relationship_history: list[dict] = [] self.interest_scoring_system = interest_scoring_system # 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float}) - self.user_relationship_cache: Dict[str, Dict] = {} + self.user_relationship_cache: dict[str, dict] = {} self.cache_expiry_hours = 1 # 缓存过期时间(小时) # 关系更新LLM @@ -91,7 +91,7 @@ class ChatterRelationshipTracker: logger.debug(f"添加用户交互追踪: {user_id}") - async def check_and_update_relationships(self) -> List[Dict]: + async def check_and_update_relationships(self) -> list[dict]: """检查并更新用户关系""" current_time = time.time() if current_time - self.last_update_time < self.update_interval_minutes * 60: @@ -108,7 +108,7 @@ class ChatterRelationshipTracker: self.last_update_time = current_time return updates - async def _update_user_relationship(self, interaction: Dict) -> Optional[Dict]: + async def _update_user_relationship(self, interaction: dict) -> dict | None: """更新单个用户的关系""" try: # 获取bot人设信息 @@ -201,11 +201,11 @@ class ChatterRelationshipTracker: return None - def get_tracking_users(self) -> Dict[str, Dict]: + def get_tracking_users(self) -> dict[str, dict]: """获取正在追踪的用户""" return self.tracking_users.copy() - def get_user_interaction(self, user_id: str) -> Optional[Dict]: + def get_user_interaction(self, user_id: str) -> dict | None: """获取特定用户的交互记录""" return self.tracking_users.get(user_id) @@ -220,11 +220,11 @@ class ChatterRelationshipTracker: self.tracking_users.clear() logger.info("清空所有用户追踪") - def get_relationship_history(self) -> List[Dict]: + def get_relationship_history(self) -> list[dict]: """获取关系历史记录""" return self.relationship_history.copy() - def add_to_history(self, relationship_update: Dict): + def add_to_history(self, relationship_update: dict): """添加到关系历史""" self.relationship_history.append({**relationship_update, "update_time": time.time()}) @@ -232,7 +232,7 @@ class ChatterRelationshipTracker: if len(self.relationship_history) > 100: self.relationship_history = self.relationship_history[-100:] - def get_tracker_stats(self) -> Dict: + def get_tracker_stats(self) -> dict: """获取追踪器统计""" return { "tracking_users": len(self.tracking_users), @@ -268,7 +268,7 @@ class ChatterRelationshipTracker: self.add_to_history(update_info) logger.info(f"强制更新用户关系: {user_id} -> {new_score:.2f}") - def get_user_summary(self, user_id: str) -> Dict: + def get_user_summary(self, user_id: str) -> dict: """获取用户交互总结""" if user_id not in self.tracking_users: return {} @@ -313,7 +313,7 @@ class ChatterRelationshipTracker: # 数据库中也没有,返回默认值 return global_config.affinity_flow.base_relationship_score - async def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]: + async def _get_user_relationship_from_db(self, user_id: str) -> dict | None: """从数据库获取用户关系数据""" try: async with get_db_session() as session: @@ -431,7 +431,7 @@ class ChatterRelationshipTracker: return 0 - async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]: + async def _get_last_bot_reply_to_user(self, user_id: str) -> DatabaseMessages | None: """获取上次bot回复该用户的消息""" try: async with get_db_session() as session: @@ -455,7 +455,7 @@ class ChatterRelationshipTracker: return None - async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]: + async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> list[DatabaseMessages]: """获取用户在bot回复后的反应消息""" try: async with get_db_session() as session: @@ -511,7 +511,7 @@ class ChatterRelationshipTracker: user_id: str, user_name: str, last_bot_reply: DatabaseMessages, - user_reactions: List[DatabaseMessages], + user_reactions: list[DatabaseMessages], current_text: str, current_score: float, current_reply: str, diff --git a/src/plugins/built_in/core_actions/anti_injector_manager.py b/src/plugins/built_in/core_actions/anti_injector_manager.py index 3291ba8cf..3b207ab63 100644 --- a/src/plugins/built_in/core_actions/anti_injector_manager.py +++ b/src/plugins/built_in/core_actions/anti_injector_manager.py @@ -8,9 +8,9 @@ - 测试功能 """ -from src.plugin_system.base import BaseCommand from src.chat.antipromptinjector import get_anti_injector from src.common.logger import get_logger +from src.plugin_system.base import BaseCommand logger = get_logger("anti_injector.commands") @@ -56,5 +56,5 @@ class AntiInjectorStatusCommand(BaseCommand): except Exception as e: logger.error(f"获取反注入系统状态失败: {e}") - await self.send_text(f"获取状态失败: {str(e)}") - return False, f"获取状态失败: {str(e)}", True + await self.send_text(f"获取状态失败: {e!s}") + return False, f"获取状态失败: {e!s}", True diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index a477fdf0a..0dab1f88c 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -1,19 +1,18 @@ import random -from typing import Tuple -# 导入新插件系统 -from src.plugin_system import BaseAction, ActionActivationType, ChatMode +from src.chat.emoji_system.emoji_history import add_emoji_to_history, get_recent_emojis +from src.chat.emoji_system.emoji_manager import MaiEmoji, get_emoji_manager +from src.chat.utils.utils_image import image_path_to_base64 # 导入依赖的系统组件 from src.common.logger import get_logger +from src.config.config import global_config + +# 导入新插件系统 +from src.plugin_system import ActionActivationType, BaseAction, ChatMode # 导入API模块 - 标准Python包方式 from src.plugin_system.apis import llm_api, message_api -from src.chat.emoji_system.emoji_manager import get_emoji_manager, MaiEmoji -from src.chat.utils.utils_image import image_path_to_base64 -from src.config.config import global_config -from src.chat.emoji_system.emoji_history import get_recent_emojis, add_emoji_to_history - logger = get_logger("emoji") @@ -59,7 +58,7 @@ class EmojiAction(BaseAction): # 关联类型 associated_types = ["emoji"] - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """执行表情动作""" logger.info(f"{self.log_prefix} 决定发送表情") @@ -286,4 +285,4 @@ class EmojiAction(BaseAction): except Exception as e: logger.error(f"{self.log_prefix} 表情动作执行失败: {e}", exc_info=True) - return False, f"表情发送失败: {str(e)}" + return False, f"表情发送失败: {e!s}" diff --git a/src/plugins/built_in/core_actions/plugin.py b/src/plugins/built_in/core_actions/plugin.py index 473005a22..91a7e8d5e 100644 --- a/src/plugins/built_in/core_actions/plugin.py +++ b/src/plugins/built_in/core_actions/plugin.py @@ -5,19 +5,16 @@ 这是系统的内置插件,提供基础的聊天交互功能 """ -from typing import List, Tuple, Type - -# 导入新插件系统 -from src.plugin_system import BasePlugin, register_plugin, ComponentInfo -from src.plugin_system.base.config_types import ConfigField - - # 导入依赖的系统组件 from src.common.logger import get_logger +# 导入新插件系统 +from src.plugin_system import BasePlugin, ComponentInfo, register_plugin +from src.plugin_system.base.config_types import ConfigField +from src.plugins.built_in.core_actions.anti_injector_manager import AntiInjectorStatusCommand + # 导入API模块 - 标准Python包方式 from src.plugins.built_in.core_actions.emoji import EmojiAction -from src.plugins.built_in.core_actions.anti_injector_manager import AntiInjectorStatusCommand logger = get_logger("core_actions") @@ -62,7 +59,7 @@ class CoreActionsPlugin(BasePlugin): }, } - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: """返回插件包含的组件列表""" # --- 根据配置注册组件 --- diff --git a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py index 194a2c5ef..38da7e013 100644 --- a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py +++ b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py @@ -1,8 +1,8 @@ -from typing import Dict, Any +from typing import Any +from src.chat.knowledge.knowledge_lib import qa_manager from src.common.logger import get_logger from src.config.config import global_config -from src.chat.knowledge.knowledge_lib import qa_manager from src.plugin_system import BaseTool, ToolParamType logger = get_logger("lpmm_get_knowledge_tool") @@ -19,7 +19,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool): ] available_for_llm = global_config.lpmm_knowledge.enable - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """执行知识库搜索 Args: @@ -56,7 +56,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool): return {"type": "lpmm_knowledge", "id": query, "content": content} except Exception as e: # 捕获异常并记录错误 - logger.error(f"知识库搜索工具执行失败: {str(e)}") + logger.error(f"知识库搜索工具执行失败: {e!s}") # 在其他异常情况下,确保 id 仍然是 query (如果它被定义了) query_id = query if "query" in locals() else "unknown_query" - return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {str(e)}"} + return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {e!s}"} diff --git a/src/plugins/built_in/maizone_refactored/__init__.py b/src/plugins/built_in/maizone_refactored/__init__.py index 56a019c4b..dd094256f 100644 --- a/src/plugins/built_in/maizone_refactored/__init__.py +++ b/src/plugins/built_in/maizone_refactored/__init__.py @@ -1,9 +1,8 @@ -# -*- coding: utf-8 -*- """ 让框架能够发现并加载子目录中的组件。 """ -from .plugin import MaiZoneRefactoredPlugin as MaiZoneRefactoredPlugin -from .actions.send_feed_action import SendFeedAction as SendFeedAction from .actions.read_feed_action import ReadFeedAction as ReadFeedAction +from .actions.send_feed_action import SendFeedAction as SendFeedAction from .commands.send_feed_command import SendFeedCommand as SendFeedCommand +from .plugin import MaiZoneRefactoredPlugin as MaiZoneRefactoredPlugin diff --git a/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py b/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py index ee5a1b73a..6abef2141 100644 --- a/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py +++ b/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py @@ -1,14 +1,12 @@ -# -*- coding: utf-8 -*- """ 阅读说说动作组件 """ -from typing import Tuple - from src.common.logger import get_logger -from src.plugin_system import BaseAction, ActionActivationType, ChatMode +from src.plugin_system import ActionActivationType, BaseAction, ChatMode from src.plugin_system.apis import generator_api from src.plugin_system.apis.permission_api import permission_api + from ..services.manager import get_qzone_service logger = get_logger("MaiZone.ReadFeedAction") @@ -41,7 +39,7 @@ class ReadFeedAction(BaseAction): # 使用权限API检查用户是否有阅读说说的权限 return await permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed") - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """ 执行动作的核心逻辑。 """ diff --git a/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py b/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py index af8760c06..b242aae70 100644 --- a/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py +++ b/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py @@ -1,14 +1,12 @@ -# -*- coding: utf-8 -*- """ 发送说说动作组件 """ -from typing import Tuple - from src.common.logger import get_logger -from src.plugin_system import BaseAction, ActionActivationType, ChatMode +from src.plugin_system import ActionActivationType, BaseAction, ChatMode from src.plugin_system.apis import generator_api from src.plugin_system.apis.permission_api import permission_api + from ..services.manager import get_qzone_service logger = get_logger("MaiZone.SendFeedAction") @@ -41,7 +39,7 @@ class SendFeedAction(BaseAction): # 使用权限API检查用户是否有发送说说的权限 return await permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed") - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """ 执行动作的核心逻辑。 """ diff --git a/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py b/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py index 631ca430d..062252a99 100644 --- a/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py +++ b/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py @@ -1,15 +1,13 @@ -# -*- coding: utf-8 -*- """ 发送说说命令 await self.send_text(f"收到!正在为你生成关于"{topic or '随机'}"的说说,请稍候...【热重载测试成功】")件 """ -from typing import Tuple - from src.common.logger import get_logger -from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.base.command_args import CommandArgs +from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.utils.permission_decorators import require_permission -from ..services.manager import get_qzone_service, get_config_getter + +from ..services.manager import get_config_getter, get_qzone_service logger = get_logger("MaiZone.SendFeedCommand") @@ -28,7 +26,7 @@ class SendFeedCommand(PlusCommand): super().__init__(*args, **kwargs) @require_permission("plugin.maizone.send_feed") - async def execute(self, args: CommandArgs) -> Tuple[bool, str, bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str, bool]: """ 执行命令的核心逻辑。 """ diff --git a/src/plugins/built_in/maizone_refactored/plugin.py b/src/plugins/built_in/maizone_refactored/plugin.py index e8259b5cb..4ef92ff9e 100644 --- a/src/plugins/built_in/maizone_refactored/plugin.py +++ b/src/plugins/built_in/maizone_refactored/plugin.py @@ -1,28 +1,26 @@ -# -*- coding: utf-8 -*- """ MaiZone(麦麦空间)- 重构版 """ import asyncio from pathlib import Path -from typing import List, Tuple, Type from src.common.logger import get_logger from src.plugin_system import BasePlugin, ComponentInfo, register_plugin -from src.plugin_system.base.config_types import ConfigField from src.plugin_system.apis.permission_api import permission_api +from src.plugin_system.base.config_types import ConfigField from .actions.read_feed_action import ReadFeedAction from .actions.send_feed_action import SendFeedAction from .commands.send_feed_command import SendFeedCommand from .services.content_service import ContentService -from .services.image_service import ImageService -from .services.qzone_service import QZoneService -from .services.scheduler_service import SchedulerService -from .services.monitor_service import MonitorService from .services.cookie_service import CookieService -from .services.reply_tracker_service import ReplyTrackerService +from .services.image_service import ImageService from .services.manager import register_service +from .services.monitor_service import MonitorService +from .services.qzone_service import QZoneService +from .services.reply_tracker_service import ReplyTrackerService +from .services.scheduler_service import SchedulerService logger = get_logger("MaiZone.Plugin") @@ -35,8 +33,8 @@ class MaiZoneRefactoredPlugin(BasePlugin): plugin_description: str = "重构版的MaiZone插件" config_file_name: str = "config.toml" enable_plugin: bool = True - dependencies: List[str] = [] - python_dependencies: List[str] = [] + dependencies: list[str] = [] + python_dependencies: list[str] = [] config_schema: dict = { "plugin": {"enable": ConfigField(type=bool, default=True, description="是否启用插件")}, @@ -125,7 +123,7 @@ class MaiZoneRefactoredPlugin(BasePlugin): asyncio.create_task(monitor_service.start()) logger.info("MaiZone后台监控和定时任务已启动。") - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: return [ (SendFeedAction.get_action_info(), SendFeedAction), (ReadFeedAction.get_action_info(), ReadFeedAction), diff --git a/src/plugins/built_in/maizone_refactored/services/content_service.py b/src/plugins/built_in/maizone_refactored/services/content_service.py index 27f2a0ee9..553eb2a95 100644 --- a/src/plugins/built_in/maizone_refactored/services/content_service.py +++ b/src/plugins/built_in/maizone_refactored/services/content_service.py @@ -1,23 +1,23 @@ -# -*- coding: utf-8 -*- """ 内容服务模块 负责生成所有与QQ空间相关的文本内容,例如说说、评论等。 """ -from typing import Callable, Optional -import datetime - -import base64 -import aiohttp -from src.common.logger import get_logger -import imghdr import asyncio -from src.plugin_system.apis import llm_api, config_api, generator_api -from src.plugin_system.apis.cross_context_api import get_chat_history_by_group_name -from src.chat.message_receive.chat_stream import get_chat_manager +import base64 +import datetime +import imghdr +from collections.abc import Callable + +import aiohttp from maim_message import UserInfo -from src.llm_models.utils_model import LLMRequest + +from src.chat.message_receive.chat_stream import get_chat_manager +from src.common.logger import get_logger from src.config.api_ada_configs import TaskConfig +from src.llm_models.utils_model import LLMRequest +from src.plugin_system.apis import config_api, generator_api, llm_api +from src.plugin_system.apis.cross_context_api import get_chat_history_by_group_name # 导入旧的工具函数,我们稍后会考虑是否也需要重构它 from ..utils.history_utils import get_send_history @@ -38,7 +38,7 @@ class ContentService: """ self.get_config = get_config - async def generate_story(self, topic: str, context: Optional[str] = None) -> str: + async def generate_story(self, topic: str, context: str | None = None) -> str: """ 根据指定主题和可选的上下文生成一条QQ空间说说。 @@ -231,7 +231,7 @@ class ContentService: return "" return "" - async def _describe_image(self, image_url: str) -> Optional[str]: + async def _describe_image(self, image_url: str) -> str | None: """ 使用LLM识别图片内容。 """ diff --git a/src/plugins/built_in/maizone_refactored/services/cookie_service.py b/src/plugins/built_in/maizone_refactored/services/cookie_service.py index 9da05582c..c0a0b7ef9 100644 --- a/src/plugins/built_in/maizone_refactored/services/cookie_service.py +++ b/src/plugins/built_in/maizone_refactored/services/cookie_service.py @@ -1,14 +1,14 @@ -# -*- coding: utf-8 -*- """ Cookie服务模块 负责从多种来源获取、缓存和管理QZone的Cookie。 """ -import orjson +from collections.abc import Callable from pathlib import Path -from typing import Callable, Optional, Dict import aiohttp +import orjson + from src.common.logger import get_logger from src.plugin_system.apis import send_api @@ -29,28 +29,28 @@ class CookieService: """获取指定QQ账号的cookie文件路径""" return self.cookie_dir / f"cookies-{qq_account}.json" - def _save_cookies_to_file(self, qq_account: str, cookies: Dict[str, str]): + def _save_cookies_to_file(self, qq_account: str, cookies: dict[str, str]): """将Cookie保存到本地文件""" cookie_file_path = self._get_cookie_file_path(qq_account) try: with open(cookie_file_path, "w", encoding="utf-8") as f: f.write(orjson.dumps(cookies, option=orjson.OPT_INDENT_2).decode("utf-8")) logger.info(f"Cookie已成功缓存至: {cookie_file_path}") - except IOError as e: + except OSError as e: logger.error(f"无法写入Cookie文件 {cookie_file_path}: {e}") - def _load_cookies_from_file(self, qq_account: str) -> Optional[Dict[str, str]]: + def _load_cookies_from_file(self, qq_account: str) -> dict[str, str] | None: """从本地文件加载Cookie""" cookie_file_path = self._get_cookie_file_path(qq_account) if cookie_file_path.exists(): try: - with open(cookie_file_path, "r", encoding="utf-8") as f: + with open(cookie_file_path, encoding="utf-8") as f: return orjson.loads(f.read()) - except (IOError, orjson.JSONDecodeError) as e: + except (OSError, orjson.JSONDecodeError) as e: logger.error(f"无法读取或解析Cookie文件 {cookie_file_path}: {e}") return None - async def _get_cookies_from_adapter(self, stream_id: Optional[str]) -> Optional[Dict[str, str]]: + async def _get_cookies_from_adapter(self, stream_id: str | None) -> dict[str, str] | None: """通过Adapter API获取Cookie""" try: params = {"domain": "user.qzone.qq.com"} @@ -73,7 +73,7 @@ class CookieService: logger.error(f"通过Adapter获取Cookie时发生异常: {e}") return None - async def _get_cookies_from_http(self) -> Optional[Dict[str, str]]: + async def _get_cookies_from_http(self) -> dict[str, str] | None: """通过备用HTTP端点获取Cookie""" host = self.get_config("cookie.http_fallback_host", "172.20.130.55") port = self.get_config("cookie.http_fallback_port", "9999") @@ -110,7 +110,7 @@ class CookieService: logger.error(f"通过HTTP备用地址 {http_url} 获取Cookie失败: {e}") return None - async def get_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]: + async def get_cookies(self, qq_account: str, stream_id: str | None) -> dict[str, str] | None: """ 获取Cookie,按以下顺序尝试: 1. HTTP备用端点 (更稳定) diff --git a/src/plugins/built_in/maizone_refactored/services/image_service.py b/src/plugins/built_in/maizone_refactored/services/image_service.py index cbb411da7..58241ba7b 100644 --- a/src/plugins/built_in/maizone_refactored/services/image_service.py +++ b/src/plugins/built_in/maizone_refactored/services/image_service.py @@ -1,12 +1,11 @@ -# -*- coding: utf-8 -*- """ 图片服务模块 负责处理所有与图片相关的任务,特别是AI生成图片。 """ import base64 +from collections.abc import Callable from pathlib import Path -from typing import Callable import aiohttp diff --git a/src/plugins/built_in/maizone_refactored/services/manager.py b/src/plugins/built_in/maizone_refactored/services/manager.py index 74cbb844a..ec1588bd3 100644 --- a/src/plugins/built_in/maizone_refactored/services/manager.py +++ b/src/plugins/built_in/maizone_refactored/services/manager.py @@ -1,14 +1,15 @@ -# -*- coding: utf-8 -*- """ 服务管理器/定位器 这是一个独立的模块,用于注册和获取插件内的全局服务实例,以避免循环导入。 """ -from typing import Dict, Any, Callable +from collections.abc import Callable +from typing import Any + from .qzone_service import QZoneService # --- 全局服务注册表 --- -_services: Dict[str, Any] = {} +_services: dict[str, Any] = {} def register_service(name: str, instance: Any): diff --git a/src/plugins/built_in/maizone_refactored/services/monitor_service.py b/src/plugins/built_in/maizone_refactored/services/monitor_service.py index 114358ea3..b479f4183 100644 --- a/src/plugins/built_in/maizone_refactored/services/monitor_service.py +++ b/src/plugins/built_in/maizone_refactored/services/monitor_service.py @@ -1,13 +1,13 @@ -# -*- coding: utf-8 -*- """ 好友动态监控服务 """ import asyncio import traceback -from typing import Callable +from collections.abc import Callable from src.common.logger import get_logger + from .qzone_service import QZoneService logger = get_logger("MaiZone.MonitorService") diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index c0e00b80d..6220595cc 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -1,32 +1,33 @@ -# -*- coding: utf-8 -*- """ QQ空间服务模块 封装了所有与QQ空间API的直接交互,是插件的核心业务逻辑层。 """ import asyncio -import orjson +import base64 import os import random import time -import base64 +from collections.abc import Callable from pathlib import Path -from typing import Callable, Optional, Dict, Any, List, Tuple +from typing import Any import aiohttp import bs4 import json5 -from src.common.logger import get_logger -from src.plugin_system.apis import config_api, person_api +import orjson + from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import ( build_readable_messages_with_id, get_raw_msg_by_timestamp_with_chat, ) +from src.common.logger import get_logger +from src.plugin_system.apis import config_api, person_api from .content_service import ContentService -from .image_service import ImageService from .cookie_service import CookieService +from .image_service import ImageService from .reply_tracker_service import ReplyTrackerService logger = get_logger("MaiZone.QZoneService") @@ -64,7 +65,7 @@ class QZoneService: # --- Public Methods (High-Level Business Logic) --- - async def send_feed(self, topic: str, stream_id: Optional[str]) -> Dict[str, Any]: + async def send_feed(self, topic: str, stream_id: str | None) -> dict[str, Any]: """发送一条说说""" # --- 获取互通组上下文 --- context = await self._get_intercom_context(stream_id) if stream_id else None @@ -92,7 +93,7 @@ class QZoneService: logger.error(f"发布说说时发生异常: {e}", exc_info=True) return {"success": False, "message": f"发布说说异常: {e}"} - async def send_feed_from_activity(self, activity: str) -> Dict[str, Any]: + async def send_feed_from_activity(self, activity: str) -> dict[str, Any]: """根据日程活动发送一条说说""" story = await self.content_service.generate_story_from_activity(activity) if not story: @@ -118,7 +119,7 @@ class QZoneService: logger.error(f"根据活动发布说说时发生异常: {e}", exc_info=True) return {"success": False, "message": f"发布说说异常: {e}"} - async def read_and_process_feeds(self, target_name: str, stream_id: Optional[str]) -> Dict[str, Any]: + async def read_and_process_feeds(self, target_name: str, stream_id: str | None) -> dict[str, Any]: """读取并处理指定好友的说说""" target_person_id = await person_api.get_person_id_by_name(target_name) if not target_person_id: @@ -147,7 +148,7 @@ class QZoneService: logger.error(f"读取和处理说说时发生异常: {e}", exc_info=True) return {"success": False, "message": f"处理说说异常: {e}"} - async def monitor_feeds(self, stream_id: Optional[str] = None): + async def monitor_feeds(self, stream_id: str | None = None): """监控并处理所有好友的动态,包括回复自己说说的评论""" logger.info("开始执行好友动态监控...") qq_account = config_api.get_global_config("bot.qq_account", "") @@ -189,7 +190,7 @@ class QZoneService: # --- Internal Helper Methods --- - async def _get_intercom_context(self, stream_id: str) -> Optional[str]: + async def _get_intercom_context(self, stream_id: str) -> str | None: """ 根据 stream_id 查找其所属的互通组,并构建该组的聊天上下文。 @@ -247,7 +248,7 @@ class QZoneService: logger.debug(f"Stream ID '{stream_id}' 未在任何互通组中找到。") return None - async def _reply_to_own_feed_comments(self, feed: Dict, api_client: Dict): + async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict): """处理对自己说说的评论并进行回复""" qq_account = config_api.get_global_config("bot.qq_account", "") comments = feed.get("comments", []) @@ -309,7 +310,7 @@ class QZoneService: if comment_key in self.processing_comments: self.processing_comments.remove(comment_key) - async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: List[Dict]): + async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: list[dict]): """验证并清理已删除的回复记录""" # 获取当前记录中该说说的所有已回复评论ID recorded_replied_comments = self.reply_tracker.get_replied_comments(fid) @@ -333,7 +334,7 @@ class QZoneService: self.reply_tracker.remove_reply_record(fid, comment_tid) logger.debug(f"已清理删除的回复记录: feed_id={fid}, comment_id={comment_tid}") - async def _process_single_feed(self, feed: Dict, api_client: Dict, target_qq: str, target_name: str): + async def _process_single_feed(self, feed: dict, api_client: dict, target_qq: str, target_name: str): """处理单条说说,决定是否评论和点赞""" content = feed.get("content", "") fid = feed.get("tid", "") @@ -371,7 +372,7 @@ class QZoneService: if random.random() <= self.get_config("read.like_possibility", 1.0): await api_client["like"](target_qq, fid) - def _load_local_images(self, image_dir: str) -> List[bytes]: + def _load_local_images(self, image_dir: str) -> list[bytes]: """随机加载本地图片(不删除文件)""" images = [] if not image_dir or not os.path.exists(image_dir): @@ -432,7 +433,7 @@ class QZoneService: hash_val += (hash_val << 5) + ord(char) return str(hash_val & 2147483647) - async def _renew_and_load_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]: + async def _renew_and_load_cookies(self, qq_account: str, stream_id: str | None) -> dict[str, str] | None: cookie_dir = Path(__file__).resolve().parent.parent / "cookies" cookie_dir.mkdir(exist_ok=True) cookie_file_path = cookie_dir / f"cookies-{qq_account}.json" @@ -480,7 +481,7 @@ class QZoneService: logger.error("所有获取Cookie的方式均失败。") return None - async def _fetch_cookies_http(self, host: str, port: int, napcat_token: str) -> Optional[Dict]: + async def _fetch_cookies_http(self, host: str, port: int, napcat_token: str) -> dict | None: """通过HTTP服务器获取Cookie""" # 从配置中读取主机和端口,如果未提供则使用传入的参数 final_host = self.get_config("cookie.http_fallback_host", host) @@ -515,19 +516,19 @@ class QZoneService: except aiohttp.ClientError as e: if attempt < max_retries - 1: - logger.warning(f"无法连接到Napcat服务(尝试 {attempt + 1}/{max_retries}): {url},错误: {str(e)}") + logger.warning(f"无法连接到Napcat服务(尝试 {attempt + 1}/{max_retries}): {url},错误: {e!s}") await asyncio.sleep(retry_delay) retry_delay *= 2 continue - logger.error(f"无法连接到Napcat服务(最终尝试): {url},错误: {str(e)}") + logger.error(f"无法连接到Napcat服务(最终尝试): {url},错误: {e!s}") raise RuntimeError(f"无法连接到Napcat服务: {url}") from e except Exception as e: - logger.error(f"获取cookie异常: {str(e)}") + logger.error(f"获取cookie异常: {e!s}") raise raise RuntimeError(f"无法连接到Napcat服务: 超过最大重试次数({max_retries})") - async def _get_api_client(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict]: + async def _get_api_client(self, qq_account: str, stream_id: str | None) -> dict | None: cookies = await self.cookie_service.get_cookies(qq_account, stream_id) if not cookies: logger.error( @@ -559,7 +560,7 @@ class QZoneService: response.raise_for_status() return await response.text() - async def _publish(content: str, images: List[bytes]) -> Tuple[bool, str]: + async def _publish(content: str, images: list[bytes]) -> tuple[bool, str]: """发布说说""" try: post_data = { @@ -660,7 +661,7 @@ class QZoneService: return picbo, richval - async def _upload_image(image_bytes: bytes, index: int) -> Optional[Dict[str, str]]: + async def _upload_image(image_bytes: bytes, index: int) -> dict[str, str] | None: """上传图片到QQ空间(完全按照原版实现)""" try: upload_url = "https://up.qzone.qq.com/cgi-bin/upload/cgi_upload_image" @@ -745,7 +746,7 @@ class QZoneService: logger.error(f"上传图片 {index + 1} 异常: {e}", exc_info=True) return None - async def _list_feeds(t_qq: str, num: int) -> List[Dict]: + async def _list_feeds(t_qq: str, num: int) -> list[dict]: """获取指定用户说说列表 (统一接口)""" try: # 统一使用 format=json 获取完整评论 @@ -920,7 +921,7 @@ class QZoneService: logger.error(f"回复评论异常: {e}", exc_info=True) return False - async def _monitor_list_feeds(num: int) -> List[Dict]: + async def _monitor_list_feeds(num: int) -> list[dict]: """监控好友动态""" try: params = { diff --git a/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py index 0fa7edb99..6baa30d21 100644 --- a/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py +++ b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 评论回复跟踪服务 负责记录和管理已回复过的评论ID,避免重复回复 @@ -7,7 +6,8 @@ import json import time from pathlib import Path -from typing import Set, Dict, Any, Union +from typing import Any + from src.common.logger import get_logger logger = get_logger("MaiZone.ReplyTrackerService") @@ -27,7 +27,7 @@ class ReplyTrackerService: # 内存中的已回复评论记录 # 格式: {feed_id: {comment_id: timestamp, ...}, ...} - self.replied_comments: Dict[str, Dict[str, float]] = {} + self.replied_comments: dict[str, dict[str, float]] = {} # 数据清理配置 self.max_record_days = 30 # 保留30天的记录 @@ -64,7 +64,7 @@ class ReplyTrackerService: try: if self.reply_record_file.exists(): try: - with open(self.reply_record_file, "r", encoding="utf-8") as f: + with open(self.reply_record_file, encoding="utf-8") as f: file_content = f.read().strip() if not file_content: # 文件为空 logger.warning("回复记录文件为空,将创建新的记录") @@ -173,7 +173,7 @@ class ReplyTrackerService: if total_removed > 0: logger.info(f"清理了 {total_removed} 条超过{self.max_record_days}天的过期回复记录") - def has_replied(self, feed_id: str, comment_id: Union[str, int]) -> bool: + def has_replied(self, feed_id: str, comment_id: str | int) -> bool: """ 检查是否已经回复过指定的评论 @@ -190,7 +190,7 @@ class ReplyTrackerService: comment_id_str = str(comment_id) return feed_id in self.replied_comments and comment_id_str in self.replied_comments[feed_id] - def mark_as_replied(self, feed_id: str, comment_id: Union[str, int]): + def mark_as_replied(self, feed_id: str, comment_id: str | int): """ 标记指定评论为已回复 @@ -219,7 +219,7 @@ class ReplyTrackerService: else: logger.error(f"标记评论时数据验证失败: feed_id={feed_id}, comment_id={comment_id}") - def get_replied_comments(self, feed_id: str) -> Set[str]: + def get_replied_comments(self, feed_id: str) -> set[str]: """ 获取指定说说下所有已回复的评论ID @@ -234,7 +234,7 @@ class ReplyTrackerService: return {str(comment_id) for comment_id in self.replied_comments[feed_id].keys()} return set() - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: """ 获取回复记录统计信息 diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index 770ced8e6..7cf0e7c93 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 定时任务服务 根据日程表定时发送说说。 @@ -8,13 +7,14 @@ import asyncio import datetime import random import traceback -from typing import Callable +from collections.abc import Callable +from sqlalchemy import select + +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus from src.common.logger import get_logger from src.schedule.schedule_manager import schedule_manager -from src.common.database.sqlalchemy_database_api import get_db_session -from sqlalchemy import select -from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus from .qzone_service import QZoneService diff --git a/src/plugins/built_in/maizone_refactored/utils/history_utils.py b/src/plugins/built_in/maizone_refactored/utils/history_utils.py index 19b3e7baa..6f51a6c0d 100644 --- a/src/plugins/built_in/maizone_refactored/utils/history_utils.py +++ b/src/plugins/built_in/maizone_refactored/utils/history_utils.py @@ -1,15 +1,15 @@ -# -*- coding: utf-8 -*- """ 历史记录工具模块 提供用于获取QQ空间发送历史的功能。 """ -import orjson import os from pathlib import Path -from typing import Dict, Any, Optional, List +from typing import Any +import orjson import requests + from src.common.logger import get_logger logger = get_logger("MaiZone.HistoryUtils") @@ -26,11 +26,11 @@ class _CookieManager: return str(cookie_dir / f"cookies-{uin}.json") @staticmethod - def load_cookies(qq_account: str) -> Optional[Dict[str, str]]: + def load_cookies(qq_account: str) -> dict[str, str] | None: cookie_file = _CookieManager.get_cookie_file_path(qq_account) if os.path.exists(cookie_file): try: - with open(cookie_file, "r", encoding="utf-8") as f: + with open(cookie_file, encoding="utf-8") as f: return orjson.loads(f.read()) except Exception as e: logger.error(f"加载Cookie文件失败: {e}") @@ -42,7 +42,7 @@ class _SimpleQZoneAPI: LIST_URL = "https://user.qzone.qq.com/proxy/domain/taotao.qq.com/cgi-bin/emotion_cgi_msglist_v6" - def __init__(self, cookies_dict: Optional[Dict[str, str]] = None): + def __init__(self, cookies_dict: dict[str, str] | None = None): self.cookies = cookies_dict or {} self.gtk2 = "" p_skey = self.cookies.get("p_skey") or self.cookies.get("p_skey".upper()) @@ -55,7 +55,7 @@ class _SimpleQZoneAPI: hash_val += (hash_val << 5) + ord(char) return str(hash_val & 2147483647) - def get_feed_list(self, target_qq: str, num: int) -> List[Dict[str, Any]]: + def get_feed_list(self, target_qq: str, num: int) -> list[dict[str, Any]]: try: params = { "g_tk": self.gtk2, diff --git a/src/plugins/built_in/permission_management/plugin.py b/src/plugins/built_in/permission_management/plugin.py index fd8612348..d85ca8dd5 100644 --- a/src/plugins/built_in/permission_management/plugin.py +++ b/src/plugins/built_in/permission_management/plugin.py @@ -6,19 +6,17 @@ """ import re -from typing import List, Optional, Tuple, Type +from src.plugin_system.apis.logging_api import get_logger +from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.apis.plugin_register_api import register_plugin from src.plugin_system.base.base_plugin import BasePlugin -from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.base.command_args import CommandArgs -from src.plugin_system.apis.permission_api import permission_api -from src.plugin_system.apis.logging_api import get_logger -from src.plugin_system.base.component_types import PlusCommandInfo, ChatType +from src.plugin_system.base.component_types import ChatType, PlusCommandInfo from src.plugin_system.base.config_types import ConfigField +from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.utils.permission_decorators import require_permission - logger = get_logger("Permission") @@ -44,7 +42,7 @@ class PermissionCommand(PlusCommand): "plugin.permission.view", "权限查看:可以查看权限节点和用户权限信息", "permission_manager", True ) - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: """执行权限管理命令""" if args.is_empty: await self._show_help() @@ -114,7 +112,7 @@ class PermissionCommand(PlusCommand): await self.send_text(help_text) @staticmethod - def _parse_user_mention(mention: str) -> Optional[str]: + def _parse_user_mention(mention: str) -> str | None: """解析用户提及,提取QQ号 支持的格式: @@ -134,7 +132,7 @@ class PermissionCommand(PlusCommand): return None @staticmethod - def parse_user_from_args(args: CommandArgs, index: int = 0) -> Optional[str]: + def parse_user_from_args(args: CommandArgs, index: int = 0) -> str | None: """从CommandArgs中解析用户ID Args: @@ -166,7 +164,7 @@ class PermissionCommand(PlusCommand): return None @require_permission("plugin.permission.manage", "❌ 你没有权限管理的权限") - async def _grant_permission(self, chat_stream, args: List[str]): + async def _grant_permission(self, chat_stream, args: list[str]): """授权用户权限""" if len(args) < 2: await self.send_text("❌ 用法: /permission grant <@用户|QQ号> <权限节点>") @@ -189,7 +187,7 @@ class PermissionCommand(PlusCommand): await self.send_text("❌ 授权失败,请检查权限节点是否存在") @require_permission("plugin.permission.manage", "❌ 你没有权限管理的权限") - async def _revoke_permission(self, chat_stream, args: List[str]): + async def _revoke_permission(self, chat_stream, args: list[str]): """撤销用户权限""" if len(args) < 2: await self.send_text("❌ 用法: /permission revoke <@用户|QQ号> <权限节点>") @@ -212,7 +210,7 @@ class PermissionCommand(PlusCommand): await self.send_text("❌ 撤销失败,请检查权限节点是否存在") @require_permission("plugin.permission.view", "❌ 你没有查看权限的权限") - async def _list_permissions(self, chat_stream, args: List[str]): + async def _list_permissions(self, chat_stream, args: list[str]): """列出用户权限""" target_user_id = None @@ -244,7 +242,7 @@ class PermissionCommand(PlusCommand): await self.send_text(response) @require_permission("plugin.permission.view", "❌ 你没有查看权限的权限") - async def _check_permission(self, chat_stream, args: List[str]): + async def _check_permission(self, chat_stream, args: list[str]): """检查用户权限""" if len(args) < 2: await self.send_text("❌ 用法: /permission check <@用户|QQ号> <权限节点>") @@ -273,7 +271,7 @@ class PermissionCommand(PlusCommand): await self.send_text(response) @require_permission("plugin.permission.view", "❌ 你没有查看权限的权限") - async def _list_nodes(self, chat_stream, args: List[str]): + async def _list_nodes(self, chat_stream, args: list[str]): """列出权限节点""" plugin_name = args[0] if args else None @@ -388,6 +386,6 @@ class PermissionManagerPlugin(BasePlugin): } } - def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]: + def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type[PlusCommand]]]: """返回插件的PlusCommand组件""" return [(PermissionCommand.get_plus_command_info(), PermissionCommand)] diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index 5061cf496..56199611e 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -1,19 +1,18 @@ import asyncio -from typing import List, Tuple, Type from src.plugin_system import ( BasePlugin, - ConfigField, - register_plugin, - plugin_manage_api, - component_manage_api, ComponentInfo, ComponentType, + ConfigField, + component_manage_api, + plugin_manage_api, + register_plugin, ) -from src.plugin_system.base.plus_command import PlusCommand -from src.plugin_system.base.command_args import CommandArgs -from src.plugin_system.base.component_types import PlusCommandInfo, ChatType from src.plugin_system.apis.permission_api import permission_api +from src.plugin_system.base.command_args import CommandArgs +from src.plugin_system.base.component_types import ChatType, PlusCommandInfo +from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.utils.permission_decorators import require_permission @@ -31,7 +30,7 @@ class ManagementCommand(PlusCommand): super().__init__(*args, **kwargs) @require_permission("plugin.management.admin", "❌ 你没有插件管理的权限") - async def execute(self, args: CommandArgs) -> Tuple[bool, str, bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str, bool]: """执行插件管理命令""" if args.is_empty: await self._show_help("all") @@ -51,7 +50,7 @@ class ManagementCommand(PlusCommand): await self.send_text(f"❌ 未知的子命令: {subcommand}\n使用 /pm help 查看帮助") return True, "未知子命令", True - async def _handle_plugin_commands(self, args: List[str]) -> Tuple[bool, str, bool]: + async def _handle_plugin_commands(self, args: list[str]) -> tuple[bool, str, bool]: """处理插件相关命令""" if not args: await self._show_help("plugin") @@ -83,7 +82,7 @@ class ManagementCommand(PlusCommand): return True, "插件命令执行完成", True - async def _handle_component_commands(self, args: List[str]) -> Tuple[bool, str, bool]: + async def _handle_component_commands(self, args: list[str]) -> tuple[bool, str, bool]: """处理组件相关命令""" if not args: await self._show_help("component") @@ -258,7 +257,7 @@ class ManagementCommand(PlusCommand): else: await self.send_text(f"❌ 插件强制重载失败: `{plugin_name}`") except Exception as e: - await self.send_text(f"❌ 强制重载过程中发生错误: {str(e)}") + await self.send_text(f"❌ 强制重载过程中发生错误: {e!s}") async def _add_dir(self, dir_path: str): """添加插件目录""" @@ -271,17 +270,17 @@ class ManagementCommand(PlusCommand): await self.send_text(f"❌ 插件目录添加失败: `{dir_path}`") @staticmethod - def _fetch_all_registered_components() -> List[ComponentInfo]: + def _fetch_all_registered_components() -> list[ComponentInfo]: all_plugin_info = component_manage_api.get_all_plugin_info() if not all_plugin_info: return [] - components_info: List[ComponentInfo] = [] + components_info: list[ComponentInfo] = [] for plugin_info in all_plugin_info.values(): components_info.extend(plugin_info.components) return components_info - def _fetch_locally_disabled_components(self) -> List[str]: + def _fetch_locally_disabled_components(self) -> list[str]: """获取本地禁用的组件列表""" stream_id = self.message.chat_stream.stream_id locally_disabled_components_actions = component_manage_api.get_locally_disabled_components( @@ -509,7 +508,7 @@ class PluginManagementPlugin(BasePlugin): False, ) - def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]: + def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type[PlusCommand]]]: """返回插件的PlusCommand组件""" components = [] if self.get_config("plugin.enabled", True): diff --git a/src/plugins/built_in/proactive_thinker/plugin.py b/src/plugins/built_in/proactive_thinker/plugin.py index 5e55e9101..e74c35c8b 100644 --- a/src/plugins/built_in/proactive_thinker/plugin.py +++ b/src/plugins/built_in/proactive_thinker/plugin.py @@ -1,14 +1,13 @@ -from typing import List, Tuple, Type - from src.common.logger import get_logger -from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system import ( + BaseEventHandler, BasePlugin, ConfigField, - register_plugin, EventHandlerInfo, - BaseEventHandler, + register_plugin, ) +from src.plugin_system.base.base_plugin import BasePlugin + from .proacive_thinker_event import ProactiveThinkerEventHandler logger = get_logger(__name__) @@ -33,9 +32,9 @@ class ProactiveThinkerPlugin(BasePlugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def get_plugin_components(self) -> List[Tuple[EventHandlerInfo, Type[BaseEventHandler]]]: + def get_plugin_components(self) -> list[tuple[EventHandlerInfo, type[BaseEventHandler]]]: """返回插件的EventHandler组件""" - components: List[Tuple[EventHandlerInfo, Type[BaseEventHandler]]] = [ + components: list[tuple[EventHandlerInfo, type[BaseEventHandler]]] = [ (ProactiveThinkerEventHandler.get_handler_info(), ProactiveThinkerEventHandler) ] return components diff --git a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py index 5ad560243..be818f037 100644 --- a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py +++ b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py @@ -2,17 +2,17 @@ import asyncio import random import time from datetime import datetime -from typing import List, Union from maim_message import UserInfo from src.chat.message_receive.chat_stream import get_chat_manager from src.common.logger import get_logger from src.config.config import global_config -from src.manager.async_task_manager import async_task_manager, AsyncTask -from src.plugin_system import EventType, BaseEventHandler +from src.manager.async_task_manager import AsyncTask, async_task_manager +from src.plugin_system import BaseEventHandler, EventType from src.plugin_system.apis import chat_api, person_api from src.plugin_system.base.base_event import HandlerResult + from .proactive_thinker_executor import ProactiveThinkerExecutor logger = get_logger(__name__) @@ -199,7 +199,7 @@ class ProactiveThinkerEventHandler(BaseEventHandler): handler_name: str = "proactive_thinker_on_start" handler_description: str = "主动思考插件的启动事件处理器" - init_subscribe: List[Union[EventType, str]] = [EventType.ON_START] + init_subscribe: list[EventType | str] = [EventType.ON_START] async def execute(self, kwargs: dict | None) -> "HandlerResult": """在机器人启动时执行,根据配置决定是否启动后台任务。""" diff --git a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py index ab3631450..2accabe5e 100644 --- a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py +++ b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py @@ -1,20 +1,21 @@ -import orjson -from typing import Optional, Dict, Any from datetime import datetime +from typing import Any + +import orjson from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.person_info.person_info import get_person_info_manager from src.plugin_system.apis import ( chat_api, + database_api, + generator_api, + llm_api, + message_api, person_api, schedule_api, send_api, - llm_api, - message_api, - generator_api, - database_api, ) -from src.config.config import global_config, model_config -from src.person_info.person_info import get_person_info_manager logger = get_logger(__name__) @@ -101,7 +102,7 @@ class ProactiveThinkerExecutor: logger.error(f"解析 stream_id ({stream_id}) 或获取 stream 失败: {e}") return None - async def _gather_context(self, stream_id: str) -> Optional[Dict[str, Any]]: + async def _gather_context(self, stream_id: str) -> dict[str, Any] | None: """ 收集构建提示词所需的所有上下文信息 """ @@ -165,7 +166,7 @@ class ProactiveThinkerExecutor: "current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } - async def _make_decision(self, context: Dict[str, Any], start_mode: str) -> Optional[Dict[str, Any]]: + async def _make_decision(self, context: dict[str, Any], start_mode: str) -> dict[str, Any] | None: """ 决策模块:判断是否应该主动发起对话,以及聊什么话题 """ @@ -234,7 +235,7 @@ class ProactiveThinkerExecutor: logger.error(f"决策LLM返回的JSON格式无效: {response}") return {"should_reply": False, "reason": "决策模型返回格式错误"} - def _build_plan_prompt(self, context: Dict[str, Any], start_mode: str, topic: str, reason: str) -> str: + def _build_plan_prompt(self, context: dict[str, Any], start_mode: str, topic: str, reason: str) -> str: """ 根据启动模式和决策话题,构建最终的规划提示词 """ diff --git a/src/plugins/built_in/social_toolkit_plugin/plugin.py b/src/plugins/built_in/social_toolkit_plugin/plugin.py index a26879da7..71bb83767 100644 --- a/src/plugins/built_in/social_toolkit_plugin/plugin.py +++ b/src/plugins/built_in/social_toolkit_plugin/plugin.py @@ -1,24 +1,25 @@ -import re -from typing import List, Tuple, Type, Optional - -from src.plugin_system import ( - BasePlugin, - register_plugin, - BaseAction, - ComponentInfo, - ActionActivationType, - ConfigField, -) -from src.common.logger import get_logger -from .qq_emoji_list import qq_face -from src.plugin_system.base.component_types import ChatType -from src.person_info.person_info import get_person_info_manager -from dateutil.parser import parse as parse_datetime -from src.manager.async_task_manager import AsyncTask, async_task_manager -from src.plugin_system.apis import send_api, llm_api, generator_api -from src.chat.message_receive.chat_stream import ChatStream import asyncio import datetime +import re + +from dateutil.parser import parse as parse_datetime + +from src.chat.message_receive.chat_stream import ChatStream +from src.common.logger import get_logger +from src.manager.async_task_manager import AsyncTask, async_task_manager +from src.person_info.person_info import get_person_info_manager +from src.plugin_system import ( + ActionActivationType, + BaseAction, + BasePlugin, + ComponentInfo, + ConfigField, + register_plugin, +) +from src.plugin_system.apis import generator_api, llm_api, send_api +from src.plugin_system.base.component_types import ChatType + +from .qq_emoji_list import qq_face logger = get_logger("set_emoji_like_plugin") @@ -30,7 +31,7 @@ class ReminderTask(AsyncTask): self, delay: float, stream_id: str, - group_id: Optional[str], + group_id: str | None, is_group: bool, target_user_id: str, target_user_name: str, @@ -162,7 +163,7 @@ class PokeAction(BaseAction): """ associated_types = ["text"] - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """执行戳一戳的动作""" user_id = self.action_data.get("user_id") user_name = self.action_data.get("user_name") @@ -242,7 +243,7 @@ class SetEmojiLikeAction(BaseAction): if match: emoji_options.append(match.group(1)) - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """执行设置表情回应的动作""" message_id = None set_like = self.action_data.get("set", True) @@ -360,7 +361,7 @@ class RemindAction(BaseAction): "例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'", ] - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """执行设置提醒的动作""" user_name = self.action_data.get("user_name") remind_time_str = self.action_data.get("remind_time") @@ -386,14 +387,14 @@ class RemindAction(BaseAction): # 优先尝试直接解析 try: target_time = parse_datetime(remind_time_str, fuzzy=True) - except Exception: + except Exception as e: # 如果直接解析失败,调用 LLM 进行转换 logger.info(f"[ReminderPlugin] 直接解析时间 '{remind_time_str}' 失败,尝试使用 LLM 进行转换...") # 获取所有可用的模型配置 available_models = llm_api.get_available_models() if "utils_small" not in available_models: - raise ValueError("未找到 'utils_small' 模型配置,无法解析时间") + raise ValueError("未找到 'utils_small' 模型配置,无法解析时间") from e # 明确使用 'planner' 模型 model_to_use = available_models["utils_small"] @@ -419,7 +420,7 @@ class RemindAction(BaseAction): ) if not success or not response: - raise ValueError(f"LLM未能返回有效的时间字符串: {response}") + raise ValueError(f"LLM未能返回有效的时间字符串: {response}") from e converted_time_str = response.strip() logger.info(f"[ReminderPlugin] LLM 转换结果: '{converted_time_str}'") @@ -533,8 +534,8 @@ class SetEmojiLikePlugin(BasePlugin): # 插件基本信息 plugin_name: str = "social_toolkit_plugin" # 内部标识符 enable_plugin: bool = True - dependencies: List[str] = [] # 插件依赖列表 - python_dependencies: List[str] = [] # Python包依赖列表,现在使用内置API + dependencies: list[str] = [] # 插件依赖列表 + python_dependencies: list[str] = [] # Python包依赖列表,现在使用内置API config_file_name: str = "config.toml" # 配置文件名 # 配置节描述 @@ -555,7 +556,7 @@ class SetEmojiLikePlugin(BasePlugin): }, } - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: enable_components = [] if self.get_config("components.action_set_emoji_like"): enable_components.append((SetEmojiLikeAction.get_action_info(), SetEmojiLikeAction)) diff --git a/src/plugins/built_in/tts_plugin/plugin.py b/src/plugins/built_in/tts_plugin/plugin.py index fc625c093..8d1327a4f 100644 --- a/src/plugins/built_in/tts_plugin/plugin.py +++ b/src/plugins/built_in/tts_plugin/plugin.py @@ -1,10 +1,9 @@ +from src.common.logger import get_logger from src.plugin_system.apis.plugin_register_api import register_plugin +from src.plugin_system.base.base_action import ActionActivationType, BaseAction, ChatMode from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system.base.component_types import ComponentInfo -from src.common.logger import get_logger -from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode from src.plugin_system.base.config_types import ConfigField -from typing import Tuple, List, Type logger = get_logger("tts") @@ -44,7 +43,7 @@ class TTSAction(BaseAction): # 关联类型 associated_types = ["tts_text"] - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """处理TTS文本转语音动作""" logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}") @@ -140,7 +139,7 @@ class TTSPlugin(BasePlugin): }, } - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: """返回插件包含的组件列表""" # 从配置获取组件启用状态 diff --git a/src/plugins/built_in/web_search_tool/engines/base.py b/src/plugins/built_in/web_search_tool/engines/base.py index 30d20a540..4fd2c452a 100644 --- a/src/plugins/built_in/web_search_tool/engines/base.py +++ b/src/plugins/built_in/web_search_tool/engines/base.py @@ -3,7 +3,7 @@ Base search engine interface """ from abc import ABC, abstractmethod -from typing import Dict, List, Any +from typing import Any class BaseSearchEngine(ABC): @@ -12,7 +12,7 @@ class BaseSearchEngine(ABC): """ @abstractmethod - async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: + async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]: """ 执行搜索 diff --git a/src/plugins/built_in/web_search_tool/engines/bing_engine.py b/src/plugins/built_in/web_search_tool/engines/bing_engine.py index ece747fbd..46431bff1 100644 --- a/src/plugins/built_in/web_search_tool/engines/bing_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/bing_engine.py @@ -6,11 +6,13 @@ import asyncio import functools import random import traceback -from typing import Dict, List, Any +from typing import Any + import requests from bs4 import BeautifulSoup from src.common.logger import get_logger + from .base import BaseSearchEngine logger = get_logger("bing_engine") @@ -68,7 +70,7 @@ class BingSearchEngine(BaseSearchEngine): """检查Bing搜索引擎是否可用""" return True # Bing是免费搜索引擎,总是可用 - async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: + async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]: """执行Bing搜索""" query = args["query"] num_results = args.get("num_results", 3) @@ -83,7 +85,7 @@ class BingSearchEngine(BaseSearchEngine): logger.error(f"Bing 搜索失败: {e}") return [] - def _search_sync(self, keyword: str, num_results: int, time_range: str) -> List[Dict[str, Any]]: + def _search_sync(self, keyword: str, num_results: int, time_range: str) -> list[dict[str, Any]]: """同步执行Bing搜索""" if not keyword: return [] @@ -113,7 +115,7 @@ class BingSearchEngine(BaseSearchEngine): return list_result[:num_results] if len(list_result) > num_results else list_result @staticmethod - def _parse_html(url: str) -> List[Dict[str, Any]]: + def _parse_html(url: str) -> list[dict[str, Any]]: """解析处理结果""" try: logger.debug(f"访问Bing搜索URL: {url}") @@ -141,11 +143,11 @@ class BingSearchEngine(BaseSearchEngine): try: res = session.get(url=url, timeout=(3.05, 6), verify=True, allow_redirects=True) except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e: - logger.warning(f"第一次请求超时,正在重试: {str(e)}") + logger.warning(f"第一次请求超时,正在重试: {e!s}") try: res = session.get(url=url, timeout=(5, 10), verify=False) except Exception as e2: - logger.error(f"第二次请求也失败: {str(e2)}") + logger.error(f"第二次请求也失败: {e2!s}") return [] res.encoding = "utf-8" @@ -175,7 +177,7 @@ class BingSearchEngine(BaseSearchEngine): try: root = BeautifulSoup(res.text, "html.parser") except Exception as e: - logger.error(f"HTML解析失败: {str(e)}") + logger.error(f"HTML解析失败: {e!s}") return [] list_data = [] @@ -262,6 +264,6 @@ class BingSearchEngine(BaseSearchEngine): return list_data except Exception as e: - logger.error(f"解析Bing页面时出错: {str(e)}") + logger.error(f"解析Bing页面时出错: {e!s}") logger.debug(traceback.format_exc()) return [] diff --git a/src/plugins/built_in/web_search_tool/engines/ddg_engine.py b/src/plugins/built_in/web_search_tool/engines/ddg_engine.py index 29f03b31a..eb73f6bcd 100644 --- a/src/plugins/built_in/web_search_tool/engines/ddg_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/ddg_engine.py @@ -2,10 +2,12 @@ DuckDuckGo search engine implementation """ -from typing import Dict, List, Any +from typing import Any + from asyncddgs import aDDGS from src.common.logger import get_logger + from .base import BaseSearchEngine logger = get_logger("ddg_engine") @@ -20,7 +22,7 @@ class DDGSearchEngine(BaseSearchEngine): """检查DuckDuckGo搜索引擎是否可用""" return True # DuckDuckGo不需要API密钥,总是可用 - async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: + async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]: """执行DuckDuckGo搜索""" query = args["query"] num_results = args.get("num_results", 3) diff --git a/src/plugins/built_in/web_search_tool/engines/exa_engine.py b/src/plugins/built_in/web_search_tool/engines/exa_engine.py index 269e32bd1..37655eb53 100644 --- a/src/plugins/built_in/web_search_tool/engines/exa_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/exa_engine.py @@ -5,13 +5,15 @@ Exa search engine implementation import asyncio import functools from datetime import datetime, timedelta -from typing import Dict, List, Any +from typing import Any + from exa_py import Exa from src.common.logger import get_logger from src.plugin_system.apis import config_api -from .base import BaseSearchEngine + from ..utils.api_key_manager import create_api_key_manager_from_config +from .base import BaseSearchEngine logger = get_logger("exa_engine") @@ -36,7 +38,7 @@ class ExaSearchEngine(BaseSearchEngine): """检查Exa搜索引擎是否可用""" return self.api_manager.is_available() - async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: + async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]: """执行Exa搜索""" if not self.is_available(): return [] diff --git a/src/plugins/built_in/web_search_tool/engines/tavily_engine.py b/src/plugins/built_in/web_search_tool/engines/tavily_engine.py index 2f929284f..acbe23d81 100644 --- a/src/plugins/built_in/web_search_tool/engines/tavily_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/tavily_engine.py @@ -4,13 +4,15 @@ Tavily search engine implementation import asyncio import functools -from typing import Dict, List, Any +from typing import Any + from tavily import TavilyClient from src.common.logger import get_logger from src.plugin_system.apis import config_api -from .base import BaseSearchEngine + from ..utils.api_key_manager import create_api_key_manager_from_config +from .base import BaseSearchEngine logger = get_logger("tavily_engine") @@ -37,7 +39,7 @@ class TavilySearchEngine(BaseSearchEngine): """检查Tavily搜索引擎是否可用""" return self.api_manager.is_available() - async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: + async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]: """执行Tavily搜索""" if not self.is_available(): return [] diff --git a/src/plugins/built_in/web_search_tool/plugin.py b/src/plugins/built_in/web_search_tool/plugin.py index fadc02a88..2b85104bc 100644 --- a/src/plugins/built_in/web_search_tool/plugin.py +++ b/src/plugins/built_in/web_search_tool/plugin.py @@ -4,14 +4,12 @@ Web Search Tool Plugin 一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。 """ -from typing import List, Tuple, Type - -from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency -from src.plugin_system.apis import config_api from src.common.logger import get_logger +from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin +from src.plugin_system.apis import config_api -from .tools.web_search import WebSurfingTool from .tools.url_parser import URLParserTool +from .tools.web_search import WebSurfingTool logger = get_logger("web_search_plugin") @@ -31,7 +29,7 @@ class WEBSEARCHPLUGIN(BasePlugin): # 插件基本信息 plugin_name: str = "web_search_tool" # 内部标识符 enable_plugin: bool = True - dependencies: List[str] = [] # 插件依赖列表 + dependencies: list[str] = [] # 插件依赖列表 def __init__(self, *args, **kwargs): """初始化插件,立即加载所有搜索引擎""" @@ -40,10 +38,10 @@ class WEBSEARCHPLUGIN(BasePlugin): # 立即初始化所有搜索引擎,触发API密钥管理器的日志输出 logger.info("🚀 正在初始化所有搜索引擎...") try: + from .engines.bing_engine import BingSearchEngine + from .engines.ddg_engine import DDGSearchEngine from .engines.exa_engine import ExaSearchEngine from .engines.tavily_engine import TavilySearchEngine - from .engines.ddg_engine import DDGSearchEngine - from .engines.bing_engine import BingSearchEngine # 实例化所有搜索引擎,这会触发API密钥管理器的初始化 exa_engine = ExaSearchEngine() @@ -71,7 +69,7 @@ class WEBSEARCHPLUGIN(BasePlugin): logger.error(f"❌ 搜索引擎初始化失败: {e}", exc_info=True) # Python包依赖列表 - python_dependencies: List[PythonDependency] = [ + python_dependencies: list[PythonDependency] = [ PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False), PythonDependency( package_name="exa_py", @@ -119,7 +117,7 @@ class WEBSEARCHPLUGIN(BasePlugin): }, } - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: """ 获取插件组件列表 diff --git a/src/plugins/built_in/web_search_tool/tools/url_parser.py b/src/plugins/built_in/web_search_tool/tools/url_parser.py index 25338c35c..6e9bf5a03 100644 --- a/src/plugins/built_in/web_search_tool/tools/url_parser.py +++ b/src/plugins/built_in/web_search_tool/tools/url_parser.py @@ -4,19 +4,20 @@ URL parser tool implementation import asyncio import functools -from typing import Any, Dict -from exa_py import Exa +from typing import Any + import httpx from bs4 import BeautifulSoup +from exa_py import Exa +from src.common.cache_manager import tool_cache from src.common.logger import get_logger from src.plugin_system import BaseTool, ToolParamType, llm_api from src.plugin_system.apis import config_api -from src.common.cache_manager import tool_cache +from ..utils.api_key_manager import create_api_key_manager_from_config from ..utils.formatters import format_url_parse_results from ..utils.url_utils import parse_urls_from_input, validate_urls -from ..utils.api_key_manager import create_api_key_manager_from_config logger = get_logger("url_parser_tool") @@ -50,7 +51,7 @@ class URLParserTool(BaseTool): exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser" ) - async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]: + async def _local_parse_and_summarize(self, url: str) -> dict[str, Any]: """ 使用本地库(httpx, BeautifulSoup)解析URL,并调用LLM进行总结。 """ @@ -124,9 +125,9 @@ class URLParserTool(BaseTool): return {"error": f"请求失败,状态码: {e.response.status_code}"} except Exception as e: logger.error(f"本地解析或总结URL '{url}' 时发生未知异常: {e}", exc_info=True) - return {"error": f"发生未知错误: {str(e)}"} + return {"error": f"发生未知错误: {e!s}"} - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """ 执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。 """ diff --git a/src/plugins/built_in/web_search_tool/tools/web_search.py b/src/plugins/built_in/web_search_tool/tools/web_search.py index 3e4039cb8..9dcafc9a5 100644 --- a/src/plugins/built_in/web_search_tool/tools/web_search.py +++ b/src/plugins/built_in/web_search_tool/tools/web_search.py @@ -3,18 +3,18 @@ Web search tool implementation """ import asyncio -from typing import Any, Dict, List +from typing import Any +from src.common.cache_manager import tool_cache from src.common.logger import get_logger from src.plugin_system import BaseTool, ToolParamType from src.plugin_system.apis import config_api -from src.common.cache_manager import tool_cache +from ..engines.bing_engine import BingSearchEngine +from ..engines.ddg_engine import DDGSearchEngine from ..engines.exa_engine import ExaSearchEngine from ..engines.tavily_engine import TavilySearchEngine -from ..engines.ddg_engine import DDGSearchEngine -from ..engines.bing_engine import BingSearchEngine -from ..utils.formatters import format_search_results, deduplicate_results +from ..utils.formatters import deduplicate_results, format_search_results logger = get_logger("web_search_tool") @@ -51,7 +51,7 @@ class WebSurfingTool(BaseTool): "bing": BingSearchEngine(), } - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: query = function_args.get("query") if not query: return {"error": "搜索查询不能为空。"} @@ -88,8 +88,8 @@ class WebSurfingTool(BaseTool): return result async def _execute_parallel_search( - self, function_args: Dict[str, Any], enabled_engines: List[str] - ) -> Dict[str, Any]: + self, function_args: dict[str, Any], enabled_engines: list[str] + ) -> dict[str, Any]: """并行搜索策略:同时使用所有启用的搜索引擎""" search_tasks = [] @@ -124,11 +124,11 @@ class WebSurfingTool(BaseTool): except Exception as e: logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True) - return {"error": f"执行网络搜索时发生严重错误: {str(e)}"} + return {"error": f"执行网络搜索时发生严重错误: {e!s}"} async def _execute_fallback_search( - self, function_args: Dict[str, Any], enabled_engines: List[str] - ) -> Dict[str, Any]: + self, function_args: dict[str, Any], enabled_engines: list[str] + ) -> dict[str, Any]: """回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个""" for engine_name in enabled_engines: engine = self.engines.get(engine_name) @@ -154,7 +154,7 @@ class WebSurfingTool(BaseTool): return {"error": "所有搜索引擎都失败了。"} - async def _execute_single_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: + async def _execute_single_search(self, function_args: dict[str, Any], enabled_engines: list[str]) -> dict[str, Any]: """单一搜索策略:只使用第一个可用的搜索引擎""" for engine_name in enabled_engines: engine = self.engines.get(engine_name) @@ -174,6 +174,6 @@ class WebSurfingTool(BaseTool): except Exception as e: logger.error(f"{engine_name} 搜索失败: {e}") - return {"error": f"{engine_name} 搜索失败: {str(e)}"} + return {"error": f"{engine_name} 搜索失败: {e!s}"} return {"error": "没有可用的搜索引擎。"} diff --git a/src/plugins/built_in/web_search_tool/utils/api_key_manager.py b/src/plugins/built_in/web_search_tool/utils/api_key_manager.py index 07757cdb1..e7aba03ce 100644 --- a/src/plugins/built_in/web_search_tool/utils/api_key_manager.py +++ b/src/plugins/built_in/web_search_tool/utils/api_key_manager.py @@ -3,7 +3,9 @@ API密钥管理器,提供轮询机制 """ import itertools -from typing import List, Optional, TypeVar, Generic, Callable +from collections.abc import Callable +from typing import Generic, TypeVar + from src.common.logger import get_logger logger = get_logger("api_key_manager") @@ -16,7 +18,7 @@ class APIKeyManager(Generic[T]): API密钥管理器,支持轮询机制 """ - def __init__(self, api_keys: List[str], client_factory: Callable[[str], T], service_name: str = "Unknown"): + def __init__(self, api_keys: list[str], client_factory: Callable[[str], T], service_name: str = "Unknown"): """ 初始化API密钥管理器 @@ -26,8 +28,8 @@ class APIKeyManager(Generic[T]): service_name: 服务名称,用于日志记录 """ self.service_name = service_name - self.clients: List[T] = [] - self.client_cycle: Optional[itertools.cycle] = None + self.clients: list[T] = [] + self.client_cycle: itertools.cycle | None = None if api_keys: # 过滤有效的API密钥,排除None、空字符串、"None"字符串等 @@ -54,7 +56,7 @@ class APIKeyManager(Generic[T]): """检查是否有可用的客户端""" return bool(self.clients and self.client_cycle) - def get_next_client(self) -> Optional[T]: + def get_next_client(self) -> T | None: """获取下一个客户端(轮询)""" if not self.is_available(): return None @@ -66,7 +68,7 @@ class APIKeyManager(Generic[T]): def create_api_key_manager_from_config( - config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str + config_keys: list[str] | None, client_factory: Callable[[str], T], service_name: str ) -> APIKeyManager[T]: """ 从配置创建API密钥管理器的便捷函数 diff --git a/src/plugins/built_in/web_search_tool/utils/formatters.py b/src/plugins/built_in/web_search_tool/utils/formatters.py index df1e4ea18..6173b0bca 100644 --- a/src/plugins/built_in/web_search_tool/utils/formatters.py +++ b/src/plugins/built_in/web_search_tool/utils/formatters.py @@ -2,10 +2,10 @@ Formatters for web search results """ -from typing import List, Dict, Any +from typing import Any -def format_search_results(results: List[Dict[str, Any]]) -> str: +def format_search_results(results: list[dict[str, Any]]) -> str: """ 格式化搜索结果为字符串 """ @@ -26,7 +26,7 @@ def format_search_results(results: List[Dict[str, Any]]) -> str: return formatted_string -def format_url_parse_results(results: List[Dict[str, Any]]) -> str: +def format_url_parse_results(results: list[dict[str, Any]]) -> str: """ 将成功解析的URL结果列表格式化为一段简洁的文本。 """ @@ -45,7 +45,7 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str: return "\n---\n".join(formatted_parts) -def deduplicate_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +def deduplicate_results(results: list[dict[str, Any]]) -> list[dict[str, Any]]: """ 根据URL去重搜索结果 """ diff --git a/src/plugins/built_in/web_search_tool/utils/url_utils.py b/src/plugins/built_in/web_search_tool/utils/url_utils.py index 5bdde0a55..f96d4a04a 100644 --- a/src/plugins/built_in/web_search_tool/utils/url_utils.py +++ b/src/plugins/built_in/web_search_tool/utils/url_utils.py @@ -3,10 +3,9 @@ URL processing utilities """ import re -from typing import List -def parse_urls_from_input(urls_input) -> List[str]: +def parse_urls_from_input(urls_input) -> list[str]: """ 从输入中解析URL列表 """ @@ -29,7 +28,7 @@ def parse_urls_from_input(urls_input) -> List[str]: return urls -def validate_urls(urls: List[str]) -> List[str]: +def validate_urls(urls: list[str]) -> list[str]: """ 验证URL格式,返回有效的URL列表 """ diff --git a/src/schedule/database.py b/src/schedule/database.py index b33bfb953..ccaf92b7f 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -1,7 +1,8 @@ # mmc/src/schedule/database.py -from typing import List -from sqlalchemy import select, func, update, delete + +from sqlalchemy import delete, func, select, update + from src.common.database.sqlalchemy_models import MonthlyPlan, get_db_session from src.common.logger import get_logger from src.config.config import global_config @@ -9,7 +10,7 @@ from src.config.config import global_config logger = get_logger("schedule_database") -async def add_new_plans(plans: List[str], month: str): +async def add_new_plans(plans: list[str], month: str): """ 批量添加新生成的月度计划到数据库,并确保不超过上限。 @@ -55,7 +56,7 @@ async def add_new_plans(plans: List[str], month: str): raise -async def get_active_plans_for_month(month: str) -> List[MonthlyPlan]: +async def get_active_plans_for_month(month: str) -> list[MonthlyPlan]: """ 获取指定月份所有状态为 'active' 的计划。 @@ -75,7 +76,7 @@ async def get_active_plans_for_month(month: str) -> List[MonthlyPlan]: return [] -async def mark_plans_completed(plan_ids: List[int]): +async def mark_plans_completed(plan_ids: list[int]): """ 将指定ID的计划标记为已完成。 @@ -103,7 +104,7 @@ async def mark_plans_completed(plan_ids: List[int]): raise -async def delete_plans_by_ids(plan_ids: List[int]): +async def delete_plans_by_ids(plan_ids: list[int]): """ 根据ID列表从数据库中物理删除月度计划。 @@ -134,7 +135,7 @@ async def delete_plans_by_ids(plan_ids: List[int]): raise -async def update_plan_usage(plan_ids: List[int], used_date: str): +async def update_plan_usage(plan_ids: list[int], used_date: str): """ 更新计划的使用统计信息。 @@ -182,7 +183,7 @@ async def update_plan_usage(plan_ids: List[int], used_date: str): raise -async def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_days: int = 7) -> List[MonthlyPlan]: +async def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_days: int = 7) -> list[MonthlyPlan]: """ 智能抽取月度计划用于每日日程生成。 @@ -255,7 +256,7 @@ async def archive_active_plans_for_month(month: str): raise -async def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]: +async def get_archived_plans_for_month(month: str) -> list[MonthlyPlan]: """ 获取指定月份所有状态为 'archived' 的计划。 用于生成下个月计划时的参考。 diff --git a/src/schedule/llm_generator.py b/src/schedule/llm_generator.py index d3ec56bb6..b8f4c51bd 100644 --- a/src/schedule/llm_generator.py +++ b/src/schedule/llm_generator.py @@ -1,16 +1,18 @@ # mmc/src/schedule/llm_generator.py import asyncio -import orjson from datetime import datetime -from typing import List, Optional, Dict, Any -from lunar_python import Lunar +from typing import Any + +import orjson from json_repair import repair_json +from lunar_python import Lunar from src.common.database.sqlalchemy_models import MonthlyPlan +from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.common.logger import get_logger + from .schemas import ScheduleData logger = get_logger("schedule_llm_generator") @@ -37,7 +39,7 @@ class ScheduleLLMGenerator: def __init__(self): self.llm = LLMRequest(model_set=model_config.model_task_config.schedule_generator, request_type="schedule") - async def generate_schedule_with_llm(self, sampled_plans: List[MonthlyPlan]) -> Optional[List[Dict[str, Any]]]: + async def generate_schedule_with_llm(self, sampled_plans: list[MonthlyPlan]) -> list[dict[str, Any]] | None: now = datetime.now() today_str = now.strftime("%Y-%m-%d") weekday = now.strftime("%A") @@ -143,7 +145,7 @@ class MonthlyPlanLLMGenerator: def __init__(self): self.llm = LLMRequest(model_set=model_config.model_task_config.schedule_generator, request_type="monthly_plan") - async def generate_plans_with_llm(self, target_month: str, archived_plans: List[MonthlyPlan]) -> List[str]: + async def generate_plans_with_llm(self, target_month: str, archived_plans: list[MonthlyPlan]) -> list[str]: guidelines = global_config.planning_system.monthly_plan_guidelines or DEFAULT_MONTHLY_PLAN_GUIDELINES personality = global_config.personality.personality_core personality_side = global_config.personality.personality_side @@ -209,7 +211,7 @@ class MonthlyPlanLLMGenerator: return [] @staticmethod - def _parse_plans_response(response: str) -> List[str]: + def _parse_plans_response(response: str) -> list[str]: try: response = response.strip() lines = [line.strip() for line in response.split("\n") if line.strip()] diff --git a/src/schedule/monthly_plan_manager.py b/src/schedule/monthly_plan_manager.py index 7deaaf77d..22e19cd49 100644 --- a/src/schedule/monthly_plan_manager.py +++ b/src/schedule/monthly_plan_manager.py @@ -1,9 +1,9 @@ import asyncio from datetime import datetime, timedelta -from typing import Optional from src.common.logger import get_logger from src.manager.async_task_manager import AsyncTask, async_task_manager + from .plan_manager import PlanManager logger = get_logger("monthly_plan_manager") @@ -31,7 +31,7 @@ class MonthlyPlanManager: else: logger.info(" 每月月度计划生成任务已在运行中。") - async def ensure_and_generate_plans_if_needed(self, target_month: Optional[str] = None) -> bool: + async def ensure_and_generate_plans_if_needed(self, target_month: str | None = None) -> bool: return await self.plan_manager.ensure_and_generate_plans_if_needed(target_month) diff --git a/src/schedule/plan_manager.py b/src/schedule/plan_manager.py index 513a907d5..239bdf3c2 100644 --- a/src/schedule/plan_manager.py +++ b/src/schedule/plan_manager.py @@ -1,18 +1,18 @@ # mmc/src/schedule/plan_manager.py from datetime import datetime -from typing import List, Optional from src.common.logger import get_logger from src.config.config import global_config + from .database import ( add_new_plans, - get_archived_plans_for_month, archive_active_plans_for_month, - has_active_plans, - get_active_plans_for_month, delete_plans_by_ids, + get_active_plans_for_month, + get_archived_plans_for_month, get_smart_plans_for_daily_schedule, + has_active_plans, ) from .llm_generator import MonthlyPlanLLMGenerator @@ -24,7 +24,7 @@ class PlanManager: self.llm_generator = MonthlyPlanLLMGenerator() self.generation_running = False - async def ensure_and_generate_plans_if_needed(self, target_month: Optional[str] = None) -> bool: + async def ensure_and_generate_plans_if_needed(self, target_month: str | None = None) -> bool: if target_month is None: target_month = datetime.now().strftime("%Y-%m") @@ -48,7 +48,7 @@ class PlanManager: logger.info(f"当前月度计划内容:\n{plan_texts}") return True - async def _generate_monthly_plans_logic(self, target_month: Optional[str] = None) -> bool: + async def _generate_monthly_plans_logic(self, target_month: str | None = None) -> bool: if self.generation_running: logger.info("月度计划生成任务已在运行中,跳过重复启动") return False @@ -90,7 +90,7 @@ class PlanManager: except Exception: return "1900-01" - async def archive_current_month_plans(self, target_month: Optional[str] = None): + async def archive_current_month_plans(self, target_month: str | None = None): try: if target_month is None: target_month = datetime.now().strftime("%Y-%m") @@ -100,6 +100,6 @@ class PlanManager: except Exception as e: logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}") - async def get_plans_for_schedule(self, month: str, max_count: int) -> List: + async def get_plans_for_schedule(self, month: str, max_count: int) -> list: avoid_days = global_config.planning_system.avoid_repetition_days return await get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) diff --git a/src/schedule/schedule_manager.py b/src/schedule/schedule_manager.py index 115480381..9f1133df6 100644 --- a/src/schedule/schedule_manager.py +++ b/src/schedule/schedule_manager.py @@ -1,14 +1,15 @@ -import orjson import asyncio from datetime import datetime, time, timedelta -from typing import Optional, List, Dict, Any +from typing import Any +import orjson from sqlalchemy import select from src.common.database.sqlalchemy_models import Schedule, get_db_session -from src.config.config import global_config from src.common.logger import get_logger +from src.config.config import global_config from src.manager.async_task_manager import AsyncTask, async_task_manager + from .database import update_plan_usage from .llm_generator import ScheduleLLMGenerator from .plan_manager import PlanManager @@ -19,7 +20,7 @@ logger = get_logger("schedule_manager") class ScheduleManager: def __init__(self): - self.today_schedule: Optional[List[Dict[str, Any]]] = None + self.today_schedule: list[dict[str, Any]] | None = None self.llm_generator = ScheduleLLMGenerator() self.plan_manager = PlanManager() self.daily_task_started = False @@ -63,7 +64,7 @@ class ScheduleManager: logger.info("尝试生成日程作为备用方案...") await self.generate_and_save_schedule() - async def _load_schedule_from_db(self, date_str: str) -> Optional[List[Dict[str, Any]]]: + async def _load_schedule_from_db(self, date_str: str) -> list[dict[str, Any]] | None: async with get_db_session() as session: result = await session.execute(select(Schedule).filter(Schedule.date == date_str)) schedule_record = result.scalars().first() @@ -118,7 +119,7 @@ class ScheduleManager: logger.info("日程生成任务结束") @staticmethod - async def _save_schedule_to_db(date_str: str, schedule_data: List[Dict[str, Any]]): + async def _save_schedule_to_db(date_str: str, schedule_data: list[dict[str, Any]]): async with get_db_session() as session: schedule_json = orjson.dumps(schedule_data).decode("utf-8") result = await session.execute(select(Schedule).filter(Schedule.date == date_str)) @@ -132,13 +133,13 @@ class ScheduleManager: await session.commit() @staticmethod - def _log_generated_schedule(date_str: str, schedule_data: List[Dict[str, Any]]): + def _log_generated_schedule(date_str: str, schedule_data: list[dict[str, Any]]): schedule_str = f"✅ 成功生成并保存今天的日程 ({date_str}):\n" for item in schedule_data: schedule_str += f" - {item.get('time_range', '未知时间')}: {item.get('activity', '未知活动')}\n" logger.info(schedule_str) - def get_current_activity(self) -> Optional[str]: + def get_current_activity(self) -> str | None: if not global_config.planning_system.schedule_enable or not self.today_schedule: return None now = datetime.now().time() diff --git a/src/schedule/schemas.py b/src/schedule/schemas.py index a733731be..00508e4d8 100644 --- a/src/schedule/schemas.py +++ b/src/schedule/schemas.py @@ -1,7 +1,7 @@ # mmc/src/schedule/schemas.py from datetime import datetime, time -from typing import List + from pydantic import BaseModel, validator @@ -41,7 +41,7 @@ class ScheduleItem(BaseModel): class ScheduleData(BaseModel): """完整日程数据的Pydantic模型""" - schedule: List[ScheduleItem] + schedule: list[ScheduleItem] @validator("schedule") def validate_schedule_completeness(cls, v): @@ -67,7 +67,7 @@ class ScheduleData(BaseModel): return v @staticmethod - def _check_24_hour_coverage(time_ranges: List[tuple]) -> bool: + def _check_24_hour_coverage(time_ranges: list[tuple]) -> bool: """检查时间段是否覆盖24小时""" if not time_ranges: return False diff --git a/src/utils/message_chunker.py b/src/utils/message_chunker.py index 66a2964e1..2e98adcf1 100644 --- a/src/utils/message_chunker.py +++ b/src/utils/message_chunker.py @@ -3,10 +3,12 @@ MaiBot 端的消息切片处理模块 用于接收和重组来自 Napcat-Adapter 的切片消息 """ -import orjson -import time import asyncio -from typing import Dict, Any, Optional +import time +from typing import Any + +import orjson + from src.common.logger import get_logger logger = get_logger("message_chunker") @@ -17,7 +19,7 @@ class MessageReassembler: def __init__(self, timeout: int = 30): self.timeout = timeout - self.chunk_buffers: Dict[str, Dict[str, Any]] = {} + self.chunk_buffers: dict[str, dict[str, Any]] = {} self._cleanup_task = None async def start_cleanup_task(self): @@ -59,7 +61,7 @@ class MessageReassembler: logger.error(f"清理过期切片时出错: {e}") @staticmethod - def is_chunk_message(message: Dict[str, Any]) -> bool: + def is_chunk_message(message: dict[str, Any]) -> bool: """检查是否是来自 Ada 的切片消息""" return ( isinstance(message, dict) @@ -68,7 +70,7 @@ class MessageReassembler: and "__mmc_is_chunked__" in message ) - async def process_chunk(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]: + async def process_chunk(self, message: dict[str, Any]) -> dict[str, Any] | None: """ 处理切片消息,如果切片完整则返回重组后的消息 @@ -144,7 +146,7 @@ class MessageReassembler: logger.error(f"处理切片消息时出错: {e}") return None - def get_pending_chunks_info(self) -> Dict[str, Any]: + def get_pending_chunks_info(self) -> dict[str, Any]: """获取待处理切片信息""" info = {} for chunk_id, buffer in self.chunk_buffers.items(): diff --git a/src/utils/timing_utils.py b/src/utils/timing_utils.py index b4084d6af..36fb1f870 100644 --- a/src/utils/timing_utils.py +++ b/src/utils/timing_utils.py @@ -10,10 +10,10 @@ - 快速筛选:使用NumPy布尔索引进行高效过滤 """ -import numpy as np -from typing import Optional from functools import lru_cache +import numpy as np + @lru_cache(maxsize=128) def _calculate_sigma_bounds(base_interval: int, sigma_percentage: float, use_3sigma_rule: bool) -> tuple: @@ -35,8 +35,8 @@ def _calculate_sigma_bounds(base_interval: int, sigma_percentage: float, use_3si def get_normal_distributed_interval( base_interval: int, sigma_percentage: float = 0.1, - min_interval: Optional[int] = None, - max_interval: Optional[int] = None, + min_interval: int | None = None, + max_interval: int | None = None, use_3sigma_rule: bool = True, ) -> int: """ @@ -120,8 +120,8 @@ def get_normal_distributed_interval( def _generate_pure_random_interval( sigma_percentage: float, - min_interval: Optional[int] = None, - max_interval: Optional[int] = None, + min_interval: int | None = None, + max_interval: int | None = None, use_3sigma_rule: bool = True, ) -> int: """ diff --git a/ui_log_adapter.py b/ui_log_adapter.py index 58ae14f80..3fb474620 100644 --- a/ui_log_adapter.py +++ b/ui_log_adapter.py @@ -3,9 +3,9 @@ Bot服务UI日志适配器 在最小侵入的情况下捕获Bot的日志并发送到UI """ -import sys -import os import logging +import os +import sys import threading import time From 047105e5e8628d5d570d3b3708b8a8f96a5eaa0a Mon Sep 17 00:00:00 2001 From: John Richard Date: Thu, 2 Oct 2025 21:10:36 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E6=9B=B4=E6=94=B9=E9=83=A8=E5=88=86?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E6=B3=A8=E8=A7=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/database/sqlalchemy_models.py | 9 ++--- src/plugin_system/base/base_action.py | 3 +- src/plugin_system/base/base_plugin.py | 6 +--- src/plugin_system/base/plugin_base.py | 35 ++++--------------- .../built_in/social_toolkit_plugin/plugin.py | 3 +- src/schedule/database.py | 6 ++-- 6 files changed, 17 insertions(+), 45 deletions(-) diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index c89848ee3..bda9f36ec 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -760,7 +760,7 @@ async def initialize_database(): @asynccontextmanager -async def get_db_session() -> AsyncGenerator[AsyncSession | None, None]: +async def get_db_session() -> AsyncGenerator[AsyncSession]: """ 异步数据库会话上下文管理器。 在初始化失败时会yield None,调用方需要检查会话是否为None。 @@ -770,13 +770,10 @@ async def get_db_session() -> AsyncGenerator[AsyncSession | None, None]: try: _, SessionLocal = await initialize_database() if not SessionLocal: - logger.error("数据库会话工厂 (_SessionLocal) 未初始化。") - yield None - return + raise RuntimeError("数据库会话工厂 (_SessionLocal) 未初始化。") except Exception as e: logger.error(f"数据库初始化失败,无法创建会话: {e}") - yield None - return + raise try: session = SessionLocal() diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 37711794b..8fcb5dd2d 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -1,3 +1,4 @@ +# Todo: 重构Action,这里现在只剩下了报错。 import asyncio import time from abc import ABC, abstractmethod @@ -452,7 +453,7 @@ class BaseAction(ABC): # 4. 执行Action logger.debug(f"{log_prefix} 开始执行...") - execute_result = await action_instance.execute() + execute_result = await action_instance.execute() # Todo: 修复类型错误 # 确保返回类型符合 (bool, str) 格式 is_success = execute_result[0] if isinstance(execute_result, tuple) and len(execute_result) > 0 else False message = execute_result[1] if isinstance(execute_result, tuple) and len(execute_result) > 1 else "" diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index 232365bce..cfe163a8f 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -21,10 +21,6 @@ class BasePlugin(PluginBase): - Command组件:处理命令请求 - 未来可扩展:Scheduler、Listener等 """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - @abstractmethod def get_plugin_components( self, @@ -42,7 +38,7 @@ class BasePlugin(PluginBase): Returns: List[tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...] """ - raise NotImplementedError("Subclasses must implement this method") + ... def register_plugin(self) -> bool: """注册插件及其所有组件""" diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index 8cc3312db..683e4985c 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -27,40 +27,17 @@ class PluginBase(ABC): """ # 插件基本信息(子类必须定义) - @property - @abstractmethod - def plugin_name(self) -> str: - return "" # 插件内部标识符(如 "hello_world_plugin") - - @property - @abstractmethod - def enable_plugin(self) -> bool: - return True # 是否启用插件 - - @property - @abstractmethod - def dependencies(self) -> list[str]: - return [] # 依赖的其他插件 - - @property - @abstractmethod - def python_dependencies(self) -> list[str | PythonDependency]: - return [] # Python包依赖,支持字符串列表或PythonDependency对象列表 - - @property - @abstractmethod - def config_file_name(self) -> str: - return "" # 配置文件名 + plugin_name: str + config_file_name: str + enable_plugin: bool = True + dependencies: list[str] = [] + python_dependencies: list[str | PythonDependency] = [] # manifest文件相关 manifest_file_name: str = "_manifest.json" # manifest文件名 manifest_data: dict[str, Any] = {} # manifest数据 - # 配置定义 - @property - @abstractmethod - def config_schema(self) -> dict[str, dict[str, ConfigField] | str]: - return {} + config_schema: dict[str, dict[str, ConfigField] | str] = {} config_section_descriptions: dict[str, str] = {} diff --git a/src/plugins/built_in/social_toolkit_plugin/plugin.py b/src/plugins/built_in/social_toolkit_plugin/plugin.py index 71bb83767..c9216199a 100644 --- a/src/plugins/built_in/social_toolkit_plugin/plugin.py +++ b/src/plugins/built_in/social_toolkit_plugin/plugin.py @@ -1,6 +1,7 @@ import asyncio import datetime import re +from typing import ClassVar from dateutil.parser import parse as parse_datetime @@ -542,7 +543,7 @@ class SetEmojiLikePlugin(BasePlugin): config_section_descriptions = {"plugin": "插件基本信息", "components": "插件组件"} # 配置Schema定义 - config_schema: dict = { + config_schema: ClassVar[dict ]= { "plugin": { "name": ConfigField(type=str, default="set_emoji_like", description="插件名称"), "version": ConfigField(type=str, default="1.0.0", description="插件版本"), diff --git a/src/schedule/database.py b/src/schedule/database.py index ccaf92b7f..ee771ac53 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -70,7 +70,7 @@ async def get_active_plans_for_month(month: str) -> list[MonthlyPlan]: .where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") .order_by(MonthlyPlan.created_at.desc()) ) - return result.scalars().all() + return list(result.scalars().all()) except Exception as e: logger.error(f"查询 {month} 的有效月度计划时发生错误: {e}") return [] @@ -225,7 +225,7 @@ async def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avo plans = random.sample(plans, max_count) logger.info(f"智能抽取了 {len(plans)} 条 {month} 的月度计划用于每日日程生成。") - return plans + return list(plans) except Exception as e: logger.error(f"智能抽取 {month} 的月度计划时发生错误: {e}") @@ -269,7 +269,7 @@ async def get_archived_plans_for_month(month: str) -> list[MonthlyPlan]: result = await session.execute( select(MonthlyPlan).where(MonthlyPlan.target_month == month, MonthlyPlan.status == "archived") ) - return result.scalars().all() + return list(result.scalars().all()) except Exception as e: logger.error(f"查询 {month} 的归档月度计划时发生错误: {e}") return []