From 7923eafef389a7047755989325fb64b4ac4ad604 Mon Sep 17 00:00:00 2001 From: John Richard Date: Thu, 2 Oct 2025 20:26:01 +0800 Subject: [PATCH] =?UTF-8?q?re-style:=20=E6=A0=BC=E5=BC=8F=E5=8C=96?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __main__.py | 2 +- bot.py | 24 ++-- plugins/bilibli/__init__.py | 1 - plugins/bilibli/bilibli_base.py | 27 ++-- plugins/bilibli/plugin.py | 21 +-- plugins/echo_example/plugin.py | 25 ++-- plugins/hello_world_plugin/plugin.py | 32 ++--- pyproject.toml | 49 ++++--- scripts/expression_stats.py | 13 +- scripts/interest_value_analysis.py | 15 +-- scripts/log_viewer_optimized.py | 27 ++-- scripts/lpmm_learning_tool.py | 39 +++--- scripts/manifest_tool.py | 8 +- scripts/mongodb_to_sqlite.py | 82 +++++------ scripts/rebuild_metadata_index.py | 5 +- scripts/run_multi_stage_smoke.py | 3 +- scripts/text_length_analysis.py | 17 ++- scripts/update_prompt_imports.py | 2 +- src/__init__.py | 8 +- src/chat/__init__.py | 2 +- src/chat/antipromptinjector/__init__.py | 27 ++-- src/chat/antipromptinjector/anti_injector.py | 30 +++-- src/chat/antipromptinjector/core/__init__.py | 3 +- src/chat/antipromptinjector/core/detector.py | 21 ++- src/chat/antipromptinjector/core/shield.py | 7 +- src/chat/antipromptinjector/counter_attack.py | 6 +- .../antipromptinjector/decision/__init__.py | 5 +- .../decision/counter_attack.py | 6 +- .../decision/decision_maker.py | 2 +- src/chat/antipromptinjector/decision_maker.py | 2 +- src/chat/antipromptinjector/detector.py | 21 ++- .../antipromptinjector/management/__init__.py | 1 - .../management/statistics.py | 7 +- .../antipromptinjector/management/user_ban.py | 7 +- .../antipromptinjector/processors/__init__.py | 1 - .../processors/message_processor.py | 6 +- src/chat/antipromptinjector/types.py | 6 +- src/chat/chatter_manager.py | 23 ++-- src/chat/emoji_system/emoji_history.py | 6 +- src/chat/emoji_system/emoji_manager.py | 91 ++++++------- src/chat/energy_system/__init__.py | 20 +-- src/chat/energy_system/energy_manager.py | 36 ++--- src/chat/express/expression_learner.py | 41 +++--- src/chat/express/expression_selector.py | 32 ++--- src/chat/frequency_analyzer/analyzer.py | 11 +- src/chat/frequency_analyzer/tracker.py | 10 +- src/chat/frequency_analyzer/trigger.py | 7 +- src/chat/interest_system/__init__.py | 5 +- .../interest_system/bot_interest_manager.py | 49 +++---- src/chat/knowledge/embedding_store.py | 66 +++++---- src/chat/knowledge/ie_process.py | 19 +-- src/chat/knowledge/kg_manager.py | 42 +++--- src/chat/knowledge/knowledge_lib.py | 11 +- src/chat/knowledge/open_ie.py | 17 +-- src/chat/knowledge/qa_manager.py | 15 ++- src/chat/knowledge/utils/dyn_topk.py | 6 +- src/chat/memory_system/__init__.py | 28 ++-- .../enhanced_memory_adapter.py | 44 +++--- .../enhanced_memory_hooks.py | 12 +- .../enhanced_memory_integration.py | 22 +-- .../deprecated_backup/enhanced_reranker.py | 27 ++-- .../deprecated_backup/integration_layer.py | 26 ++-- .../memory_integration_hooks.py | 28 ++-- .../deprecated_backup/metadata_index.py | 86 ++++++------ .../multi_stage_retrieval.py | 113 ++++++++-------- .../deprecated_backup/vector_storage.py | 77 +++++------ .../enhanced_memory_activator.py | 21 ++- .../memory_system/memory_activator_new.py | 21 ++- src/chat/memory_system/memory_builder.py | 73 +++++----- src/chat/memory_system/memory_chunk.py | 77 +++++------ .../memory_system/memory_forgetting_engine.py | 22 ++- src/chat/memory_system/memory_fusion.py | 38 +++--- src/chat/memory_system/memory_manager.py | 38 +++--- .../memory_system/memory_metadata_index.py | 98 +++++++------- .../memory_system/memory_query_planner.py | 33 +++-- src/chat/memory_system/memory_system.py | 106 +++++++-------- .../memory_system/vector_memory_storage_v2.py | 66 ++++----- src/chat/message_manager/__init__.py | 4 +- src/chat/message_manager/context_manager.py | 23 ++-- .../message_manager/distribution_manager.py | 24 ++-- src/chat/message_manager/message_manager.py | 23 ++-- .../sleep_manager/sleep_manager.py | 9 +- .../sleep_manager/sleep_state.py | 9 +- .../sleep_manager/time_checker.py | 12 +- .../sleep_manager/wakeup_manager.py | 13 +- src/chat/message_receive/__init__.py | 5 +- src/chat/message_receive/bot.py | 39 +++--- src/chat/message_receive/chat_stream.py | 54 ++++---- src/chat/message_receive/message.py | 37 +++-- src/chat/message_receive/storage.py | 18 +-- .../message_receive/uni_message_sender.py | 9 +- src/chat/planner_actions/action_manager.py | 38 +++--- src/chat/planner_actions/action_modifier.py | 37 +++-- src/chat/replyer/default_generator.py | 84 ++++++------ src/chat/replyer/replyer_manager.py | 12 +- src/chat/utils/chat_message_builder.py | 81 ++++++----- src/chat/utils/memory_mappings.py | 1 - src/chat/utils/prompt.py | 91 ++++++------- src/chat/utils/statistic.py | 40 +++--- src/chat/utils/timer_calculator.py | 14 +- src/chat/utils/typo_generator.py | 12 +- src/chat/utils/utils.py | 28 ++-- src/chat/utils/utils_image.py | 40 +++--- src/chat/utils/utils_video.py | 55 ++++---- src/chat/utils/utils_video_legacy.py | 34 ++--- src/chat/utils/utils_voice.py | 8 +- src/common/cache_manager.py | 42 +++--- src/common/config_helpers.py | 8 +- .../data_models/bot_interest_data_model.py | 28 ++-- src/common/data_models/database_data_model.py | 52 +++---- src/common/data_models/info_data_model.py | 28 ++-- src/common/data_models/llm_data_model.py | 16 +-- .../data_models/message_manager_data_model.py | 27 ++-- src/common/database/database.py | 5 +- .../database/sqlalchemy_database_api.py | 70 +++++----- src/common/database/sqlalchemy_init.py | 6 +- src/common/database/sqlalchemy_models.py | 15 ++- src/common/logger.py | 14 +- src/common/message/__init__.py | 1 - src/common/message/api.py | 10 +- src/common/message_repository.py | 16 +-- src/common/remote.py | 6 +- src/common/server.py | 14 +- src/common/tcp_connector.py | 3 +- src/common/vector_db/__init__.py | 2 +- src/common/vector_db/base.py | 34 ++--- src/common/vector_db/chromadb_impl.py | 41 +++--- src/config/api_ada_configs.py | 17 +-- src/config/config.py | 92 +++++++------ src/config/config_base.py | 10 +- src/config/official_configs.py | 45 ++++--- src/individuality/individuality.py | 19 +-- src/individuality/not_using/offline_llm.py | 20 +-- src/individuality/not_using/per_bf_gen.py | 30 ++--- src/individuality/not_using/scene.py | 5 +- src/llm_models/exceptions.py | 1 - .../model_client/aiohttp_gemini_client.py | 40 +++--- src/llm_models/model_client/base_client.py | 15 ++- src/llm_models/model_client/openai_client.py | 53 ++++---- src/llm_models/payload_content/message.py | 1 - src/llm_models/payload_content/resp_format.py | 13 +- src/llm_models/utils.py | 15 ++- src/llm_models/utils_model.py | 83 ++++++------ src/main.py | 56 ++++---- src/mais4u/mai_think.py | 13 +- .../body_emotion_action_manager.py | 12 +- src/mais4u/mais4u_chat/context_web_manager.py | 12 +- src/mais4u/mais4u_chat/gift_manager.py | 12 +- src/mais4u/mais4u_chat/internal_manager.py | 2 +- src/mais4u/mais4u_chat/s4u_chat.py | 44 +++--- src/mais4u/mais4u_chat/s4u_mood_manager.py | 11 +- src/mais4u/mais4u_chat/s4u_msg_processor.py | 14 +- src/mais4u/mais4u_chat/s4u_prompt.py | 34 ++--- .../mais4u_chat/s4u_stream_generator.py | 14 +- src/mais4u/mais4u_chat/screen_manager.py | 2 +- src/mais4u/mais4u_chat/super_chat_manager.py | 14 +- src/mais4u/mais4u_chat/yes_or_no.py | 2 +- src/mais4u/openai_client.py | 29 ++-- src/mais4u/s4u_config.py | 21 +-- src/manager/async_task_manager.py | 11 +- src/manager/local_store_manager.py | 7 +- src/mood/mood_manager.py | 11 +- src/person_info/person_info.py | 19 +-- src/person_info/relationship_builder.py | 21 +-- .../relationship_builder_manager.py | 11 +- src/person_info/relationship_fetcher.py | 17 ++- src/person_info/relationship_manager.py | 23 ++-- src/plugin_system/__init__.py | 70 +++++----- src/plugin_system/apis/__init__.py | 15 ++- src/plugin_system/apis/chat_api.py | 38 +++--- .../apis/component_manage_api.py | 23 ++-- src/plugin_system/apis/config_api.py | 1 + src/plugin_system/apis/cross_context_api.py | 20 +-- src/plugin_system/apis/database_api.py | 4 +- src/plugin_system/apis/emoji_api.py | 13 +- src/plugin_system/apis/generator_api.py | 44 +++--- src/plugin_system/apis/llm_api.py | 29 ++-- src/plugin_system/apis/message_api.py | 66 ++++----- src/plugin_system/apis/permission_api.py | 18 +-- src/plugin_system/apis/person_api.py | 7 +- src/plugin_system/apis/plugin_manage_api.py | 11 +- src/plugin_system/apis/plugin_register_api.py | 2 +- src/plugin_system/apis/schedule_api.py | 22 +-- src/plugin_system/apis/send_api.py | 37 +++-- src/plugin_system/apis/tool_api.py | 8 +- src/plugin_system/base/__init__.py | 20 +-- src/plugin_system/base/base_action.py | 39 +++--- src/plugin_system/base/base_chatter.py | 8 +- src/plugin_system/base/base_command.py | 16 +-- src/plugin_system/base/base_event.py | 18 +-- src/plugin_system/base/base_events_handler.py | 8 +- src/plugin_system/base/base_plugin.py | 20 ++- src/plugin_system/base/base_tool.py | 13 +- src/plugin_system/base/command_args.py | 5 +- src/plugin_system/base/component_types.py | 63 ++++----- src/plugin_system/base/config_types.py | 6 +- src/plugin_system/base/plugin_base.py | 51 +++---- src/plugin_system/base/plus_command.py | 31 +++-- src/plugin_system/core/__init__.py | 4 +- src/plugin_system/core/component_registry.py | 127 +++++++++--------- src/plugin_system/core/event_manager.py | 46 +++---- .../core/global_announcement_manager.py | 18 ++- src/plugin_system/core/permission_manager.py | 22 +-- src/plugin_system/core/plugin_manager.py | 46 +++---- src/plugin_system/core/tool_use.py | 43 +++--- src/plugin_system/utils/dependency_alias.py | 1 - src/plugin_system/utils/dependency_config.py | 3 +- src/plugin_system/utils/dependency_manager.py | 25 ++-- src/plugin_system/utils/manifest_utils.py | 17 +-- .../utils/permission_decorators.py | 16 +-- .../affinity_flow_chatter/affinity_chatter.py | 20 +-- .../affinity_flow_chatter/interest_scoring.py | 22 +-- .../affinity_flow_chatter/plan_executor.py | 19 ++- .../affinity_flow_chatter/plan_filter.py | 27 ++-- .../affinity_flow_chatter/plan_generator.py | 5 +- .../built_in/affinity_flow_chatter/planner.py | 34 +++-- .../built_in/affinity_flow_chatter/plugin.py | 6 +- .../relationship_tracker.py | 44 +++--- .../core_actions/anti_injector_manager.py | 6 +- src/plugins/built_in/core_actions/emoji.py | 19 ++- src/plugins/built_in/core_actions/plugin.py | 15 +-- .../built_in/knowledge/lpmm_get_knowledge.py | 10 +- .../built_in/maizone_refactored/__init__.py | 5 +- .../actions/read_feed_action.py | 8 +- .../actions/send_feed_action.py | 8 +- .../commands/send_feed_command.py | 10 +- .../built_in/maizone_refactored/plugin.py | 20 ++- .../services/content_service.py | 28 ++-- .../services/cookie_service.py | 22 +-- .../services/image_service.py | 3 +- .../maizone_refactored/services/manager.py | 7 +- .../services/monitor_service.py | 4 +- .../services/qzone_service.py | 53 ++++---- .../services/reply_tracker_service.py | 16 +-- .../services/scheduler_service.py | 10 +- .../maizone_refactored/utils/history_utils.py | 14 +- .../built_in/permission_management/plugin.py | 28 ++-- .../built_in/plugin_management/plugin.py | 31 +++-- .../built_in/proactive_thinker/plugin.py | 13 +- .../proacive_thinker_event.py | 8 +- .../proactive_thinker_executor.py | 23 ++-- .../built_in/social_toolkit_plugin/plugin.py | 59 ++++---- src/plugins/built_in/tts_plugin/plugin.py | 9 +- .../built_in/web_search_tool/engines/base.py | 4 +- .../web_search_tool/engines/bing_engine.py | 18 +-- .../web_search_tool/engines/ddg_engine.py | 6 +- .../web_search_tool/engines/exa_engine.py | 8 +- .../web_search_tool/engines/tavily_engine.py | 8 +- .../built_in/web_search_tool/plugin.py | 18 ++- .../web_search_tool/tools/url_parser.py | 15 ++- .../web_search_tool/tools/web_search.py | 26 ++-- .../web_search_tool/utils/api_key_manager.py | 14 +- .../web_search_tool/utils/formatters.py | 8 +- .../web_search_tool/utils/url_utils.py | 5 +- src/schedule/database.py | 19 +-- src/schedule/llm_generator.py | 16 ++- src/schedule/monthly_plan_manager.py | 4 +- src/schedule/plan_manager.py | 16 +-- src/schedule/schedule_manager.py | 17 +-- src/schedule/schemas.py | 6 +- src/utils/message_chunker.py | 16 ++- src/utils/timing_utils.py | 12 +- ui_log_adapter.py | 4 +- 263 files changed, 3103 insertions(+), 3123 deletions(-) diff --git a/__main__.py b/__main__.py index 15bf83a4e..f6d2a3178 100644 --- a/__main__.py +++ b/__main__.py @@ -12,7 +12,7 @@ if __name__ == "__main__": # 执行bot.py的代码 bot_file = current_dir / "bot.py" - with open(bot_file, "r", encoding="utf-8") as f: + with open(bot_file, encoding="utf-8") as f: exec(f.read()) diff --git a/bot.py b/bot.py index 798247c96..985065b99 100644 --- a/bot.py +++ b/bot.py @@ -1,30 +1,30 @@ # import asyncio import asyncio import os +import platform import sys import time -import platform import traceback from pathlib import Path -from rich.traceback import install -from colorama import init, Fore + +from colorama import Fore, init from dotenv import load_dotenv # 处理.env文件 +from rich.traceback import install # maim_message imports for console input - # 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式 -from src.common.logger import initialize_logging, get_logger, shutdown_logging +from src.common.logger import get_logger, initialize_logging, shutdown_logging # UI日志适配器 initialize_logging() from src.main import MainSystem # noqa -from src import BaseMain # noqa -from src.manager.async_task_manager import async_task_manager # noqa -from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge # noqa -from src.config.config import global_config # noqa -from src.common.database.database import initialize_sql_database # noqa -from src.common.database.sqlalchemy_models import initialize_database as init_db # noqa +from src import BaseMain +from src.manager.async_task_manager import async_task_manager +from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge +from src.config.config import global_config +from src.common.database.database import initialize_sql_database +from src.common.database.sqlalchemy_models import initialize_database as init_db logger = get_logger("main") @@ -247,7 +247,7 @@ if __name__ == "__main__": # The actual shutdown logic is now in the finally block. except Exception as e: - logger.error(f"主程序发生异常: {str(e)} {str(traceback.format_exc())}") + logger.error(f"主程序发生异常: {e!s} {traceback.format_exc()!s}") exit_code = 1 # 标记发生错误 finally: # 确保 loop 在任何情况下都尝试关闭(如果存在且未关闭) diff --git a/plugins/bilibli/__init__.py b/plugins/bilibli/__init__.py index ca649acac..7f6e5c3c2 100644 --- a/plugins/bilibli/__init__.py +++ b/plugins/bilibli/__init__.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ Bilibili 插件包 提供B站视频观看体验功能,像真实用户一样浏览和评价视频 diff --git a/plugins/bilibli/bilibli_base.py b/plugins/bilibli/bilibli_base.py index 34e794fd7..c35538dba 100644 --- a/plugins/bilibli/bilibli_base.py +++ b/plugins/bilibli/bilibli_base.py @@ -1,16 +1,17 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ Bilibili 工具基础模块 提供 B 站视频信息获取和视频分析功能 """ -import re -import aiohttp import asyncio -from typing import Optional, Dict, Any -from src.common.logger import get_logger +import re +from typing import Any + +import aiohttp + from src.chat.utils.utils_video import get_video_analyzer +from src.common.logger import get_logger logger = get_logger("bilibili_tool") @@ -25,7 +26,7 @@ class BilibiliVideoAnalyzer: "Referer": "https://www.bilibili.com/", } - def extract_bilibili_url(self, text: str) -> Optional[str]: + def extract_bilibili_url(self, text: str) -> str | None: """从文本中提取哔哩哔哩视频链接""" # 哔哩哔哩短链接模式 short_pattern = re.compile(r"https?://b23\.tv/[\w]+", re.IGNORECASE) @@ -44,7 +45,7 @@ class BilibiliVideoAnalyzer: return None - async def get_video_info(self, url: str) -> Optional[Dict[str, Any]]: + async def get_video_info(self, url: str) -> dict[str, Any] | None: """获取哔哩哔哩视频基本信息""" try: logger.info(f"🔍 解析视频URL: {url}") @@ -127,7 +128,7 @@ class BilibiliVideoAnalyzer: logger.exception("详细错误信息:") return None - async def get_video_stream_url(self, aid: int, cid: int) -> Optional[str]: + async def get_video_stream_url(self, aid: int, cid: int) -> str | None: """获取视频流URL""" try: logger.info(f"🎥 获取视频流URL: aid={aid}, cid={cid}") @@ -164,7 +165,7 @@ class BilibiliVideoAnalyzer: return stream_url # 降级到FLV格式 - if "durl" in play_data and play_data["durl"]: + if play_data.get("durl"): logger.info("📹 使用FLV格式视频流") stream_url = play_data["durl"][0].get("url") if stream_url: @@ -185,7 +186,7 @@ class BilibiliVideoAnalyzer: logger.exception("详细错误信息:") return None - async def download_video_bytes(self, stream_url: str, max_size_mb: int = 100) -> Optional[bytes]: + async def download_video_bytes(self, stream_url: str, max_size_mb: int = 100) -> bytes | None: """下载视频字节数据 Args: @@ -244,7 +245,7 @@ class BilibiliVideoAnalyzer: logger.exception("详细错误信息:") return None - async def analyze_bilibili_video(self, url: str, prompt: str = None) -> Dict[str, Any]: + async def analyze_bilibili_video(self, url: str, prompt: str = None) -> dict[str, Any]: """分析哔哩哔哩视频并返回详细信息和AI分析结果""" try: logger.info(f"🎬 开始分析哔哩哔哩视频: {url}") @@ -322,10 +323,10 @@ class BilibiliVideoAnalyzer: return result except Exception as e: - error_msg = f"分析哔哩哔哩视频时发生异常: {str(e)}" + error_msg = f"分析哔哩哔哩视频时发生异常: {e!s}" logger.error(f"❌ {error_msg}") logger.exception("详细错误信息:") # 记录完整的异常堆栈 - return {"error": f"分析失败: {str(e)}"} + return {"error": f"分析失败: {e!s}"} # 全局实例 diff --git a/plugins/bilibli/plugin.py b/plugins/bilibli/plugin.py index 72129c034..41f97bdeb 100644 --- a/plugins/bilibli/plugin.py +++ b/plugins/bilibli/plugin.py @@ -1,14 +1,15 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ Bilibili 视频观看体验工具 支持哔哩哔哩视频链接解析和AI视频内容分析 """ -from typing import Dict, Any, List, Tuple, Type -from src.plugin_system import BaseTool, ToolParamType, BasePlugin, register_plugin, ComponentInfo, ConfigField -from .bilibli_base import get_bilibili_analyzer +from typing import Any + from src.common.logger import get_logger +from src.plugin_system import BasePlugin, BaseTool, ComponentInfo, ConfigField, ToolParamType, register_plugin + +from .bilibli_base import get_bilibili_analyzer logger = get_logger("bilibili_tool") @@ -41,7 +42,7 @@ class BilibiliTool(BaseTool): super().__init__(plugin_config) self.analyzer = get_bilibili_analyzer() - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """执行哔哩哔哩视频观看体验""" try: url = function_args.get("url", "").strip() @@ -83,7 +84,7 @@ class BilibiliTool(BaseTool): return {"name": self.name, "content": content.strip()} except Exception as e: - error_msg = f"😅 看视频的时候出了点问题: {str(e)}" + error_msg = f"😅 看视频的时候出了点问题: {e!s}" logger.error(error_msg) return {"name": self.name, "content": error_msg} @@ -104,7 +105,7 @@ class BilibiliTool(BaseTool): return base_prompt - def _format_watch_experience(self, video_info: Dict, ai_analysis: str, interest_focus: str = None) -> str: + def _format_watch_experience(self, video_info: dict, ai_analysis: str, interest_focus: str = None) -> str: """格式化观看体验报告""" # 根据播放量生成热度评价 @@ -191,8 +192,8 @@ class BilibiliPlugin(BasePlugin): # 插件基本信息 plugin_name: str = "bilibili_video_watcher" enable_plugin: bool = True - dependencies: List[str] = [] - python_dependencies: List[str] = [] + dependencies: list[str] = [] + python_dependencies: list[str] = [] config_file_name: str = "config.toml" # 配置节描述 @@ -220,6 +221,6 @@ class BilibiliPlugin(BasePlugin): }, } - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: """返回插件包含的工具组件""" return [(BilibiliTool.get_tool_info(), BilibiliTool)] diff --git a/plugins/echo_example/plugin.py b/plugins/echo_example/plugin.py index 6f99cc901..e03429805 100644 --- a/plugins/echo_example/plugin.py +++ b/plugins/echo_example/plugin.py @@ -4,14 +4,15 @@ Echo 示例插件 展示增强命令系统的使用方法 """ -from typing import List, Tuple, Type, Optional, Union +from typing import Union + from src.plugin_system import ( BasePlugin, - PlusCommand, - CommandArgs, - PlusCommandInfo, - ConfigField, ChatType, + CommandArgs, + ConfigField, + PlusCommand, + PlusCommandInfo, register_plugin, ) from src.plugin_system.base.component_types import PythonDependency @@ -27,7 +28,7 @@ class EchoCommand(PlusCommand): chat_type_allow = ChatType.ALL intercept_message = True - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: """执行echo命令""" if args.is_empty(): await self.send_text("❓ 请提供要回显的内容\n用法: /echo <内容>") @@ -56,7 +57,7 @@ class HelloCommand(PlusCommand): chat_type_allow = ChatType.ALL intercept_message = True - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: """执行hello命令""" if args.is_empty(): await self.send_text("👋 Hello! 很高兴见到你!") @@ -77,7 +78,7 @@ class InfoCommand(PlusCommand): chat_type_allow = ChatType.ALL intercept_message = True - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: """执行info命令""" info_text = ( "📋 Echo 示例插件信息\n" @@ -105,7 +106,7 @@ class TestCommand(PlusCommand): chat_type_allow = ChatType.ALL intercept_message = True - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: """执行test命令""" if args.is_empty(): help_text = ( @@ -166,8 +167,8 @@ class EchoExamplePlugin(BasePlugin): plugin_name: str = "echo_example_plugin" enable_plugin: bool = True - dependencies: List[str] = [] - python_dependencies: List[Union[str, "PythonDependency"]] = [] + dependencies: list[str] = [] + python_dependencies: list[Union[str, "PythonDependency"]] = [] config_file_name: str = "config.toml" config_schema = { @@ -187,7 +188,7 @@ class EchoExamplePlugin(BasePlugin): "commands": "命令相关配置", } - def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type]]: + def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type]]: """获取插件组件""" components = [] diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index ca7a6a13a..2c71293a1 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -1,20 +1,20 @@ -from typing import List, Tuple, Type, Dict, Any, Optional import logging import random +from typing import Any from src.plugin_system import ( - BasePlugin, - register_plugin, - ComponentInfo, - BaseEventHandler, - EventType, - BaseTool, - PlusCommand, - CommandArgs, - ChatType, - BaseAction, ActionActivationType, + BaseAction, + BaseEventHandler, + BasePlugin, + BaseTool, + ChatType, + CommandArgs, + ComponentInfo, ConfigField, + EventType, + PlusCommand, + register_plugin, ) from src.plugin_system.base.base_event import HandlerResult @@ -39,7 +39,7 @@ class GetSystemInfoTool(BaseTool): available_for_llm = True parameters = [] - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: return {"name": self.name, "content": "系统版本: 1.0.1, 状态: 运行正常"} @@ -51,7 +51,7 @@ class HelloCommand(PlusCommand): command_aliases = ["hi", "你好"] chat_type_allow = ChatType.ALL - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: greeting = str(self.get_config("greeting.message", "Hello, World! 我是一个由 MoFox_Bot 驱动的插件。")) await self.send_text(greeting) return True, "成功发送问候", True @@ -67,7 +67,7 @@ class RandomEmojiAction(BaseAction): action_require = ["当对话气氛轻松时", "可以用来回应简单的情感表达"] associated_types = ["text"] - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: emojis = ["😊", "😂", "👍", "🎉", "🤔", "🤖"] await self.send_text(random.choice(emojis)) return True, "成功发送了一个随机表情" @@ -99,9 +99,9 @@ class HelloWorldPlugin(BasePlugin): }, } - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: """根据配置文件动态注册插件的功能组件。""" - components: List[Tuple[ComponentInfo, Type]] = [] + components: list[tuple[ComponentInfo, type]] = [] components.append((StartupMessageHandler.get_handler_info(), StartupMessageHandler)) components.append((GetSystemInfoTool.get_tool_info(), GetSystemInfoTool)) diff --git a/pyproject.toml b/pyproject.toml index a67f28472..04bb07299 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ dependencies = [ "tqdm>=4.67.1", "urllib3>=2.5.0", "uvicorn>=0.35.0", + "watchdog>=6.0.0", "websockets>=15.0.1", "aiomysql>=0.2.0", "aiosqlite>=0.21.0", @@ -80,29 +81,41 @@ dependencies = [ url = "https://pypi.tuna.tsinghua.edu.cn/simple" default = true +[tool.uv.sources] +amrita = { workspace = true } + [tool.ruff] - -include = ["*.py"] - -# 行长度设置 line-length = 120 +target-version = "py310" [tool.ruff.lint] -fixable = ["ALL"] -unfixable = [] +select = [ + "F", # Pyflakes + "W", # pycodestyle warnings + "E", # pycodestyle errors + "UP", # pyupgrade + "ASYNC", # flake8-async + "C4", # flake8-comprehensions + "T10", # flake8-debugger + "PYI", # flake8-pyi + "PT", # flake8-pytest-style + "Q", # flake8-quotes + "RUF", # Ruff-specific rules + "I", # isort + "PERF", # pylint-performance +] +ignore = [ + "E402", # module-import-not-at-top-of-file + "E501", # line-too-long + "UP037", # quoted-annotation + "RUF001", # ambiguous-unicode-character-string + "RUF002", # ambiguous-unicode-character-docstring + "RUF003", # ambiguous-unicode-character-comment +] + # 如果一个变量的名称以下划线开头,即使它未被使用,也不应该被视为错误或警告。 dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" - -# 启用的规则 -select = [ - "E", # pycodestyle 错误 - "F", # pyflakes - "B", # flake8-bugbear -] - -ignore = ["E711","E501"] - [tool.ruff.format] docstring-code-format = true indent-style = "space" @@ -124,6 +137,4 @@ skip-magic-trailing-comma = false line-ending = "auto" [dependency-groups] -lint = [ - "loguru>=0.7.3", -] +lint = ["loguru>=0.7.3"] diff --git a/scripts/expression_stats.py b/scripts/expression_stats.py index 133f3d73b..b79819493 100644 --- a/scripts/expression_stats.py +++ b/scripts/expression_stats.py @@ -1,10 +1,9 @@ -import time -import sys import os -from typing import Dict, List +import sys +import time # Add project root to Python path -from src.common.database.database_model import Expression, ChatStreams +from src.common.database.database_model import ChatStreams, Expression project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, project_root) @@ -30,7 +29,7 @@ def get_chat_name(chat_id: str) -> str: return f"查询失败 ({chat_id})" -def calculate_time_distribution(expressions) -> Dict[str, int]: +def calculate_time_distribution(expressions) -> dict[str, int]: """Calculate distribution of last active time in days""" now = time.time() distribution = { @@ -64,7 +63,7 @@ def calculate_time_distribution(expressions) -> Dict[str, int]: return distribution -def calculate_count_distribution(expressions) -> Dict[str, int]: +def calculate_count_distribution(expressions) -> dict[str, int]: """Calculate distribution of count values""" distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0} for expr in expressions: @@ -86,7 +85,7 @@ def calculate_count_distribution(expressions) -> Dict[str, int]: return distribution -def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]: +def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> list[Expression]: """Get top N most used expressions for a specific chat_id""" return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n) diff --git a/scripts/interest_value_analysis.py b/scripts/interest_value_analysis.py index bce37b4a2..e464c905c 100644 --- a/scripts/interest_value_analysis.py +++ b/scripts/interest_value_analysis.py @@ -1,7 +1,6 @@ -import time -import sys import os -from typing import Dict, List, Tuple, Optional +import sys +import time from datetime import datetime # Add project root to Python path @@ -35,7 +34,7 @@ def format_timestamp(timestamp: float) -> str: return "未知时间" -def calculate_interest_value_distribution(messages) -> Dict[str, int]: +def calculate_interest_value_distribution(messages) -> dict[str, int]: """Calculate distribution of interest_value""" distribution = { "0.000-0.010": 0, @@ -76,7 +75,7 @@ def calculate_interest_value_distribution(messages) -> Dict[str, int]: return distribution -def get_interest_value_stats(messages) -> Dict[str, float]: +def get_interest_value_stats(messages) -> dict[str, float]: """Calculate basic statistics for interest_value""" values = [ float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0 @@ -97,7 +96,7 @@ def get_interest_value_stats(messages) -> Dict[str, float]: } -def get_available_chats() -> List[Tuple[str, str, int]]: +def get_available_chats() -> list[tuple[str, str, int]]: """Get all available chats with message counts""" try: # 获取所有有消息的chat_id @@ -130,7 +129,7 @@ def get_available_chats() -> List[Tuple[str, str, int]]: return [] -def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: +def get_time_range_input() -> tuple[float | None, float | None]: """Get time range input from user""" print("\n时间范围选择:") print("1. 最近1天") @@ -170,7 +169,7 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: def analyze_interest_values( - chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None + chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None ) -> None: """Analyze interest values with optional filters""" diff --git a/scripts/log_viewer_optimized.py b/scripts/log_viewer_optimized.py index 2103e5486..f38dafa64 100644 --- a/scripts/log_viewer_optimized.py +++ b/scripts/log_viewer_optimized.py @@ -1,13 +1,14 @@ -import tkinter as tk -from tkinter import ttk, messagebox, filedialog, colorchooser -import orjson -from pathlib import Path -import threading -import toml -from datetime import datetime -from collections import defaultdict import os +import threading import time +import tkinter as tk +from collections import defaultdict +from datetime import datetime +from pathlib import Path +from tkinter import colorchooser, filedialog, messagebox, ttk + +import orjson +import toml class LogIndex: @@ -409,7 +410,7 @@ class AsyncLogLoader: file_size = os.path.getsize(file_path) processed_size = 0 - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: line_count = 0 batch_size = 1000 # 批量处理 @@ -561,7 +562,7 @@ class LogViewer: try: if config_path.exists(): - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: bot_config = toml.load(f) if "log" in bot_config: self.log_config.update(bot_config["log"]) @@ -575,7 +576,7 @@ class LogViewer: try: if viewer_config_path.exists(): - with open(viewer_config_path, "r", encoding="utf-8") as f: + with open(viewer_config_path, encoding="utf-8") as f: viewer_config = toml.load(f) if "viewer" in viewer_config: self.viewer_config.update(viewer_config["viewer"]) @@ -843,7 +844,7 @@ class LogViewer: mapping_file = Path("config/module_mapping.json") if mapping_file.exists(): try: - with open(mapping_file, "r", encoding="utf-8") as f: + with open(mapping_file, encoding="utf-8") as f: custom_mapping = orjson.loads(f.read()) self.module_name_mapping.update(custom_mapping) except Exception as e: @@ -1172,7 +1173,7 @@ class LogViewer: """读取新的日志条目并返回它们""" new_entries = [] new_modules = set() # 收集新发现的模块 - with open(self.current_log_file, "r", encoding="utf-8") as f: + with open(self.current_log_file, encoding="utf-8") as f: f.seek(from_position) line_count = self.log_index.total_entries for line in f: diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 9caafc7fd..58aa91c64 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -1,36 +1,37 @@ import asyncio +import datetime import os import shutil import sys -import orjson -import datetime -from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path from threading import Lock -from typing import Optional + +import orjson from json_repair import repair_json # 将项目根目录添加到 sys.path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from src.common.logger import get_logger -from src.chat.knowledge.utils.hash import get_sha256 -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config -from src.chat.knowledge.open_ie import OpenIE -from src.chat.knowledge.embedding_store import EmbeddingManager -from src.chat.knowledge.kg_manager import KGManager from rich.progress import ( - Progress, BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, TimeElapsedColumn, TimeRemainingColumn, - TaskProgressColumn, - MofNCompleteColumn, - SpinnerColumn, - TextColumn, ) +from src.chat.knowledge.embedding_store import EmbeddingManager +from src.chat.knowledge.kg_manager import KGManager +from src.chat.knowledge.open_ie import OpenIE +from src.chat.knowledge.utils.hash import get_sha256 +from src.common.logger import get_logger +from src.config.config import model_config +from src.llm_models.utils_model import LLMRequest + logger = get_logger("LPMM_LearningTool") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data") @@ -59,7 +60,7 @@ def clear_cache(): def process_text_file(file_path): - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: raw = f.read() return [p.strip() for p in raw.split("\n\n") if p.strip()] @@ -86,7 +87,7 @@ def preprocess_raw_data(): # --- 模块二:信息提取 --- -def _parse_and_repair_json(json_string: str) -> Optional[dict]: +def _parse_and_repair_json(json_string: str) -> dict | None: """ 尝试解析JSON字符串,如果失败则尝试修复并重新解析。 @@ -249,7 +250,7 @@ def extract_information(paragraphs_dict, model_set): # --- 模块三:数据导入 --- -async def import_data(openie_obj: Optional[OpenIE] = None): +async def import_data(openie_obj: OpenIE | None = None): """ 将OpenIE数据导入知识库(Embedding Store 和 KG) diff --git a/scripts/manifest_tool.py b/scripts/manifest_tool.py index 6f9a3a6d0..c18b6a208 100644 --- a/scripts/manifest_tool.py +++ b/scripts/manifest_tool.py @@ -4,11 +4,13 @@ 提供插件manifest文件的创建、验证和管理功能 """ +import argparse import os import sys -import argparse -import orjson from pathlib import Path + +import orjson + from src.common.logger import get_logger from src.plugin_system.utils.manifest_utils import ( ManifestValidator, @@ -124,7 +126,7 @@ def validate_manifest_file(plugin_dir: str) -> bool: return False try: - with open(manifest_path, "r", encoding="utf-8") as f: + with open(manifest_path, encoding="utf-8") as f: manifest_data = orjson.loads(f.read()) validator = ManifestValidator() diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py index 789c5860a..36b7aa9ab 100644 --- a/scripts/mongodb_to_sqlite.py +++ b/scripts/mongodb_to_sqlite.py @@ -1,46 +1,48 @@ import os -import orjson -import sys # 新增系统模块导入 # import time import pickle +import sys # 新增系统模块导入 from pathlib import Path +import orjson + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from typing import Dict, Any, List, Optional, Type from dataclasses import dataclass, field from datetime import datetime +from typing import Any + +from peewee import Field, IntegrityError, Model from pymongo import MongoClient from pymongo.errors import ConnectionFailure -from peewee import Model, Field, IntegrityError # Rich 进度条和显示组件 from rich.console import Console +from rich.panel import Panel from rich.progress import ( - Progress, - TextColumn, BarColumn, - TaskProgressColumn, - TimeRemainingColumn, - TimeElapsedColumn, + Progress, SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, ) from rich.table import Table -from rich.panel import Panel -# from rich.text import Text +# from rich.text import Text from src.common.database.database import db from src.common.database.sqlalchemy_models import ( ChatStreams, Emoji, - Messages, - Images, - ImageDescriptions, - PersonInfo, - Knowledges, - ThinkingLog, - GraphNodes, GraphEdges, + GraphNodes, + ImageDescriptions, + Images, + Knowledges, + Messages, + PersonInfo, + ThinkingLog, ) from src.common.logger import get_logger @@ -54,12 +56,12 @@ class MigrationConfig: """迁移配置类""" mongo_collection: str - target_model: Type[Model] - field_mapping: Dict[str, str] + target_model: type[Model] + field_mapping: dict[str, str] batch_size: int = 500 enable_validation: bool = True skip_duplicates: bool = True - unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段 + unique_fields: list[str] = field(default_factory=list) # 用于重复检查的字段 # 数据验证相关类已移除 - 用户要求不要数据验证 @@ -73,7 +75,7 @@ class MigrationCheckpoint: processed_count: int last_processed_id: Any timestamp: datetime - batch_errors: List[Dict[str, Any]] = field(default_factory=list) + batch_errors: list[dict[str, Any]] = field(default_factory=list) @dataclass @@ -88,11 +90,11 @@ class MigrationStats: duplicate_count: int = 0 validation_errors: int = 0 batch_insert_count: int = 0 - errors: List[Dict[str, Any]] = field(default_factory=list) - start_time: Optional[datetime] = None - end_time: Optional[datetime] = None + errors: list[dict[str, Any]] = field(default_factory=list) + start_time: datetime | None = None + end_time: datetime | None = None - def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None): + def add_error(self, doc_id: Any, error: str, doc_data: dict | None = None): """添加错误记录""" self.errors.append( {"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data} @@ -108,10 +110,10 @@ class MigrationStats: class MongoToSQLiteMigrator: """MongoDB到SQLite数据迁移器 - 使用Peewee ORM""" - def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None): + def __init__(self, mongo_uri: str | None = None, database_name: str | None = None): self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot") self.mongo_uri = mongo_uri or self._build_mongo_uri() - self.mongo_client: Optional[MongoClient] = None + self.mongo_client: MongoClient | None = None self.mongo_db = None # 迁移配置 @@ -142,7 +144,7 @@ class MongoToSQLiteMigrator: else: return f"mongodb://{host}:{port}/{self.database_name}" - def _initialize_migration_configs(self) -> List[MigrationConfig]: + def _initialize_migration_configs(self) -> list[MigrationConfig]: """初始化迁移配置""" return [ # 表情包迁移配置 MigrationConfig( @@ -306,7 +308,7 @@ class MongoToSQLiteMigrator: ), ] - def _initialize_validation_rules(self) -> Dict[str, Any]: + def _initialize_validation_rules(self) -> dict[str, Any]: """数据验证已禁用 - 返回空字典""" return {} @@ -337,7 +339,7 @@ class MongoToSQLiteMigrator: self.mongo_client.close() logger.info("MongoDB连接已关闭") - def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any: + def _get_nested_value(self, document: dict[str, Any], field_path: str) -> Any: """获取嵌套字段的值""" if "." not in field_path: return document.get(field_path) @@ -434,7 +436,7 @@ class MongoToSQLiteMigrator: return None - def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool: + def _validate_data(self, collection_name: str, data: dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool: """数据验证已禁用 - 始终返回True""" return True @@ -454,7 +456,7 @@ class MongoToSQLiteMigrator: except Exception as e: logger.warning(f"保存断点失败: {e}") - def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]: + def _load_checkpoint(self, collection_name: str) -> MigrationCheckpoint | None: """加载迁移断点""" checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl" if not checkpoint_file.exists(): @@ -467,7 +469,7 @@ class MongoToSQLiteMigrator: logger.warning(f"加载断点失败: {e}") return None - def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int: + def _batch_insert(self, model: type[Model], data_list: list[dict[str, Any]]) -> int: """批量插入数据""" if not data_list: return 0 @@ -494,7 +496,7 @@ class MongoToSQLiteMigrator: return success_count def _check_duplicate_by_unique_fields( - self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str] + self, model: type[Model], data: dict[str, Any], unique_fields: list[str] ) -> bool: """根据唯一字段检查重复""" if not unique_fields: @@ -512,7 +514,7 @@ class MongoToSQLiteMigrator: logger.debug(f"重复检查失败: {e}") return False - def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]: + def _create_model_instance(self, model: type[Model], data: dict[str, Any]) -> Model | None: """使用ORM创建模型实例""" try: # 过滤掉不存在的字段 @@ -669,7 +671,7 @@ class MongoToSQLiteMigrator: return stats - def migrate_all(self) -> Dict[str, MigrationStats]: + def migrate_all(self) -> dict[str, MigrationStats]: """执行所有迁移任务""" logger.info("开始执行数据库迁移...") @@ -730,7 +732,7 @@ class MongoToSQLiteMigrator: self._print_migration_summary(all_stats) return all_stats - def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]): + def _print_migration_summary(self, all_stats: dict[str, MigrationStats]): """使用Rich打印美观的迁移汇总信息""" # 计算总体统计 total_processed = sum(stats.processed_count for stats in all_stats.values()) @@ -857,7 +859,7 @@ class MongoToSQLiteMigrator: """添加新的迁移配置""" self.migration_configs.append(config) - def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]: + def migrate_single_collection(self, collection_name: str) -> MigrationStats | None: """迁移单个指定的集合""" config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None) if not config: @@ -875,7 +877,7 @@ class MongoToSQLiteMigrator: finally: self.disconnect_mongodb() - def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str): + def export_error_report(self, all_stats: dict[str, MigrationStats], filepath: str): """导出错误报告""" error_report = { "timestamp": datetime.now().isoformat(), diff --git a/scripts/rebuild_metadata_index.py b/scripts/rebuild_metadata_index.py index d1990fecc..b4d786019 100644 --- a/scripts/rebuild_metadata_index.py +++ b/scripts/rebuild_metadata_index.py @@ -1,17 +1,16 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ 从现有ChromaDB数据重建JSON元数据索引 """ import asyncio -import sys import os +import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from src.chat.memory_system.memory_system import MemorySystem from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndexEntry +from src.chat.memory_system.memory_system import MemorySystem from src.common.logger import get_logger logger = get_logger(__name__) diff --git a/scripts/run_multi_stage_smoke.py b/scripts/run_multi_stage_smoke.py index 000336244..634f97210 100644 --- a/scripts/run_multi_stage_smoke.py +++ b/scripts/run_multi_stage_smoke.py @@ -1,12 +1,11 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ 轻量烟雾测试:初始化 MemorySystem 并运行一次检索,验证 MemoryMetadata.source 访问不再报错 """ import asyncio -import sys import os +import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) diff --git a/scripts/text_length_analysis.py b/scripts/text_length_analysis.py index 5a329b93c..818b5f6e1 100644 --- a/scripts/text_length_analysis.py +++ b/scripts/text_length_analysis.py @@ -1,8 +1,7 @@ -import time -import sys import os import re -from typing import Dict, List, Tuple, Optional +import sys +import time from datetime import datetime # Add project root to Python path @@ -63,7 +62,7 @@ def format_timestamp(timestamp: float) -> str: return "未知时间" -def calculate_text_length_distribution(messages) -> Dict[str, int]: +def calculate_text_length_distribution(messages) -> dict[str, int]: """Calculate distribution of processed_plain_text length""" distribution = { "0": 0, # 空文本 @@ -126,7 +125,7 @@ def calculate_text_length_distribution(messages) -> Dict[str, int]: return distribution -def get_text_length_stats(messages) -> Dict[str, float]: +def get_text_length_stats(messages) -> dict[str, float]: """Calculate basic statistics for processed_plain_text length""" lengths = [] null_count = 0 @@ -168,7 +167,7 @@ def get_text_length_stats(messages) -> Dict[str, float]: } -def get_available_chats() -> List[Tuple[str, str, int]]: +def get_available_chats() -> list[tuple[str, str, int]]: """Get all available chats with message counts""" try: # 获取所有有消息的chat_id,排除特殊类型消息 @@ -202,7 +201,7 @@ def get_available_chats() -> List[Tuple[str, str, int]]: return [] -def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: +def get_time_range_input() -> tuple[float | None, float | None]: """Get time range input from user""" print("\n时间范围选择:") print("1. 最近1天") @@ -241,7 +240,7 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: return None, None -def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]: +def get_top_longest_messages(messages, top_n: int = 10) -> list[tuple[str, int, str, str]]: """Get top N longest messages""" message_lengths = [] @@ -266,7 +265,7 @@ def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, def analyze_text_lengths( - chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None + chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None ) -> None: """Analyze processed_plain_text lengths with optional filters""" diff --git a/scripts/update_prompt_imports.py b/scripts/update_prompt_imports.py index 227491ec2..3917c9408 100644 --- a/scripts/update_prompt_imports.py +++ b/scripts/update_prompt_imports.py @@ -30,7 +30,7 @@ def update_prompt_imports(file_path): print(f"文件不存在: {file_path}") return False - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: content = f.read() # 替换导入语句 diff --git a/src/__init__.py b/src/__init__.py index bdb90be85..d23d01ddb 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,13 +1,15 @@ import random -from typing import List, Optional, Sequence -from colorama import init, Fore +from collections.abc import Sequence +from typing import List, Optional + +from colorama import Fore, init from src.common.logger import get_logger egg = get_logger("小彩蛋") -def weighted_choice(data: Sequence[str], weights: Optional[List[float]] = None) -> str: +def weighted_choice(data: Sequence[str], weights: list[float] | None = None) -> str: """ 从 data 中按权重随机返回一条。 若 weights 为 None,则所有元素权重默认为 1。 diff --git a/src/chat/__init__.py b/src/chat/__init__.py index a569c0226..2f7da45ce 100644 --- a/src/chat/__init__.py +++ b/src/chat/__init__.py @@ -3,8 +3,8 @@ MaiBot模块系统 包含聊天、情绪、记忆、日程等功能模块 """ -from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager +from src.chat.message_receive.chat_stream import get_chat_manager # 导出主要组件供外部使用 __all__ = [ diff --git a/src/chat/antipromptinjector/__init__.py b/src/chat/antipromptinjector/__init__.py index e5a672c86..fb45f006a 100644 --- a/src/chat/antipromptinjector/__init__.py +++ b/src/chat/antipromptinjector/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ MaiBot 反注入系统模块 @@ -14,25 +13,25 @@ MaiBot 反注入系统模块 """ from .anti_injector import AntiPromptInjector, get_anti_injector, initialize_anti_injector -from .types import DetectionResult, ProcessResult -from .core import PromptInjectionDetector, MessageShield -from .processors.message_processor import MessageProcessor -from .management import AntiInjectionStatistics, UserBanManager +from .core import MessageShield, PromptInjectionDetector from .decision import CounterAttackGenerator, ProcessingDecisionMaker +from .management import AntiInjectionStatistics, UserBanManager +from .processors.message_processor import MessageProcessor +from .types import DetectionResult, ProcessResult __all__ = [ + "AntiInjectionStatistics", "AntiPromptInjector", + "CounterAttackGenerator", + "DetectionResult", + "MessageProcessor", + "MessageShield", + "ProcessResult", + "ProcessingDecisionMaker", + "PromptInjectionDetector", + "UserBanManager", "get_anti_injector", "initialize_anti_injector", - "DetectionResult", - "ProcessResult", - "PromptInjectionDetector", - "MessageShield", - "MessageProcessor", - "AntiInjectionStatistics", - "UserBanManager", - "CounterAttackGenerator", - "ProcessingDecisionMaker", ] diff --git a/src/chat/antipromptinjector/anti_injector.py b/src/chat/antipromptinjector/anti_injector.py index b2c2e3232..23ff3a7ee 100644 --- a/src/chat/antipromptinjector/anti_injector.py +++ b/src/chat/antipromptinjector/anti_injector.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ LLM反注入系统主模块 @@ -12,15 +11,16 @@ LLM反注入系统主模块 """ import time -from typing import Optional, Tuple, Dict, Any +from typing import Any from src.common.logger import get_logger from src.config.config import global_config -from .types import ProcessResult -from .core import PromptInjectionDetector, MessageShield -from .processors.message_processor import MessageProcessor -from .management import AntiInjectionStatistics, UserBanManager + +from .core import MessageShield, PromptInjectionDetector from .decision import CounterAttackGenerator, ProcessingDecisionMaker +from .management import AntiInjectionStatistics, UserBanManager +from .processors.message_processor import MessageProcessor +from .types import ProcessResult logger = get_logger("anti_injector") @@ -43,7 +43,7 @@ class AntiPromptInjector: async def process_message( self, message_data: dict, chat_stream=None - ) -> Tuple[ProcessResult, Optional[str], Optional[str]]: + ) -> tuple[ProcessResult, str | None, str | None]: """处理字典格式的消息并返回结果 Args: @@ -102,7 +102,7 @@ class AntiPromptInjector: await self.statistics.update_stats(error_count=1) # 异常情况下直接阻止消息 - return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {str(e)}" + return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {e!s}" finally: # 更新处理时间统计 @@ -111,7 +111,7 @@ class AntiPromptInjector: async def _process_message_internal( self, text_to_detect: str, user_id: str, platform: str, processed_plain_text: str, start_time: float - ) -> Tuple[ProcessResult, Optional[str], Optional[str]]: + ) -> tuple[ProcessResult, str | None, str | None]: """内部消息处理逻辑(共用的检测核心)""" # 如果是纯引用消息,直接允许通过 @@ -218,7 +218,7 @@ class AntiPromptInjector: return ProcessResult.ALLOWED, None, "消息检查通过" async def handle_message_storage( - self, result: ProcessResult, modified_content: Optional[str], reason: str, message_data: dict + self, result: ProcessResult, modified_content: str | None, reason: str, message_data: dict ) -> None: """处理违禁消息的数据库存储,根据处理模式决定如何处理""" if result == ProcessResult.BLOCKED_INJECTION or result == ProcessResult.COUNTER_ATTACK: @@ -253,9 +253,10 @@ class AntiPromptInjector: async def _delete_message_from_storage(message_data: dict) -> None: """从数据库中删除违禁消息记录""" try: - from src.common.database.sqlalchemy_models import Messages, get_db_session from sqlalchemy import delete + from src.common.database.sqlalchemy_models import Messages, get_db_session + message_id = message_data.get("message_id") if not message_id: logger.warning("无法删除消息:缺少message_id") @@ -279,9 +280,10 @@ class AntiPromptInjector: async def _update_message_in_storage(message_data: dict, new_content: str) -> None: """更新数据库中的消息内容为加盾版本""" try: - from src.common.database.sqlalchemy_models import Messages, get_db_session from sqlalchemy import update + from src.common.database.sqlalchemy_models import Messages, get_db_session + message_id = message_data.get("message_id") if not message_id: logger.warning("无法更新消息:缺少message_id") @@ -305,7 +307,7 @@ class AntiPromptInjector: except Exception as e: logger.error(f"更新消息内容失败: {e}") - async def get_stats(self) -> Dict[str, Any]: + async def get_stats(self) -> dict[str, Any]: """获取统计信息""" return await self.statistics.get_stats() @@ -315,7 +317,7 @@ class AntiPromptInjector: # 全局反注入器实例 -_global_injector: Optional[AntiPromptInjector] = None +_global_injector: AntiPromptInjector | None = None def get_anti_injector() -> AntiPromptInjector: diff --git a/src/chat/antipromptinjector/core/__init__.py b/src/chat/antipromptinjector/core/__init__.py index f4087c4f3..5f751d823 100644 --- a/src/chat/antipromptinjector/core/__init__.py +++ b/src/chat/antipromptinjector/core/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 反注入系统核心检测模块 @@ -10,4 +9,4 @@ from .detector import PromptInjectionDetector from .shield import MessageShield -__all__ = ["PromptInjectionDetector", "MessageShield"] +__all__ = ["MessageShield", "PromptInjectionDetector"] diff --git a/src/chat/antipromptinjector/core/detector.py b/src/chat/antipromptinjector/core/detector.py index 39e65db8b..202c9bb5b 100644 --- a/src/chat/antipromptinjector/core/detector.py +++ b/src/chat/antipromptinjector/core/detector.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 提示词注入检测器模块 @@ -8,19 +7,19 @@ 3. 缓存机制优化性能 """ +import hashlib import re import time -import hashlib -from typing import Dict, List from dataclasses import asdict from src.common.logger import get_logger from src.config.config import global_config -from ..types import DetectionResult # 导入LLM API from src.plugin_system.apis import llm_api +from ..types import DetectionResult + logger = get_logger("anti_injector.detector") @@ -30,8 +29,8 @@ class PromptInjectionDetector: def __init__(self): """初始化检测器""" self.config = global_config.anti_prompt_injection - self._cache: Dict[str, DetectionResult] = {} - self._compiled_patterns: List[re.Pattern] = [] + self._cache: dict[str, DetectionResult] = {} + self._compiled_patterns: list[re.Pattern] = [] self._compile_patterns() def _compile_patterns(self): @@ -224,7 +223,7 @@ class PromptInjectionDetector: matched_patterns=[], processing_time=processing_time, detection_method="llm", - reason=f"LLM检测出错: {str(e)}", + reason=f"LLM检测出错: {e!s}", ) @staticmethod @@ -250,7 +249,7 @@ class PromptInjectionDetector: 请客观分析,避免误判正常对话。""" @staticmethod - def _parse_llm_response(response: str) -> Dict: + def _parse_llm_response(response: str) -> dict: """解析LLM响应""" try: lines = response.strip().split("\n") @@ -280,7 +279,7 @@ class PromptInjectionDetector: except Exception as e: logger.error(f"解析LLM响应失败: {e}") - return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"} + return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"} async def detect(self, message: str) -> DetectionResult: """执行检测""" @@ -331,7 +330,7 @@ class PromptInjectionDetector: return final_result - def _merge_results(self, results: List[DetectionResult]) -> DetectionResult: + def _merge_results(self, results: list[DetectionResult]) -> DetectionResult: """合并多个检测结果""" if not results: return DetectionResult(reason="无检测结果") @@ -384,7 +383,7 @@ class PromptInjectionDetector: if expired_keys: logger.debug(f"清理了{len(expired_keys)}个过期缓存项") - def get_cache_stats(self) -> Dict: + def get_cache_stats(self) -> dict: """获取缓存统计信息""" return { "cache_size": len(self._cache), diff --git a/src/chat/antipromptinjector/core/shield.py b/src/chat/antipromptinjector/core/shield.py index c7a2e78bc..399ec9025 100644 --- a/src/chat/antipromptinjector/core/shield.py +++ b/src/chat/antipromptinjector/core/shield.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 消息加盾模块 @@ -6,8 +5,6 @@ 主要通过注入系统提示词来指导AI安全响应。 """ -from typing import List - from src.common.logger import get_logger from src.config.config import global_config @@ -35,7 +32,7 @@ class MessageShield: return SAFETY_SYSTEM_PROMPT @staticmethod - def is_shield_needed(confidence: float, matched_patterns: List[str]) -> bool: + def is_shield_needed(confidence: float, matched_patterns: list[str]) -> bool: """判断是否需要加盾 Args: @@ -60,7 +57,7 @@ class MessageShield: return False @staticmethod - def create_safety_summary(confidence: float, matched_patterns: List[str]) -> str: + def create_safety_summary(confidence: float, matched_patterns: list[str]) -> str: """创建安全处理摘要 Args: diff --git a/src/chat/antipromptinjector/counter_attack.py b/src/chat/antipromptinjector/counter_attack.py index 7c2bd86c5..2a094e419 100644 --- a/src/chat/antipromptinjector/counter_attack.py +++ b/src/chat/antipromptinjector/counter_attack.py @@ -1,15 +1,13 @@ -# -*- coding: utf-8 -*- """ 反击消息生成模块 负责生成个性化的反击消息回应提示词注入攻击 """ -from typing import Optional - from src.common.logger import get_logger from src.config.config import global_config from src.plugin_system.apis import llm_api + from .types import DetectionResult logger = get_logger("anti_injector.counter_attack") @@ -55,7 +53,7 @@ class CounterAttackGenerator: async def generate_counter_attack_message( self, original_message: str, detection_result: DetectionResult - ) -> Optional[str]: + ) -> str | None: """生成反击消息 Args: diff --git a/src/chat/antipromptinjector/decision/__init__.py b/src/chat/antipromptinjector/decision/__init__.py index 5778ca4ed..358147066 100644 --- a/src/chat/antipromptinjector/decision/__init__.py +++ b/src/chat/antipromptinjector/decision/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 反注入系统决策模块 @@ -7,7 +6,7 @@ - counter_attack: 反击消息生成器 """ -from .decision_maker import ProcessingDecisionMaker from .counter_attack import CounterAttackGenerator +from .decision_maker import ProcessingDecisionMaker -__all__ = ["ProcessingDecisionMaker", "CounterAttackGenerator"] +__all__ = ["CounterAttackGenerator", "ProcessingDecisionMaker"] diff --git a/src/chat/antipromptinjector/decision/counter_attack.py b/src/chat/antipromptinjector/decision/counter_attack.py index 9d6aac2ff..ad305b9c4 100644 --- a/src/chat/antipromptinjector/decision/counter_attack.py +++ b/src/chat/antipromptinjector/decision/counter_attack.py @@ -1,15 +1,13 @@ -# -*- coding: utf-8 -*- """ 反击消息生成模块 负责生成个性化的反击消息回应提示词注入攻击 """ -from typing import Optional - from src.common.logger import get_logger from src.config.config import global_config from src.plugin_system.apis import llm_api + from ..types import DetectionResult logger = get_logger("anti_injector.counter_attack") @@ -55,7 +53,7 @@ class CounterAttackGenerator: async def generate_counter_attack_message( self, original_message: str, detection_result: DetectionResult - ) -> Optional[str]: + ) -> str | None: """生成反击消息 Args: diff --git a/src/chat/antipromptinjector/decision/decision_maker.py b/src/chat/antipromptinjector/decision/decision_maker.py index 12a2c95b5..be3d3ccfb 100644 --- a/src/chat/antipromptinjector/decision/decision_maker.py +++ b/src/chat/antipromptinjector/decision/decision_maker.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 处理决策器模块 @@ -6,6 +5,7 @@ """ from src.common.logger import get_logger + from ..types import DetectionResult logger = get_logger("anti_injector.decision_maker") diff --git a/src/chat/antipromptinjector/decision_maker.py b/src/chat/antipromptinjector/decision_maker.py index 972253fab..893da059f 100644 --- a/src/chat/antipromptinjector/decision_maker.py +++ b/src/chat/antipromptinjector/decision_maker.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 处理决策器模块 @@ -6,6 +5,7 @@ """ from src.common.logger import get_logger + from .types import DetectionResult logger = get_logger("anti_injector.decision_maker") diff --git a/src/chat/antipromptinjector/detector.py b/src/chat/antipromptinjector/detector.py index 6c1e3b4bd..59d1132b1 100644 --- a/src/chat/antipromptinjector/detector.py +++ b/src/chat/antipromptinjector/detector.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 提示词注入检测器模块 @@ -8,19 +7,19 @@ 3. 缓存机制优化性能 """ +import hashlib import re import time -import hashlib -from typing import Dict, List from dataclasses import asdict from src.common.logger import get_logger from src.config.config import global_config -from .types import DetectionResult # 导入LLM API from src.plugin_system.apis import llm_api +from .types import DetectionResult + logger = get_logger("anti_injector.detector") @@ -30,8 +29,8 @@ class PromptInjectionDetector: def __init__(self): """初始化检测器""" self.config = global_config.anti_prompt_injection - self._cache: Dict[str, DetectionResult] = {} - self._compiled_patterns: List[re.Pattern] = [] + self._cache: dict[str, DetectionResult] = {} + self._compiled_patterns: list[re.Pattern] = [] self._compile_patterns() def _compile_patterns(self): @@ -221,7 +220,7 @@ class PromptInjectionDetector: matched_patterns=[], processing_time=processing_time, detection_method="llm", - reason=f"LLM检测出错: {str(e)}", + reason=f"LLM检测出错: {e!s}", ) @staticmethod @@ -247,7 +246,7 @@ class PromptInjectionDetector: 请客观分析,避免误判正常对话。""" @staticmethod - def _parse_llm_response(response: str) -> Dict: + def _parse_llm_response(response: str) -> dict: """解析LLM响应""" try: lines = response.strip().split("\n") @@ -277,7 +276,7 @@ class PromptInjectionDetector: except Exception as e: logger.error(f"解析LLM响应失败: {e}") - return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"} + return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"} async def detect(self, message: str) -> DetectionResult: """执行检测""" @@ -328,7 +327,7 @@ class PromptInjectionDetector: return final_result - def _merge_results(self, results: List[DetectionResult]) -> DetectionResult: + def _merge_results(self, results: list[DetectionResult]) -> DetectionResult: """合并多个检测结果""" if not results: return DetectionResult(reason="无检测结果") @@ -381,7 +380,7 @@ class PromptInjectionDetector: if expired_keys: logger.debug(f"清理了{len(expired_keys)}个过期缓存项") - def get_cache_stats(self) -> Dict: + def get_cache_stats(self) -> dict: """获取缓存统计信息""" return { "cache_size": len(self._cache), diff --git a/src/chat/antipromptinjector/management/__init__.py b/src/chat/antipromptinjector/management/__init__.py index eaef392c4..28b1bcee2 100644 --- a/src/chat/antipromptinjector/management/__init__.py +++ b/src/chat/antipromptinjector/management/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 反注入系统管理模块 diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index 9d44faa78..0525754f1 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 反注入系统统计模块 @@ -6,12 +5,12 @@ """ import datetime -from typing import Dict, Any +from typing import Any from sqlalchemy import select -from src.common.logger import get_logger from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session +from src.common.logger import get_logger from src.config.config import global_config logger = get_logger("anti_injector.statistics") @@ -94,7 +93,7 @@ class AntiInjectionStatistics: except Exception as e: logger.error(f"更新统计数据失败: {e}") - async def get_stats(self) -> Dict[str, Any]: + async def get_stats(self) -> dict[str, Any]: """获取统计信息""" try: # 检查反注入系统是否启用 diff --git a/src/chat/antipromptinjector/management/user_ban.py b/src/chat/antipromptinjector/management/user_ban.py index b965a08af..f1b82a8dc 100644 --- a/src/chat/antipromptinjector/management/user_ban.py +++ b/src/chat/antipromptinjector/management/user_ban.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 用户封禁管理模块 @@ -6,12 +5,12 @@ """ import datetime -from typing import Optional, Tuple from sqlalchemy import select -from src.common.logger import get_logger from src.common.database.sqlalchemy_models import BanUser, get_db_session +from src.common.logger import get_logger + from ..types import DetectionResult logger = get_logger("anti_injector.user_ban") @@ -28,7 +27,7 @@ class UserBanManager: """ self.config = config - async def check_user_ban(self, user_id: str, platform: str) -> Optional[Tuple[bool, Optional[str], str]]: + async def check_user_ban(self, user_id: str, platform: str) -> tuple[bool, str | None, str] | None: """检查用户是否被封禁 Args: diff --git a/src/chat/antipromptinjector/processors/__init__.py b/src/chat/antipromptinjector/processors/__init__.py index 1db74557f..40de37df9 100644 --- a/src/chat/antipromptinjector/processors/__init__.py +++ b/src/chat/antipromptinjector/processors/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 反注入系统消息处理模块 diff --git a/src/chat/antipromptinjector/processors/message_processor.py b/src/chat/antipromptinjector/processors/message_processor.py index 935848c2d..0e37efc0d 100644 --- a/src/chat/antipromptinjector/processors/message_processor.py +++ b/src/chat/antipromptinjector/processors/message_processor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 消息内容处理模块 @@ -6,10 +5,9 @@ """ import re -from typing import Optional -from src.common.logger import get_logger from src.chat.message_receive.message import MessageRecv +from src.common.logger import get_logger logger = get_logger("anti_injector.message_processor") @@ -66,7 +64,7 @@ class MessageProcessor: return new_content @staticmethod - def check_whitelist(message: MessageRecv, whitelist: list) -> Optional[tuple]: + def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None: """检查用户白名单 Args: diff --git a/src/chat/antipromptinjector/types.py b/src/chat/antipromptinjector/types.py index 81d775ffc..ac436cc90 100644 --- a/src/chat/antipromptinjector/types.py +++ b/src/chat/antipromptinjector/types.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 反注入系统数据类型定义模块 @@ -10,7 +9,6 @@ """ import time -from typing import List, Optional from dataclasses import dataclass, field from enum import Enum @@ -31,8 +29,8 @@ class DetectionResult: is_injection: bool = False confidence: float = 0.0 - matched_patterns: List[str] = field(default_factory=list) - llm_analysis: Optional[str] = None + matched_patterns: list[str] = field(default_factory=list) + llm_analysis: str | None = None processing_time: float = 0.0 detection_method: str = "unknown" reason: str = "" diff --git a/src/chat/chatter_manager.py b/src/chat/chatter_manager.py index d22d39440..d8eda9baa 100644 --- a/src/chat/chatter_manager.py +++ b/src/chat/chatter_manager.py @@ -1,10 +1,11 @@ -from typing import Dict, List, Optional, Any import time -from src.plugin_system.base.base_chatter import BaseChatter -from src.common.data_models.message_manager_data_model import StreamContext +from typing import Any + from src.chat.planner_actions.action_manager import ChatterActionManager -from src.plugin_system.base.component_types import ChatType +from src.common.data_models.message_manager_data_model import StreamContext from src.common.logger import get_logger +from src.plugin_system.base.base_chatter import BaseChatter +from src.plugin_system.base.component_types import ChatType logger = get_logger("chatter_manager") @@ -12,8 +13,8 @@ logger = get_logger("chatter_manager") class ChatterManager: def __init__(self, action_manager: ChatterActionManager): self.action_manager = action_manager - self.chatter_classes: Dict[ChatType, List[type]] = {} - self.instances: Dict[str, BaseChatter] = {} + self.chatter_classes: dict[ChatType, list[type]] = {} + self.instances: dict[str, BaseChatter] = {} # 管理器统计 self.stats = { @@ -46,21 +47,21 @@ class ChatterManager: self.stats["chatters_registered"] += 1 - def get_chatter_class(self, chat_type: ChatType) -> Optional[type]: + def get_chatter_class(self, chat_type: ChatType) -> type | None: """获取指定聊天类型的聊天处理器类""" if chat_type in self.chatter_classes: return self.chatter_classes[chat_type][0] return None - def get_supported_chat_types(self) -> List[ChatType]: + def get_supported_chat_types(self) -> list[ChatType]: """获取支持的聊天类型列表""" return list(self.chatter_classes.keys()) - def get_registered_chatters(self) -> Dict[ChatType, List[type]]: + def get_registered_chatters(self) -> dict[ChatType, list[type]]: """获取已注册的聊天处理器""" return self.chatter_classes.copy() - def get_stream_instance(self, stream_id: str) -> Optional[BaseChatter]: + def get_stream_instance(self, stream_id: str) -> BaseChatter | None: """获取指定流的聊天处理器实例""" return self.instances.get(stream_id) @@ -139,7 +140,7 @@ class ChatterManager: logger.error(f"处理流 {stream_id} 时发生错误: {e}") raise - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: """获取管理器统计信息""" stats = self.stats.copy() stats["active_instances"] = len(self.instances) diff --git a/src/chat/emoji_system/emoji_history.py b/src/chat/emoji_system/emoji_history.py index dadd152a1..0e7d6a6e1 100644 --- a/src/chat/emoji_system/emoji_history.py +++ b/src/chat/emoji_system/emoji_history.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- """ 表情包发送历史记录模块 """ -from typing import List, Dict from collections import deque from src.common.logger import get_logger @@ -14,7 +12,7 @@ MAX_HISTORY_SIZE = 5 # 每个聊天会话最多保留最近5条表情历史 # 使用一个全局字典在内存中存储历史记录 # 键是 chat_id,值是一个 deque 对象 -_history_cache: Dict[str, deque] = {} +_history_cache: dict[str, deque] = {} def add_emoji_to_history(chat_id: str, emoji_description: str): @@ -38,7 +36,7 @@ def add_emoji_to_history(chat_id: str, emoji_description: str): logger.debug(f"已将表情 '{emoji_description}' 添加到聊天 {chat_id} 的内存历史中") -def get_recent_emojis(chat_id: str, limit: int = 5) -> List[str]: +def get_recent_emojis(chat_id: str, limit: int = 5) -> list[str]: """ 从内存中获取最近发送的表情包描述列表。 diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index cd472ec0c..62552a201 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -1,23 +1,24 @@ import asyncio import base64 +import binascii import hashlib +import io import os import random +import re import time import traceback -import io -import re -import binascii +from typing import Any, Optional -from typing import Optional, Tuple, List, Any from PIL import Image from rich.traceback import install from sqlalchemy import select + +from src.chat.utils.utils_image import get_image_manager, image_path_to_base64 from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import Emoji, Images from src.common.logger import get_logger from src.config.config import global_config, model_config -from src.chat.utils.utils_image import image_path_to_base64, get_image_manager from src.llm_models.utils_model import LLMRequest install(extra_lines=3) @@ -47,14 +48,14 @@ class MaiEmoji: self.embedding = [] self.hash = "" # 初始为空,在创建实例时会计算 self.description = "" - self.emotion: List[str] = [] + self.emotion: list[str] = [] self.usage_count = 0 self.last_used_time = time.time() self.register_time = time.time() self.is_deleted = False # 标记是否已被删除 self.format = "" - async def initialize_hash_format(self) -> Optional[bool]: + async def initialize_hash_format(self) -> bool | None: """从文件创建表情包实例, 计算哈希值和格式""" try: # 使用 full_path 检查文件是否存在 @@ -105,7 +106,7 @@ class MaiEmoji: self.is_deleted = True return None except Exception as e: - logger.error(f"[初始化错误] 初始化表情包时发生未预期错误 ({self.filename}): {str(e)}") + logger.error(f"[初始化错误] 初始化表情包时发生未预期错误 ({self.filename}): {e!s}") logger.error(traceback.format_exc()) self.is_deleted = True return None @@ -142,7 +143,7 @@ class MaiEmoji: self.path = EMOJI_REGISTERED_DIR # self.filename 保持不变 except Exception as move_error: - logger.error(f"[错误] 移动文件失败: {str(move_error)}") + logger.error(f"[错误] 移动文件失败: {move_error!s}") # 如果移动失败,尝试将实例状态恢复?暂时不处理,仅返回失败 return False @@ -174,11 +175,11 @@ class MaiEmoji: return True except Exception as db_error: - logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}") + logger.error(f"[错误] 保存数据库失败 ({self.filename}): {db_error!s}") return False except Exception as e: - logger.error(f"[错误] 注册表情包失败 ({self.filename}): {str(e)}") + logger.error(f"[错误] 注册表情包失败 ({self.filename}): {e!s}") logger.error(traceback.format_exc()) return False @@ -198,7 +199,7 @@ class MaiEmoji: os.remove(file_to_delete) logger.debug(f"[删除] 文件: {file_to_delete}") except Exception as e: - logger.error(f"[错误] 删除文件失败 {file_to_delete}: {str(e)}") + logger.error(f"[错误] 删除文件失败 {file_to_delete}: {e!s}") # 文件删除失败,但仍然尝试删除数据库记录 # 2. 删除数据库记录 @@ -214,7 +215,7 @@ class MaiEmoji: result = 1 # Successfully deleted one record await session.commit() except Exception as e: - logger.error(f"[错误] 删除数据库记录时出错: {str(e)}") + logger.error(f"[错误] 删除数据库记录时出错: {e!s}") result = 0 if result > 0: @@ -233,11 +234,11 @@ class MaiEmoji: return False except Exception as e: - logger.error(f"[错误] 删除表情包失败 ({self.filename}): {str(e)}") + logger.error(f"[错误] 删除表情包失败 ({self.filename}): {e!s}") return False -def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]: +def _emoji_objects_to_readable_list(emoji_objects: list["MaiEmoji"]) -> list[str]: """将表情包对象列表转换为可读的字符串列表 参数: @@ -256,7 +257,7 @@ def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str return emoji_info_list -def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]: +def _to_emoji_objects(data: Any) -> tuple[list["MaiEmoji"], int]: emoji_objects = [] load_errors = 0 emoji_data_list = list(data) @@ -300,7 +301,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]: logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}") load_errors += 1 except Exception as e: - logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {str(e)}") + logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {e!s}") load_errors += 1 return emoji_objects, load_errors @@ -335,7 +336,7 @@ async def clear_temp_emoji() -> None: logger.debug(f"[清理] 删除: {filename}") -async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], removed_count: int) -> int: +async def clean_unused_emojis(emoji_dir: str, emoji_objects: list["MaiEmoji"], removed_count: int) -> int: """清理指定目录中未被 emoji_objects 追踪的表情包文件""" if not os.path.exists(emoji_dir): logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") @@ -361,7 +362,7 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r logger.info(f"[清理] 删除未追踪的表情包文件: {file_full_path}") cleaned_count += 1 except Exception as e: - logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {str(e)}") + logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {e!s}") if cleaned_count > 0: logger.info(f"[清理] 在目录 {emoji_dir} 中清理了 {cleaned_count} 个破损表情包。") @@ -369,7 +370,7 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。") except Exception as e: - logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}") + logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {e!s}") return removed_count + cleaned_count @@ -437,9 +438,9 @@ class EmojiManager: emoji_update.last_used_time = time.time() # Update last used time await session.commit() except Exception as e: - logger.error(f"记录表情使用失败: {str(e)}") + logger.error(f"记录表情使用失败: {e!s}") - async def get_emoji_for_text(self, text_emotion: str) -> Optional[Tuple[str, str, str]]: + async def get_emoji_for_text(self, text_emotion: str) -> tuple[str, str, str] | None: """ 根据文本内容,使用LLM选择一个合适的表情包。 @@ -531,7 +532,7 @@ class EmojiManager: return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion except Exception as e: - logger.error(f"使用LLM获取表情包时发生错误: {str(e)}") + logger.error(f"使用LLM获取表情包时发生错误: {e!s}") logger.error(traceback.format_exc()) return None @@ -578,7 +579,7 @@ class EmojiManager: continue except Exception as item_error: - logger.error(f"[错误] 处理表情包记录时出错 ({emoji.filename}): {str(item_error)}") + logger.error(f"[错误] 处理表情包记录时出错 ({emoji.filename}): {item_error!s}") # 即使出错,也尝试继续检查下一个 continue @@ -597,7 +598,7 @@ class EmojiManager: logger.info(f"[检查] 已检查 {total_count} 个表情包记录,全部完好") except Exception as e: - logger.error(f"[错误] 检查表情包完整性失败: {str(e)}") + logger.error(f"[错误] 检查表情包完整性失败: {e!s}") logger.error(traceback.format_exc()) async def start_periodic_check_register(self) -> None: @@ -651,7 +652,7 @@ class EmojiManager: os.remove(file_path) logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}") except Exception as e: - logger.error(f"[错误] 扫描表情包目录失败: {str(e)}") + logger.error(f"[错误] 扫描表情包目录失败: {e!s}") await asyncio.sleep(global_config.emoji.check_interval * 60) @@ -674,11 +675,11 @@ class EmojiManager: logger.warning(f"[数据库] 加载过程中出现 {load_errors} 个错误。") except Exception as e: - logger.error(f"[错误] 从数据库加载所有表情包对象失败: {str(e)}") + logger.error(f"[错误] 从数据库加载所有表情包对象失败: {e!s}") self.emoji_objects = [] # 加载失败则清空列表 self.emoji_num = 0 - async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]: + async def get_emoji_from_db(self, emoji_hash: str | None = None) -> list["MaiEmoji"]: """获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找) 参数: @@ -707,7 +708,7 @@ class EmojiManager: return emoji_objects except Exception as e: - logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}") + logger.error(f"[错误] 从数据库获取表情包对象失败: {e!s}") return [] async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]: @@ -725,7 +726,7 @@ class EmojiManager: return emoji return None # 如果循环结束还没找到,则返回 None - async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]: + async def get_emoji_tag_by_hash(self, emoji_hash: str) -> str | None: """根据哈希值获取已注册表情包的描述 Args: @@ -753,10 +754,10 @@ class EmojiManager: return None except Exception as e: - logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}") + logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {e!s}") return None - async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]: + async def get_emoji_description_by_hash(self, emoji_hash: str) -> str | None: """根据哈希值获取已注册表情包的描述 Args: @@ -787,7 +788,7 @@ class EmojiManager: return None except Exception as e: - logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}") + logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {e!s}") return None async def delete_emoji(self, emoji_hash: str) -> bool: @@ -823,7 +824,7 @@ class EmojiManager: return False except Exception as e: - logger.error(f"[错误] 删除表情包失败: {str(e)}") + logger.error(f"[错误] 删除表情包失败: {e!s}") logger.error(traceback.format_exc()) return False @@ -909,11 +910,11 @@ class EmojiManager: return False except Exception as e: - logger.error(f"[错误] 替换表情包失败: {str(e)}") + logger.error(f"[错误] 替换表情包失败: {e!s}") logger.error(traceback.format_exc()) return False - async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]: + async def build_emoji_description(self, image_base64: str) -> tuple[str, list[str]]: """ 获取表情包的详细描述和情感关键词列表。 @@ -976,14 +977,14 @@ class EmojiManager: # 4. 内容审核,确保表情包符合规定 if global_config.emoji.content_filtration: - prompt = f''' + prompt = f""" 请根据以下标准审核这个表情包: 1. 主题必须符合:"{global_config.emoji.filtration_prompt}"。 2. 内容健康,不含色情、暴力、政治敏感等元素。 3. 必须是表情包,而不是普通的聊天截图或视频截图。 4. 表情包中的文字数量(如果有)不能超过5个。 这个表情包是否完全满足以上所有要求?请只回答“是”或“否”。 - ''' + """ content, _ = await self.vlm.generate_response_for_image( prompt, image_base64, image_format, temperature=0.1, max_tokens=10 ) @@ -1023,7 +1024,7 @@ class EmojiManager: return final_description, emotions except Exception as e: - logger.error(f"构建表情包描述时发生严重错误: {str(e)}") + logger.error(f"构建表情包描述时发生严重错误: {e!s}") logger.error(traceback.format_exc()) return "", [] @@ -1058,7 +1059,7 @@ class EmojiManager: os.remove(file_full_path) logger.info(f"[清理] 删除重复的待注册文件: {filename}") except Exception as e: - logger.error(f"[错误] 删除重复文件失败: {str(e)}") + logger.error(f"[错误] 删除重复文件失败: {e!s}") return False # 返回 False 表示未注册新表情 # 3. 构建描述和情感 @@ -1075,7 +1076,7 @@ class EmojiManager: os.remove(file_full_path) logger.info(f"[清理] 删除描述生成失败的文件: {filename}") except Exception as e: - logger.error(f"[错误] 删除描述生成失败文件时出错: {str(e)}") + logger.error(f"[错误] 删除描述生成失败文件时出错: {e!s}") return False new_emoji.description = description new_emoji.emotion = emotions @@ -1086,7 +1087,7 @@ class EmojiManager: os.remove(file_full_path) logger.info(f"[清理] 删除描述生成异常的文件: {filename}") except Exception as e: - logger.error(f"[错误] 删除描述生成异常文件时出错: {str(e)}") + logger.error(f"[错误] 删除描述生成异常文件时出错: {e!s}") return False # 4. 检查容量并决定是否替换或直接注册 @@ -1100,7 +1101,7 @@ class EmojiManager: os.remove(file_full_path) # new_emoji 的 full_path 此时还是源路径 logger.info(f"[清理] 删除替换失败的新表情文件: {filename}") except Exception as e: - logger.error(f"[错误] 删除替换失败文件时出错: {str(e)}") + logger.error(f"[错误] 删除替换失败文件时出错: {e!s}") return False # 替换成功时,replace_a_emoji 内部已处理 new_emoji 的注册和添加到列表 return True @@ -1122,11 +1123,11 @@ class EmojiManager: os.remove(file_full_path) logger.info(f"[清理] 删除注册失败的源文件: {filename}") except Exception as e: - logger.error(f"[错误] 删除注册失败源文件时出错: {str(e)}") + logger.error(f"[错误] 删除注册失败源文件时出错: {e!s}") return False except Exception as e: - logger.error(f"[错误] 注册表情包时发生未预期错误 ({filename}): {str(e)}") + logger.error(f"[错误] 注册表情包时发生未预期错误 ({filename}): {e!s}") logger.error(traceback.format_exc()) # 尝试删除源文件以避免循环处理 if os.path.exists(file_full_path): diff --git a/src/chat/energy_system/__init__.py b/src/chat/energy_system/__init__.py index 6cdf96da5..570e183e6 100644 --- a/src/chat/energy_system/__init__.py +++ b/src/chat/energy_system/__init__.py @@ -4,24 +4,24 @@ """ from .energy_manager import ( - EnergyManager, - EnergyLevel, - EnergyComponent, - EnergyCalculator, - InterestEnergyCalculator, ActivityEnergyCalculator, + EnergyCalculator, + EnergyComponent, + EnergyLevel, + EnergyManager, + InterestEnergyCalculator, RecencyEnergyCalculator, RelationshipEnergyCalculator, energy_manager, ) __all__ = [ - "EnergyManager", - "EnergyLevel", - "EnergyComponent", - "EnergyCalculator", - "InterestEnergyCalculator", "ActivityEnergyCalculator", + "EnergyCalculator", + "EnergyComponent", + "EnergyLevel", + "EnergyManager", + "InterestEnergyCalculator", "RecencyEnergyCalculator", "RelationshipEnergyCalculator", "energy_manager", diff --git a/src/chat/energy_system/energy_manager.py b/src/chat/energy_system/energy_manager.py index 4a92349bf..0bfb6fc4f 100644 --- a/src/chat/energy_system/energy_manager.py +++ b/src/chat/energy_system/energy_manager.py @@ -4,10 +4,10 @@ """ import time -from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict +from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum -from abc import ABC, abstractmethod +from typing import Any, TypedDict from src.common.logger import get_logger from src.config.config import global_config @@ -51,8 +51,8 @@ class EnergyContext(TypedDict): """能量计算上下文""" stream_id: str - messages: List[Any] - user_id: Optional[str] + messages: list[Any] + user_id: str | None class EnergyResult(TypedDict): @@ -61,7 +61,7 @@ class EnergyResult(TypedDict): energy: float level: EnergyLevel distribution_interval: float - component_scores: Dict[str, float] + component_scores: dict[str, float] cached: bool @@ -69,7 +69,7 @@ class EnergyCalculator(ABC): """能量计算器抽象基类""" @abstractmethod - def calculate(self, context: Dict[str, Any]) -> float: + def calculate(self, context: dict[str, Any]) -> float: """计算能量值""" pass @@ -82,7 +82,7 @@ class EnergyCalculator(ABC): class InterestEnergyCalculator(EnergyCalculator): """兴趣度能量计算器""" - def calculate(self, context: Dict[str, Any]) -> float: + def calculate(self, context: dict[str, Any]) -> float: """基于消息兴趣度计算能量""" messages = context.get("messages", []) if not messages: @@ -120,7 +120,7 @@ class ActivityEnergyCalculator(EnergyCalculator): def __init__(self): self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1} - def calculate(self, context: Dict[str, Any]) -> float: + def calculate(self, context: dict[str, Any]) -> float: """基于活跃度计算能量""" messages = context.get("messages", []) if not messages: @@ -150,7 +150,7 @@ class ActivityEnergyCalculator(EnergyCalculator): class RecencyEnergyCalculator(EnergyCalculator): """最近性能量计算器""" - def calculate(self, context: Dict[str, Any]) -> float: + def calculate(self, context: dict[str, Any]) -> float: """基于最近性计算能量""" messages = context.get("messages", []) if not messages: @@ -197,7 +197,7 @@ class RecencyEnergyCalculator(EnergyCalculator): class RelationshipEnergyCalculator(EnergyCalculator): """关系能量计算器""" - async def calculate(self, context: Dict[str, Any]) -> float: + async def calculate(self, context: dict[str, Any]) -> float: """基于关系计算能量""" user_id = context.get("user_id") if not user_id: @@ -223,7 +223,7 @@ class EnergyManager: """能量管理器 - 统一管理所有能量计算""" def __init__(self) -> None: - self.calculators: List[EnergyCalculator] = [ + self.calculators: list[EnergyCalculator] = [ InterestEnergyCalculator(), ActivityEnergyCalculator(), RecencyEnergyCalculator(), @@ -231,14 +231,14 @@ class EnergyManager: ] # 能量缓存 - self.energy_cache: Dict[str, Tuple[float, float]] = {} # stream_id -> (energy, timestamp) + self.energy_cache: dict[str, tuple[float, float]] = {} # stream_id -> (energy, timestamp) self.cache_ttl: int = 60 # 1分钟缓存 # AFC阈值配置 - self.thresholds: Dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2} + self.thresholds: dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2} # 统计信息 - self.stats: Dict[str, Union[int, float, str]] = { + self.stats: dict[str, int | float | str] = { "total_calculations": 0, "cache_hits": 0, "cache_misses": 0, @@ -272,7 +272,7 @@ class EnergyManager: except Exception as e: logger.warning(f"加载AFC阈值失败,使用默认值: {e}") - async def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float: + async def calculate_focus_energy(self, stream_id: str, messages: list[Any], user_id: str | None = None) -> float: """计算聊天流的focus_energy""" start_time = time.time() @@ -297,7 +297,7 @@ class EnergyManager: } # 计算各组件能量 - component_scores: Dict[str, float] = {} + component_scores: dict[str, float] = {} total_weight = 0.0 for calculator in self.calculators: @@ -437,7 +437,7 @@ class EnergyManager: if expired_keys: logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存") - def get_statistics(self) -> Dict[str, Any]: + def get_statistics(self) -> dict[str, Any]: """获取统计信息""" return { "cache_size": len(self.energy_cache), @@ -446,7 +446,7 @@ class EnergyManager: "performance_stats": self.stats.copy(), } - def update_thresholds(self, new_thresholds: Dict[str, float]) -> None: + def update_thresholds(self, new_thresholds: dict[str, float]) -> None: """更新阈值""" self.thresholds.update(new_thresholds) diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 596322ebd..f9e0e68af 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -1,21 +1,20 @@ -import time -import random -import orjson import os +import random +import time from datetime import datetime +from typing import Any -from typing import List, Dict, Optional, Any, Tuple - -from src.common.logger import get_logger -from src.common.database.sqlalchemy_database_api import get_db_session +import orjson from sqlalchemy import select -from src.common.database.sqlalchemy_models import Expression -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config, global_config -from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.chat_message_builder import build_anonymous_messages, get_raw_msg_by_timestamp_with_chat_inclusive +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import Expression +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest MAX_EXPRESSION_COUNT = 300 DECAY_DAYS = 30 # 30天衰减到0.01 @@ -193,7 +192,7 @@ class ExpressionLearner: logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}") return False - async def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]: + async def get_expression_by_chat_id(self) -> tuple[list[dict[str, float]], list[dict[str, float]]]: """ 获取指定chat_id的style和grammar表达方式 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作 @@ -341,7 +340,7 @@ class ExpressionLearner: return [] # 按chat_id分组 - chat_dict: Dict[str, List[Dict[str, Any]]] = {} + chat_dict: dict[str, list[dict[str, Any]]] = {} for chat_id, situation, style in learnt_expressions: if chat_id not in chat_dict: chat_dict[chat_id] = [] @@ -398,7 +397,7 @@ class ExpressionLearner: return learnt_expressions return None - async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]: + async def learn_expression(self, type: str, num: int = 10) -> tuple[list[tuple[str, str, str]], str] | None: """从指定聊天流学习表达方式 Args: @@ -416,7 +415,7 @@ class ExpressionLearner: current_time = time.time() # 获取上次学习时间 - random_msg: Optional[List[Dict[str, Any]]] = await get_raw_msg_by_timestamp_with_chat_inclusive( + random_msg: list[dict[str, Any]] | None = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_learning_time, timestamp_end=current_time, @@ -447,16 +446,16 @@ class ExpressionLearner: logger.debug(f"学习{type_str}的response: {response}") - expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id) + expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id) return expressions, chat_id @staticmethod - def parse_expression_response(response: str, chat_id: str) -> List[Tuple[str, str, str]]: + def parse_expression_response(response: str, chat_id: str) -> list[tuple[str, str, str]]: """ 解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组 """ - expressions: List[Tuple[str, str, str]] = [] + expressions: list[tuple[str, str, str]] = [] for line in response.splitlines(): line = line.strip() if not line: @@ -562,7 +561,7 @@ class ExpressionLearnerManager: if not os.path.exists(expr_file): continue try: - with open(expr_file, "r", encoding="utf-8") as f: + with open(expr_file, encoding="utf-8") as f: expressions = orjson.loads(f.read()) if not isinstance(expressions, list): diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index ff4083a3b..431d55b46 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -1,18 +1,18 @@ -import orjson -import time -import random import hashlib +import random +import time +from typing import Any -from typing import List, Dict, Tuple, Optional, Any +import orjson from json_repair import repair_json - -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger from sqlalchemy import select -from src.common.database.sqlalchemy_models import Expression + from src.chat.utils.prompt import Prompt, global_prompt_manager from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import Expression +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("expression_selector") @@ -45,7 +45,7 @@ def init_prompt(): Prompt(expression_evaluation_prompt, "expression_evaluation_prompt") -def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]: +def weighted_sample(population: list[dict], weights: list[float], k: int) -> list[dict]: """按权重随机抽样""" if not population or not weights or k <= 0: return [] @@ -95,7 +95,7 @@ class ExpressionSelector: return False @staticmethod - def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]: + def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None: """解析'platform:id:type'为chat_id(与get_stream_id一致)""" try: parts = stream_config_str.split(":") @@ -114,7 +114,7 @@ class ExpressionSelector: except Exception: return None - def get_related_chat_ids(self, chat_id: str) -> List[str]: + def get_related_chat_ids(self, chat_id: str) -> list[str]: """根据expression.rules配置,获取与当前chat_id相关的所有chat_id(包括自身)""" rules = global_config.expression.rules current_group = None @@ -139,7 +139,7 @@ class ExpressionSelector: async def get_random_expressions( self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float - ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: # sourcery skip: extract-duplicate-method, move-assign # 支持多chat_id合并抽选 related_chat_ids = self.get_related_chat_ids(chat_id) @@ -195,7 +195,7 @@ class ExpressionSelector: return selected_style, selected_grammar @staticmethod - async def update_expressions_count_batch(expressions_to_update: List[Dict[str, Any]], increment: float = 0.1): + async def update_expressions_count_batch(expressions_to_update: list[dict[str, Any]], increment: float = 0.1): """对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库""" if not expressions_to_update: return @@ -240,8 +240,8 @@ class ExpressionSelector: chat_info: str, max_num: int = 10, min_num: int = 5, - target_message: Optional[str] = None, - ) -> List[Dict[str, Any]]: + target_message: str | None = None, + ) -> list[dict[str, Any]]: # sourcery skip: inline-variable, list-comprehension """使用LLM选择适合的表达方式""" diff --git a/src/chat/frequency_analyzer/analyzer.py b/src/chat/frequency_analyzer/analyzer.py index 1493c47ea..a3e6addea 100644 --- a/src/chat/frequency_analyzer/analyzer.py +++ b/src/chat/frequency_analyzer/analyzer.py @@ -16,8 +16,7 @@ Chat Frequency Analyzer """ import time as time_module -from datetime import datetime, timedelta, time -from typing import List, Tuple, Optional +from datetime import datetime, time, timedelta from .tracker import chat_frequency_tracker @@ -42,7 +41,7 @@ class ChatFrequencyAnalyzer: self._cache_ttl_seconds = 60 * 30 # 缓存30分钟 @staticmethod - def _find_peak_windows(timestamps: List[float]) -> List[Tuple[datetime, datetime]]: + def _find_peak_windows(timestamps: list[float]) -> list[tuple[datetime, datetime]]: """ 使用滑动窗口算法来识别时间戳列表中的高峰时段。 @@ -59,7 +58,7 @@ class ChatFrequencyAnalyzer: datetimes = [datetime.fromtimestamp(ts) for ts in timestamps] datetimes.sort() - peak_windows: List[Tuple[datetime, datetime]] = [] + peak_windows: list[tuple[datetime, datetime]] = [] window_start_idx = 0 for i in range(len(datetimes)): @@ -83,7 +82,7 @@ class ChatFrequencyAnalyzer: return peak_windows - def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]: + def get_peak_chat_times(self, chat_id: str) -> list[tuple[time, time]]: """ 获取指定用户的高峰聊天时间段。 @@ -116,7 +115,7 @@ class ChatFrequencyAnalyzer: return peak_time_windows - def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool: + def is_in_peak_time(self, chat_id: str, now: datetime | None = None) -> bool: """ 检查当前时间是否处于用户的高峰聊天时段内。 diff --git a/src/chat/frequency_analyzer/tracker.py b/src/chat/frequency_analyzer/tracker.py index 3621cb5b4..371fc6351 100644 --- a/src/chat/frequency_analyzer/tracker.py +++ b/src/chat/frequency_analyzer/tracker.py @@ -1,8 +1,8 @@ -import orjson import time -from typing import Dict, List, Optional from pathlib import Path +import orjson + from src.common.logger import get_logger # 数据存储路径 @@ -19,10 +19,10 @@ class ChatFrequencyTracker: """ def __init__(self): - self._timestamps: Dict[str, List[float]] = self._load_timestamps() + self._timestamps: dict[str, list[float]] = self._load_timestamps() @staticmethod - def _load_timestamps() -> Dict[str, List[float]]: + def _load_timestamps() -> dict[str, list[float]]: """从本地文件加载时间戳数据。""" if not TRACKER_FILE.exists(): return {} @@ -61,7 +61,7 @@ class ChatFrequencyTracker: logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}") self._save_timestamps() - def get_timestamps_for_chat(self, chat_id: str) -> Optional[List[float]]: + def get_timestamps_for_chat(self, chat_id: str) -> list[float] | None: """ 获取指定聊天的所有时间戳记录。 diff --git a/src/chat/frequency_analyzer/trigger.py b/src/chat/frequency_analyzer/trigger.py index 2d8e8b56f..9d8a4fea0 100644 --- a/src/chat/frequency_analyzer/trigger.py +++ b/src/chat/frequency_analyzer/trigger.py @@ -18,11 +18,10 @@ Frequency-Based Proactive Trigger import asyncio import time from datetime import datetime -from typing import Dict, Optional from src.common.logger import get_logger -# AFC manager has been moved to chatter plugin +# AFC manager has been moved to chatter plugin # TODO: 需要重新实现主动思考和睡眠管理功能 from .analyzer import chat_frequency_analyzer @@ -42,10 +41,10 @@ class FrequencyBasedTrigger: def __init__(self): # TODO: 需要重新实现睡眠管理器 - self._task: Optional[asyncio.Task] = None + self._task: asyncio.Task | None = None # 记录上次为用户触发的时间,用于冷却控制 # 格式: { "chat_id": timestamp } - self._last_triggered: Dict[str, float] = {} + self._last_triggered: dict[str, float] = {} async def _run_trigger_cycle(self): """触发器的主要循环逻辑。""" diff --git a/src/chat/interest_system/__init__.py b/src/chat/interest_system/__init__.py index e05cbeebf..0d1a9bbe8 100644 --- a/src/chat/interest_system/__init__.py +++ b/src/chat/interest_system/__init__.py @@ -3,13 +3,14 @@ 提供机器人兴趣标签和智能匹配功能 """ -from .bot_interest_manager import BotInterestManager, bot_interest_manager from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult +from .bot_interest_manager import BotInterestManager, bot_interest_manager + __all__ = [ "BotInterestManager", - "bot_interest_manager", "BotInterestTag", "BotPersonalityInterests", "InterestMatchResult", + "bot_interest_manager", ] diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index 8fee48d1c..b26095f4c 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -3,17 +3,18 @@ 基于人设生成兴趣标签,并使用embedding计算匹配度 """ -import orjson import traceback -from typing import List, Dict, Optional, Any from datetime import datetime +from typing import Any + import numpy as np +import orjson from sqlalchemy import select +from src.common.config_helpers import resolve_embedding_dimension +from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult from src.common.logger import get_logger from src.config.config import global_config -from src.common.config_helpers import resolve_embedding_dimension -from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult logger = get_logger("bot_interest_manager") @@ -22,8 +23,8 @@ class BotInterestManager: """机器人兴趣标签管理器""" def __init__(self): - self.current_interests: Optional[BotPersonalityInterests] = None - self.embedding_cache: Dict[str, List[float]] = {} # embedding缓存 + self.current_interests: BotPersonalityInterests | None = None + self.embedding_cache: dict[str, list[float]] = {} # embedding缓存 self._initialized = False # Embedding客户端配置 @@ -31,7 +32,7 @@ class BotInterestManager: self.embedding_config = None configured_dim = resolve_embedding_dimension() self.embedding_dimension = int(configured_dim) if configured_dim else 0 - self._detected_embedding_dimension: Optional[int] = None + self._detected_embedding_dimension: int | None = None @property def is_initialized(self) -> bool: @@ -145,7 +146,7 @@ class BotInterestManager: async def _generate_interests_from_personality( self, personality_description: str, personality_id: str - ) -> Optional[BotPersonalityInterests]: + ) -> BotPersonalityInterests | None: """根据人设生成兴趣标签""" try: logger.info("🎨 开始根据人设生成兴趣标签...") @@ -226,14 +227,14 @@ class BotInterestManager: traceback.print_exc() raise - async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]: + async def _call_llm_for_interest_generation(self, prompt: str) -> str | None: """调用LLM生成兴趣标签""" try: logger.info("🔧 配置LLM客户端...") # 使用llm_api来处理请求 - from src.plugin_system.apis import llm_api from src.config.config import model_config + from src.plugin_system.apis import llm_api # 构建完整的提示词,明确要求只返回纯JSON full_prompt = f"""你是一个专业的机器人人设分析师,擅长根据人设描述生成合适的兴趣标签。 @@ -342,7 +343,7 @@ class BotInterestManager: logger.info(f"🗃️ 总缓存大小: {len(self.embedding_cache)}") logger.info("=" * 50) - async def _get_embedding(self, text: str) -> List[float]: + async def _get_embedding(self, text: str) -> list[float]: """获取文本的embedding向量""" if not hasattr(self, "embedding_request"): raise RuntimeError("❌ Embedding请求客户端未初始化") @@ -383,7 +384,7 @@ class BotInterestManager: else: raise RuntimeError(f"❌ 返回的embedding为空: {embedding}") - async def _generate_message_embedding(self, message_text: str, keywords: List[str]) -> List[float]: + async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> list[float]: """为消息生成embedding向量""" # 组合消息文本和关键词作为embedding输入 if keywords: @@ -399,7 +400,7 @@ class BotInterestManager: return embedding async def _calculate_similarity_scores( - self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str] + self, result: InterestMatchResult, message_embedding: list[float], keywords: list[str] ): """计算消息与兴趣标签的相似度分数""" try: @@ -428,7 +429,7 @@ class BotInterestManager: except Exception as e: logger.error(f"❌ 计算相似度分数失败: {e}") - async def calculate_interest_match(self, message_text: str, keywords: List[str] = None) -> InterestMatchResult: + async def calculate_interest_match(self, message_text: str, keywords: list[str] = None) -> InterestMatchResult: """计算消息与机器人兴趣的匹配度""" if not self.current_interests or not self._initialized: raise RuntimeError("❌ 兴趣标签系统未初始化") @@ -528,7 +529,7 @@ class BotInterestManager: ) return result - def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]: + def _calculate_keyword_match_bonus(self, keywords: list[str], matched_tags: list[str]) -> dict[str, float]: """计算关键词直接匹配奖励""" if not keywords or not matched_tags: return {} @@ -610,7 +611,7 @@ class BotInterestManager: return previous_row[-1] - def _calculate_cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float: + def _calculate_cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float: """计算余弦相似度""" try: vec1 = np.array(vec1) @@ -629,16 +630,17 @@ class BotInterestManager: logger.error(f"计算余弦相似度失败: {e}") return 0.0 - async def _load_interests_from_database(self, personality_id: str) -> Optional[BotPersonalityInterests]: + async def _load_interests_from_database(self, personality_id: str) -> BotPersonalityInterests | None: """从数据库加载兴趣标签""" try: logger.debug(f"从数据库加载兴趣标签, personality_id: {personality_id}") # 导入SQLAlchemy相关模块 - from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests - from src.common.database.sqlalchemy_database_api import get_db_session import orjson + from src.common.database.sqlalchemy_database_api import get_db_session + from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests + async with get_db_session() as session: # 查询最新的兴趣标签配置 db_interests = ( @@ -716,10 +718,11 @@ class BotInterestManager: logger.info(f"🔄 版本: {interests.version}") # 导入SQLAlchemy相关模块 - from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests - from src.common.database.sqlalchemy_database_api import get_db_session import orjson + from src.common.database.sqlalchemy_database_api import get_db_session + from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests + # 将兴趣标签转换为JSON格式 tags_data = [] for tag in interests.interest_tags: @@ -803,11 +806,11 @@ class BotInterestManager: logger.error("🔍 错误详情:") traceback.print_exc() - def get_current_interests(self) -> Optional[BotPersonalityInterests]: + def get_current_interests(self) -> BotPersonalityInterests | None: """获取当前的兴趣标签配置""" return self.current_interests - def get_interest_stats(self) -> Dict[str, Any]: + def get_interest_stats(self) -> dict[str, Any]: """获取兴趣系统统计信息""" if not self.current_interests: return {"initialized": False} diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index f6fae8d6c..7ef04f985 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -1,33 +1,31 @@ -from dataclasses import dataclass -import orjson -import os -import math import asyncio +import math +import os from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Dict, List, Tuple - -import numpy as np -import pandas as pd +from dataclasses import dataclass # import tqdm import faiss - -from .utils.hash import get_sha256 -from .global_logger import logger -from rich.traceback import install +import numpy as np +import orjson +import pandas as pd from rich.progress import ( - Progress, BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, TimeElapsedColumn, TimeRemainingColumn, - TaskProgressColumn, - MofNCompleteColumn, - SpinnerColumn, - TextColumn, ) -from src.config.config import global_config -from src.common.config_helpers import resolve_embedding_dimension +from rich.traceback import install +from src.common.config_helpers import resolve_embedding_dimension +from src.config.config import global_config + +from .global_logger import logger +from .utils.hash import get_sha256 install(extra_lines=3) @@ -79,7 +77,7 @@ def cosine_similarity(a, b): class EmbeddingStoreItem: """嵌入库中的项""" - def __init__(self, item_hash: str, embedding: List[float], content: str): + def __init__(self, item_hash: str, embedding: list[float], content: str): self.hash = item_hash self.embedding = embedding self.str = content @@ -127,7 +125,7 @@ class EmbeddingStore: self.idx2hash = None @staticmethod - def _get_embedding(s: str) -> List[float]: + def _get_embedding(s: str) -> list[float]: """获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题""" # 创建新的事件循环并在完成后立即关闭 loop = asyncio.new_event_loop() @@ -135,8 +133,8 @@ class EmbeddingStore: try: # 创建新的LLMRequest实例 - from src.llm_models.utils_model import LLMRequest from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") @@ -161,8 +159,8 @@ class EmbeddingStore: @staticmethod def _get_embeddings_batch_threaded( - strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None - ) -> List[Tuple[str, List[float]]]: + strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None + ) -> list[tuple[str, list[float]]]: """使用多线程批量获取嵌入向量 Args: @@ -192,8 +190,8 @@ class EmbeddingStore: chunk_results = [] # 为每个线程创建独立的LLMRequest实例 - from src.llm_models.utils_model import LLMRequest from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest try: # 创建线程专用的LLM实例 @@ -303,7 +301,7 @@ class EmbeddingStore: path = self.get_test_file_path() if not os.path.exists(path): return None - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return orjson.loads(f.read()) def check_embedding_model_consistency(self): @@ -345,7 +343,7 @@ class EmbeddingStore: logger.info("嵌入模型一致性校验通过。") return True - def batch_insert_strs(self, strs: List[str], times: int) -> None: + def batch_insert_strs(self, strs: list[str], times: int) -> None: """向库中存入字符串(使用多线程优化)""" if not strs: return @@ -481,7 +479,7 @@ class EmbeddingStore: if os.path.exists(self.idx2hash_file_path): logger.info(f"正在加载{self.namespace}嵌入库的idx2hash映射...") logger.debug(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射") - with open(self.idx2hash_file_path, "r") as f: + with open(self.idx2hash_file_path) as f: self.idx2hash = orjson.loads(f.read()) logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功") else: @@ -511,7 +509,7 @@ class EmbeddingStore: self.faiss_index = faiss.IndexFlatIP(embedding_dim) self.faiss_index.add(embeddings) - def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]: + def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]: """搜索最相似的k个项,以余弦相似度为度量 Args: query: 查询的embedding @@ -575,11 +573,11 @@ class EmbeddingManager: """对所有嵌入库做模型一致性校验""" return self.paragraphs_embedding_store.check_embedding_model_consistency() - def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): + def _store_pg_into_embedding(self, raw_paragraphs: dict[str, str]): """将段落编码存入Embedding库""" self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1) - def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): + def _store_ent_into_embedding(self, triple_list_data: dict[str, list[list[str]]]): """将实体编码存入Embedding库""" entities = set() for triple_list in triple_list_data.values(): @@ -588,7 +586,7 @@ class EmbeddingManager: entities.add(triple[2]) self.entities_embedding_store.batch_insert_strs(list(entities), times=2) - def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): + def _store_rel_into_embedding(self, triple_list_data: dict[str, list[list[str]]]): """将关系编码存入Embedding库""" graph_triples = [] # a list of unique relation triple (in tuple) from all chunks for triples in triple_list_data.values(): @@ -606,8 +604,8 @@ class EmbeddingManager: def store_new_data_set( self, - raw_paragraphs: Dict[str, str], - triple_list_data: Dict[str, List[List[str]]], + raw_paragraphs: dict[str, str], + triple_list_data: dict[str, list[list[str]]], ): if not self.check_all_embedding_model_consistency(): raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。") diff --git a/src/chat/knowledge/ie_process.py b/src/chat/knowledge/ie_process.py index 457396d0a..e74b7d127 100644 --- a/src/chat/knowledge/ie_process.py +++ b/src/chat/knowledge/ie_process.py @@ -1,14 +1,15 @@ import asyncio -import orjson import time -from typing import List, Union -from .global_logger import logger -from . import prompt_template -from .knowledge_lib import INVALID_ENTITY -from src.llm_models.utils_model import LLMRequest +import orjson from json_repair import repair_json +from src.llm_models.utils_model import LLMRequest + +from . import prompt_template +from .global_logger import logger +from .knowledge_lib import INVALID_ENTITY + def _extract_json_from_text(text: str): # sourcery skip: assign-if-exp, extract-method @@ -46,7 +47,7 @@ def _extract_json_from_text(text: str): return [] -def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: +def _entity_extract(llm_req: LLMRequest, paragraph: str) -> list[str]: # sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" entity_extract_context = prompt_template.build_entity_extract_context(paragraph) @@ -92,7 +93,7 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: return entity_extract_result -def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]: +def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> list[list[str]]: """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" rdf_extract_context = prompt_template.build_rdf_triple_extract_context( paragraph, entities=orjson.dumps(entities).decode("utf-8") @@ -141,7 +142,7 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> def info_extract_from_str( llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str -) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]: +) -> tuple[None, None] | tuple[list[str], list[list[str]]]: try_count = 0 while True: try: diff --git a/src/chat/knowledge/kg_manager.py b/src/chat/knowledge/kg_manager.py index 6d0585226..f590fad7d 100644 --- a/src/chat/knowledge/kg_manager.py +++ b/src/chat/knowledge/kg_manager.py @@ -1,28 +1,26 @@ -import orjson import os import time -from typing import Dict, List, Tuple import numpy as np +import orjson import pandas as pd +from quick_algo import di_graph, pagerank from rich.progress import ( - Progress, BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, TimeElapsedColumn, TimeRemainingColumn, - TaskProgressColumn, - MofNCompleteColumn, - SpinnerColumn, - TextColumn, ) -from quick_algo import di_graph, pagerank - -from .utils.hash import get_sha256 -from .embedding_store import EmbeddingManager, EmbeddingStoreItem from src.config.config import global_config +from .embedding_store import EmbeddingManager, EmbeddingStoreItem from .global_logger import logger +from .utils.hash import get_sha256 def _get_kg_dir(): @@ -87,7 +85,7 @@ class KGManager: raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在") # 加载段落hash - with open(self.pg_hash_file_path, "r", encoding="utf-8") as f: + with open(self.pg_hash_file_path, encoding="utf-8") as f: data = orjson.loads(f.read()) self.stored_paragraph_hashes = set(data["stored_paragraph_hashes"]) @@ -100,8 +98,8 @@ class KGManager: def _build_edges_between_ent( self, - node_to_node: Dict[Tuple[str, str], float], - triple_list_data: Dict[str, List[List[str]]], + node_to_node: dict[tuple[str, str], float], + triple_list_data: dict[str, list[list[str]]], ): """构建实体节点之间的关系,同时统计实体出现次数""" for triple_list in triple_list_data.values(): @@ -124,8 +122,8 @@ class KGManager: @staticmethod def _build_edges_between_ent_pg( - node_to_node: Dict[Tuple[str, str], float], - triple_list_data: Dict[str, List[List[str]]], + node_to_node: dict[tuple[str, str], float], + triple_list_data: dict[str, list[list[str]]], ): """构建实体节点与文段节点之间的关系""" for idx in triple_list_data: @@ -136,8 +134,8 @@ class KGManager: @staticmethod def _synonym_connect( - node_to_node: Dict[Tuple[str, str], float], - triple_list_data: Dict[str, List[List[str]]], + node_to_node: dict[tuple[str, str], float], + triple_list_data: dict[str, list[list[str]]], embedding_manager: EmbeddingManager, ) -> int: """同义词连接""" @@ -208,7 +206,7 @@ class KGManager: def _update_graph( self, - node_to_node: Dict[Tuple[str, str], float], + node_to_node: dict[tuple[str, str], float], embedding_manager: EmbeddingManager, ): """更新KG图结构 @@ -280,7 +278,7 @@ class KGManager: def build_kg( self, - triple_list_data: Dict[str, List[List[str]]], + triple_list_data: dict[str, list[list[str]]], embedding_manager: EmbeddingManager, ): """增量式构建KG @@ -317,8 +315,8 @@ class KGManager: def kg_search( self, - relation_search_result: List[Tuple[Tuple[str, str, str], float]], - paragraph_search_result: List[Tuple[str, float]], + relation_search_result: list[tuple[tuple[str, str, str], float]], + paragraph_search_result: list[tuple[str, float]], embed_manager: EmbeddingManager, ): """RAG搜索与PageRank diff --git a/src/chat/knowledge/knowledge_lib.py b/src/chat/knowledge/knowledge_lib.py index ccc3cd090..a1f49f314 100644 --- a/src/chat/knowledge/knowledge_lib.py +++ b/src/chat/knowledge/knowledge_lib.py @@ -1,10 +1,11 @@ -from src.chat.knowledge.embedding_store import EmbeddingManager -from src.chat.knowledge.qa_manager import QAManager -from src.chat.knowledge.kg_manager import KGManager -from src.chat.knowledge.global_logger import logger -from src.config.config import global_config import os +from src.chat.knowledge.embedding_store import EmbeddingManager +from src.chat.knowledge.global_logger import logger +from src.chat.knowledge.kg_manager import KGManager +from src.chat.knowledge.qa_manager import QAManager +from src.config.config import global_config + INVALID_ENTITY = [ "", "你", diff --git a/src/chat/knowledge/open_ie.py b/src/chat/knowledge/open_ie.py index 23b3032d5..aa01c6c2f 100644 --- a/src/chat/knowledge/open_ie.py +++ b/src/chat/knowledge/open_ie.py @@ -1,14 +1,15 @@ -import orjson -import os import glob -from typing import Any, Dict, List +import os +from typing import Any +import orjson + +from .knowledge_lib import DATA_PATH, INVALID_ENTITY, ROOT_PATH -from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH # from src.manager.local_store_manager import local_storage -def _filter_invalid_entities(entities: List[str]) -> List[str]: +def _filter_invalid_entities(entities: list[str]) -> list[str]: """过滤无效的实体""" valid_entities = set() for entity in entities: @@ -20,7 +21,7 @@ def _filter_invalid_entities(entities: List[str]) -> List[str]: return list(valid_entities) -def _filter_invalid_triples(triples: List[List[str]]) -> List[List[str]]: +def _filter_invalid_triples(triples: list[list[str]]) -> list[list[str]]: """过滤无效的三元组""" unique_triples = set() valid_triples = [] @@ -62,7 +63,7 @@ class OpenIE: def __init__( self, - docs: List[Dict[str, Any]], + docs: list[dict[str, Any]], avg_ent_chars, avg_ent_words, ): @@ -112,7 +113,7 @@ class OpenIE: json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json"))) data_list = [] for file in json_files: - with open(file, "r", encoding="utf-8") as f: + with open(file, encoding="utf-8") as f: data = orjson.loads(f.read()) data_list.append(data) if not data_list: diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index c340fc30e..b08fb24e0 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -1,15 +1,16 @@ import time -from typing import Tuple, List, Dict, Optional, Any +from typing import Any + +from src.chat.utils.utils import get_embedding +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest -from .global_logger import logger from .embedding_store import EmbeddingManager +from .global_logger import logger from .kg_manager import KGManager # from .lpmmconfig import global_config from .utils.dyn_topk import dyn_select_top_k -from src.llm_models.utils_model import LLMRequest -from src.chat.utils.utils import get_embedding -from src.config.config import global_config, model_config MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度 @@ -26,7 +27,7 @@ class QAManager: async def process_query( self, question: str - ) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]: + ) -> tuple[list[tuple[str, float, float]], dict[str, float] | None] | None: """处理查询""" # 生成问题的Embedding @@ -98,7 +99,7 @@ class QAManager: return result, ppr_node_weights - async def get_knowledge(self, question: str) -> Optional[Dict[str, Any]]: + async def get_knowledge(self, question: str) -> dict[str, Any] | None: """ 获取知识,返回结构化字典 diff --git a/src/chat/knowledge/utils/dyn_topk.py b/src/chat/knowledge/utils/dyn_topk.py index df9e470dc..106a68da4 100644 --- a/src/chat/knowledge/utils/dyn_topk.py +++ b/src/chat/knowledge/utils/dyn_topk.py @@ -1,9 +1,9 @@ -from typing import List, Any, Tuple +from typing import Any def dyn_select_top_k( - score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float -) -> List[Tuple[Any, float, float]]: + score: list[tuple[Any, float]], jmp_factor: float, var_factor: float +) -> list[tuple[Any, float, float]]: """动态TopK选择""" # 检查输入列表是否为空 if not score: diff --git a/src/chat/memory_system/__init__.py b/src/chat/memory_system/__init__.py index a1c176a10..d3c5feea4 100644 --- a/src/chat/memory_system/__init__.py +++ b/src/chat/memory_system/__init__.py @@ -1,37 +1,35 @@ -# -*- coding: utf-8 -*- """ 简化记忆系统模块 移除即时记忆和长期记忆分类,实现统一记忆架构和智能遗忘机制 """ # 核心数据结构 +# 激活器 +from .enhanced_memory_activator import MemoryActivator, enhanced_memory_activator, memory_activator from .memory_chunk import ( + ConfidenceLevel, + ContentStructure, + ImportanceLevel, MemoryChunk, MemoryMetadata, - ContentStructure, MemoryType, - ImportanceLevel, - ConfidenceLevel, create_memory_chunk, ) +# 兼容性别名 +from .memory_chunk import MemoryChunk as Memory + # 遗忘引擎 -from .memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig, get_memory_forgetting_engine - -# Vector DB存储系统 -from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage - -# 记忆核心系统 -from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system +from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine # 记忆管理器 from .memory_manager import MemoryManager, MemoryResult, memory_manager -# 激活器 -from .enhanced_memory_activator import MemoryActivator, memory_activator, enhanced_memory_activator +# 记忆核心系统 +from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system -# 兼容性别名 -from .memory_chunk import MemoryChunk as Memory +# Vector DB存储系统 +from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage __all__ = [ # 核心数据结构 diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py index aae09c08b..cf93ceaf0 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py @@ -1,17 +1,17 @@ -# -*- coding: utf-8 -*- """ 增强记忆系统适配器 将增强记忆系统集成到现有MoFox Bot架构中 """ import time -from typing import Dict, List, Optional, Any from dataclasses import dataclass +from typing import Any -from src.common.logger import get_logger -from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +from src.chat.memory_system.integration_layer import IntegrationConfig, IntegrationMode, MemoryIntegrationLayer from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm + +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest logger = get_logger(__name__) @@ -47,10 +47,10 @@ class AdapterConfig: class EnhancedMemoryAdapter: """增强记忆系统适配器""" - def __init__(self, llm_model: LLMRequest, config: Optional[AdapterConfig] = None): + def __init__(self, llm_model: LLMRequest, config: AdapterConfig | None = None): self.llm_model = llm_model self.config = config or AdapterConfig() - self.integration_layer: Optional[MemoryIntegrationLayer] = None + self.integration_layer: MemoryIntegrationLayer | None = None self._initialized = False # 统计信息 @@ -96,7 +96,7 @@ class EnhancedMemoryAdapter: # 如果初始化失败,禁用增强记忆功能 self.config.enable_enhanced_memory = False - async def process_conversation_memory(self, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + async def process_conversation_memory(self, context: dict[str, Any] | None = None) -> dict[str, Any]: """处理对话记忆,以上下文为唯一输入""" if not self._initialized or not self.config.enable_enhanced_memory: return {"success": False, "error": "Enhanced memory not available"} @@ -105,7 +105,7 @@ class EnhancedMemoryAdapter: self.adapter_stats["total_processed"] += 1 try: - payload_context: Dict[str, Any] = dict(context or {}) + payload_context: dict[str, Any] = dict(context or {}) conversation_text = payload_context.get("conversation_text") if not conversation_text: @@ -146,8 +146,8 @@ class EnhancedMemoryAdapter: return {"success": False, "error": str(e)} async def retrieve_relevant_memories( - self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: Optional[int] = None - ) -> List[MemoryChunk]: + self, query: str, user_id: str, context: dict[str, Any] | None = None, limit: int | None = None + ) -> list[MemoryChunk]: """检索相关记忆""" if not self._initialized or not self.config.enable_enhanced_memory: return [] @@ -166,7 +166,7 @@ class EnhancedMemoryAdapter: return [] async def get_memory_context_for_prompt( - self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, max_memories: int = 5 + self, query: str, user_id: str, context: dict[str, Any] | None = None, max_memories: int = 5 ) -> str: """获取用于提示词的记忆上下文""" memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories) @@ -186,7 +186,7 @@ class EnhancedMemoryAdapter: return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config) - async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]: + async def get_enhanced_memory_summary(self, user_id: str) -> dict[str, Any]: """获取增强记忆系统摘要""" if not self._initialized or not self.config.enable_enhanced_memory: return {"available": False, "reason": "Not initialized or disabled"} @@ -222,7 +222,7 @@ class EnhancedMemoryAdapter: new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed self.adapter_stats["average_processing_time"] = new_avg - def get_adapter_stats(self) -> Dict[str, Any]: + def get_adapter_stats(self) -> dict[str, Any]: """获取适配器统计信息""" return self.adapter_stats.copy() @@ -253,7 +253,7 @@ class EnhancedMemoryAdapter: # 全局适配器实例 -_enhanced_memory_adapter: Optional[EnhancedMemoryAdapter] = None +_enhanced_memory_adapter: EnhancedMemoryAdapter | None = None async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter: @@ -292,8 +292,8 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest): async def process_conversation_with_enhanced_memory( - context: Dict[str, Any], llm_model: Optional[LLMRequest] = None -) -> Dict[str, Any]: + context: dict[str, Any], llm_model: LLMRequest | None = None +) -> dict[str, Any]: """使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息""" if not llm_model: # 获取默认的LLM模型 @@ -323,10 +323,10 @@ async def process_conversation_with_enhanced_memory( async def retrieve_memories_with_enhanced_system( query: str, user_id: str, - context: Optional[Dict[str, Any]] = None, + context: dict[str, Any] | None = None, limit: int = 10, - llm_model: Optional[LLMRequest] = None, -) -> List[MemoryChunk]: + llm_model: LLMRequest | None = None, +) -> list[MemoryChunk]: """使用增强记忆系统检索记忆""" if not llm_model: # 获取默认的LLM模型 @@ -345,9 +345,9 @@ async def retrieve_memories_with_enhanced_system( async def get_memory_context_for_prompt( query: str, user_id: str, - context: Optional[Dict[str, Any]] = None, + context: dict[str, Any] | None = None, max_memories: int = 5, - llm_model: Optional[LLMRequest] = None, + llm_model: LLMRequest | None = None, ) -> str: """获取用于提示词的记忆上下文""" if not llm_model: diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py index a1b374510..2794332cf 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py @@ -1,15 +1,15 @@ -# -*- coding: utf-8 -*- """ 增强记忆系统钩子 用于在消息处理过程中自动构建和检索记忆 """ -from typing import Dict, List, Any, Optional from datetime import datetime +from typing import Any + +from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager from src.common.logger import get_logger from src.config.config import global_config -from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager logger = get_logger(__name__) @@ -27,7 +27,7 @@ class EnhancedMemoryHooks: user_id: str, chat_id: str, message_id: str, - context: Optional[Dict[str, Any]] = None, + context: dict[str, Any] | None = None, ) -> bool: """ 处理消息并构建记忆 @@ -106,8 +106,8 @@ class EnhancedMemoryHooks: user_id: str, chat_id: str, limit: int = 5, - extra_context: Optional[Dict[str, Any]] = None, - ) -> List[Dict[str, Any]]: + extra_context: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: """ 为回复获取相关记忆 diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py index 913c2aed0..8583f7dd2 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py @@ -1,19 +1,19 @@ -# -*- coding: utf-8 -*- """ 增强记忆系统集成脚本 用于在现有系统中无缝集成增强记忆功能 """ -from typing import Dict, Any, Optional +from typing import Any + +from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks from src.common.logger import get_logger -from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks logger = get_logger(__name__) async def process_user_message_memory( - message_content: str, user_id: str, chat_id: str, message_id: str, context: Optional[Dict[str, Any]] = None + message_content: str, user_id: str, chat_id: str, message_id: str, context: dict[str, Any] | None = None ) -> bool: """ 处理用户消息并构建记忆 @@ -44,8 +44,8 @@ async def process_user_message_memory( async def get_relevant_memories_for_response( - query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: + query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: dict[str, Any] | None = None +) -> dict[str, Any]: """ 为回复获取相关记忆 @@ -74,7 +74,7 @@ async def get_relevant_memories_for_response( return {"has_memories": False, "memories": [], "memory_count": 0} -def format_memories_for_prompt(memories: Dict[str, Any]) -> str: +def format_memories_for_prompt(memories: dict[str, Any]) -> str: """ 格式化记忆信息用于Prompt @@ -114,7 +114,7 @@ async def cleanup_memory_system(): logger.error(f"记忆系统清理失败: {e}") -def get_memory_system_status() -> Dict[str, Any]: +def get_memory_system_status() -> dict[str, Any]: """ 获取记忆系统状态 @@ -133,7 +133,7 @@ def get_memory_system_status() -> Dict[str, Any]: # 便捷函数 async def remember_message( - message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: Optional[Dict[str, Any]] = None + message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: dict[str, Any] | None = None ) -> bool: """ 便捷的记忆构建函数 @@ -159,8 +159,8 @@ async def recall_memories( user_id: str = "default_user", chat_id: str = "default_chat", limit: int = 5, - context: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: + context: dict[str, Any] | None = None, +) -> dict[str, Any]: """ 便捷的记忆检索函数 diff --git a/src/chat/memory_system/deprecated_backup/enhanced_reranker.py b/src/chat/memory_system/deprecated_backup/enhanced_reranker.py index e5b368460..c35b9de53 100644 --- a/src/chat/memory_system/deprecated_backup/enhanced_reranker.py +++ b/src/chat/memory_system/deprecated_backup/enhanced_reranker.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 增强重排序器 实现文档设计的多维度评分模型 @@ -6,12 +5,12 @@ import math import time -from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass from enum import Enum +from typing import Any -from src.common.logger import get_logger from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +from src.common.logger import get_logger logger = get_logger(__name__) @@ -44,7 +43,7 @@ class ReRankingConfig: freq_max_score: float = 5.0 # 最大频率得分 # 类型匹配权重映射 - type_match_weights: Dict[str, Dict[str, float]] = None + type_match_weights: dict[str, dict[str, float]] = None def __post_init__(self): """初始化类型匹配权重""" @@ -157,7 +156,7 @@ class IntentClassifier: ], } - def classify_intent(self, query: str, context: Dict[str, Any]) -> IntentType: + def classify_intent(self, query: str, context: dict[str, Any]) -> IntentType: """识别对话意图""" if not query: return IntentType.UNKNOWN @@ -165,7 +164,7 @@ class IntentClassifier: query_lower = query.lower() # 统计各意图的匹配分数 - intent_scores = {intent: 0 for intent in IntentType} + intent_scores = dict.fromkeys(IntentType, 0) for intent, patterns in self.patterns.items(): for pattern in patterns: @@ -187,7 +186,7 @@ class IntentClassifier: class EnhancedReRanker: """增强重排序器 - 实现文档设计的多维度评分模型""" - def __init__(self, config: Optional[ReRankingConfig] = None): + def __init__(self, config: ReRankingConfig | None = None): self.config = config or ReRankingConfig() self.intent_classifier = IntentClassifier() @@ -210,10 +209,10 @@ class EnhancedReRanker: def rerank_memories( self, query: str, - candidate_memories: List[Tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity) - context: Dict[str, Any], + candidate_memories: list[tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity) + context: dict[str, Any], limit: int = 10, - ) -> List[Tuple[str, MemoryChunk, float]]: + ) -> list[tuple[str, MemoryChunk, float]]: """ 对候选记忆进行重排序 @@ -341,11 +340,11 @@ default_reranker = EnhancedReRanker() def rerank_candidate_memories( query: str, - candidate_memories: List[Tuple[str, MemoryChunk, float]], - context: Dict[str, Any], + candidate_memories: list[tuple[str, MemoryChunk, float]], + context: dict[str, Any], limit: int = 10, - config: Optional[ReRankingConfig] = None, -) -> List[Tuple[str, MemoryChunk, float]]: + config: ReRankingConfig | None = None, +) -> list[tuple[str, MemoryChunk, float]]: """ 便捷函数:对候选记忆进行重排序 """ diff --git a/src/chat/memory_system/deprecated_backup/integration_layer.py b/src/chat/memory_system/deprecated_backup/integration_layer.py index 5b9282a84..c7a27b8cb 100644 --- a/src/chat/memory_system/deprecated_backup/integration_layer.py +++ b/src/chat/memory_system/deprecated_backup/integration_layer.py @@ -1,18 +1,18 @@ -# -*- coding: utf-8 -*- """ 增强记忆系统集成层 现在只管理新的增强记忆系统,旧系统已被完全移除 """ -import time import asyncio -from typing import Dict, List, Optional, Any +import time from dataclasses import dataclass from enum import Enum +from typing import Any -from src.common.logger import get_logger from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem + from src.chat.memory_system.memory_chunk import MemoryChunk +from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest logger = get_logger(__name__) @@ -40,12 +40,12 @@ class IntegrationConfig: class MemoryIntegrationLayer: """记忆系统集成层 - 现在只管理增强记忆系统""" - def __init__(self, llm_model: LLMRequest, config: Optional[IntegrationConfig] = None): + def __init__(self, llm_model: LLMRequest, config: IntegrationConfig | None = None): self.llm_model = llm_model self.config = config or IntegrationConfig() # 只初始化增强记忆系统 - self.enhanced_memory: Optional[EnhancedMemorySystem] = None + self.enhanced_memory: EnhancedMemorySystem | None = None # 集成统计 self.integration_stats = { @@ -113,7 +113,7 @@ class MemoryIntegrationLayer: logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True) raise - async def process_conversation(self, context: Dict[str, Any]) -> Dict[str, Any]: + async def process_conversation(self, context: dict[str, Any]) -> dict[str, Any]: """处理对话记忆,仅使用上下文信息""" if not self._initialized or not self.enhanced_memory: return {"success": False, "error": "Memory system not available"} @@ -150,10 +150,10 @@ class MemoryIntegrationLayer: async def retrieve_relevant_memories( self, query: str, - user_id: Optional[str] = None, - context: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None, - ) -> List[MemoryChunk]: + user_id: str | None = None, + context: dict[str, Any] | None = None, + limit: int | None = None, + ) -> list[MemoryChunk]: """检索相关记忆""" if not self._initialized or not self.enhanced_memory: return [] @@ -172,7 +172,7 @@ class MemoryIntegrationLayer: logger.error(f"检索相关记忆失败: {e}", exc_info=True) return [] - async def get_system_status(self) -> Dict[str, Any]: + async def get_system_status(self) -> dict[str, Any]: """获取系统状态""" if not self._initialized: return {"status": "not_initialized"} @@ -193,7 +193,7 @@ class MemoryIntegrationLayer: logger.error(f"获取系统状态失败: {e}", exc_info=True) return {"status": "error", "error": str(e)} - def get_integration_stats(self) -> Dict[str, Any]: + def get_integration_stats(self) -> dict[str, Any]: """获取集成统计信息""" return self.integration_stats.copy() diff --git a/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py b/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py index 4659389cb..a37e4c548 100644 --- a/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py +++ b/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py @@ -1,20 +1,20 @@ -# -*- coding: utf-8 -*- """ 记忆系统集成钩子 提供与现有MoFox Bot系统的无缝集成点 """ import time -from typing import Dict, Optional, Any from dataclasses import dataclass +from typing import Any -from src.common.logger import get_logger from src.chat.memory_system.enhanced_memory_adapter import ( + get_memory_context_for_prompt, process_conversation_with_enhanced_memory, retrieve_memories_with_enhanced_system, - get_memory_context_for_prompt, ) +from src.common.logger import get_logger + logger = get_logger(__name__) @@ -24,7 +24,7 @@ class HookResult: success: bool data: Any = None - error: Optional[str] = None + error: str | None = None processing_time: float = 0.0 @@ -125,8 +125,8 @@ class MemoryIntegrationHooks: # 尝试注册到事件系统 try: - from src.plugin_system.core.event_manager import event_manager from src.plugin_system.base.component_types import EventType + from src.plugin_system.core.event_manager import event_manager # 注册消息后处理事件 event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler) @@ -238,11 +238,11 @@ class MemoryIntegrationHooks: # 钩子处理器方法 - async def _on_message_processed_handler(self, event_data: Dict[str, Any]) -> HookResult: + async def _on_message_processed_handler(self, event_data: dict[str, Any]) -> HookResult: """事件系统的消息处理处理器""" return await self._on_message_processed_hook(event_data) - async def _on_message_processed_hook(self, message_data: Dict[str, Any]) -> HookResult: + async def _on_message_processed_hook(self, message_data: dict[str, Any]) -> HookResult: """消息后处理钩子""" start_time = time.time() @@ -289,7 +289,7 @@ class MemoryIntegrationHooks: logger.error(f"消息处理钩子执行异常: {e}", exc_info=True) return HookResult(success=False, error=str(e), processing_time=processing_time) - async def _on_chat_stream_save_hook(self, chat_stream_data: Dict[str, Any]) -> HookResult: + async def _on_chat_stream_save_hook(self, chat_stream_data: dict[str, Any]) -> HookResult: """聊天流保存钩子""" start_time = time.time() @@ -345,7 +345,7 @@ class MemoryIntegrationHooks: logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True) return HookResult(success=False, error=str(e), processing_time=processing_time) - async def _on_pre_response_hook(self, response_data: Dict[str, Any]) -> HookResult: + async def _on_pre_response_hook(self, response_data: dict[str, Any]) -> HookResult: """回复前钩子""" start_time = time.time() @@ -380,7 +380,7 @@ class MemoryIntegrationHooks: logger.error(f"回复前钩子执行异常: {e}", exc_info=True) return HookResult(success=False, error=str(e), processing_time=processing_time) - async def _on_knowledge_query_hook(self, query_data: Dict[str, Any]) -> HookResult: + async def _on_knowledge_query_hook(self, query_data: dict[str, Any]) -> HookResult: """知识库查询钩子""" start_time = time.time() @@ -411,7 +411,7 @@ class MemoryIntegrationHooks: logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True) return HookResult(success=False, error=str(e), processing_time=processing_time) - async def _on_prompt_building_hook(self, prompt_data: Dict[str, Any]) -> HookResult: + async def _on_prompt_building_hook(self, prompt_data: dict[str, Any]) -> HookResult: """提示词构建钩子""" start_time = time.time() @@ -459,7 +459,7 @@ class MemoryIntegrationHooks: new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions self.hook_stats["average_hook_time"] = new_avg - def get_hook_stats(self) -> Dict[str, Any]: + def get_hook_stats(self) -> dict[str, Any]: """获取钩子统计信息""" return self.hook_stats.copy() @@ -501,7 +501,7 @@ class MemoryMaintenanceTask: # 全局钩子实例 -_memory_hooks: Optional[MemoryIntegrationHooks] = None +_memory_hooks: MemoryIntegrationHooks | None = None async def get_memory_integration_hooks() -> MemoryIntegrationHooks: diff --git a/src/chat/memory_system/deprecated_backup/metadata_index.py b/src/chat/memory_system/deprecated_backup/metadata_index.py index f7ab8ecda..8c89e5c34 100644 --- a/src/chat/memory_system/deprecated_backup/metadata_index.py +++ b/src/chat/memory_system/deprecated_backup/metadata_index.py @@ -1,20 +1,20 @@ -# -*- coding: utf-8 -*- """ 元数据索引系统 为记忆系统提供多维度的精准过滤和查询能力 """ +import threading import time -import orjson -from typing import Dict, List, Optional, Tuple, Set, Any, Union +from collections import defaultdict from dataclasses import dataclass from enum import Enum -import threading -from collections import defaultdict from pathlib import Path +from typing import Any +import orjson + +from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk, MemoryType from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel logger = get_logger(__name__) @@ -40,21 +40,21 @@ class IndexType(Enum): class IndexQuery: """索引查询条件""" - user_ids: Optional[List[str]] = None - memory_types: Optional[List[MemoryType]] = None - subjects: Optional[List[str]] = None - keywords: Optional[List[str]] = None - tags: Optional[List[str]] = None - categories: Optional[List[str]] = None - time_range: Optional[Tuple[float, float]] = None - confidence_levels: Optional[List[ConfidenceLevel]] = None - importance_levels: Optional[List[ImportanceLevel]] = None - min_relationship_score: Optional[float] = None - max_relationship_score: Optional[float] = None - min_access_count: Optional[int] = None - semantic_hashes: Optional[List[str]] = None - limit: Optional[int] = None - sort_by: Optional[str] = None # "created_at", "access_count", "relevance_score" + user_ids: list[str] | None = None + memory_types: list[MemoryType] | None = None + subjects: list[str] | None = None + keywords: list[str] | None = None + tags: list[str] | None = None + categories: list[str] | None = None + time_range: tuple[float, float] | None = None + confidence_levels: list[ConfidenceLevel] | None = None + importance_levels: list[ImportanceLevel] | None = None + min_relationship_score: float | None = None + max_relationship_score: float | None = None + min_access_count: int | None = None + semantic_hashes: list[str] | None = None + limit: int | None = None + sort_by: str | None = None # "created_at", "access_count", "relevance_score" sort_order: str = "desc" # "asc", "desc" @@ -62,10 +62,10 @@ class IndexQuery: class IndexResult: """索引结果""" - memory_ids: List[str] + memory_ids: list[str] total_count: int query_time: float - filtered_by: List[str] + filtered_by: list[str] class MetadataIndexManager: @@ -94,7 +94,7 @@ class MetadataIndexManager: self.access_frequency_index = [] # [(access_count, memory_id), ...] # 内存缓存 - self.memory_metadata_cache: Dict[str, Dict[str, Any]] = {} + self.memory_metadata_cache: dict[str, dict[str, Any]] = {} # 统计信息 self.index_stats = { @@ -140,7 +140,7 @@ class MetadataIndexManager: return key @staticmethod - def _serialize_metadata_entry(metadata: Dict[str, Any]) -> Dict[str, Any]: + def _serialize_metadata_entry(metadata: dict[str, Any]) -> dict[str, Any]: serialized = {} for field_name, value in metadata.items(): if isinstance(value, Enum): @@ -149,7 +149,7 @@ class MetadataIndexManager: serialized[field_name] = value return serialized - async def index_memories(self, memories: List[MemoryChunk]): + async def index_memories(self, memories: list[MemoryChunk]): """为记忆建立索引""" if not memories: return @@ -375,7 +375,7 @@ class MetadataIndexManager: logger.error(f"❌ 元数据查询失败: {e}", exc_info=True) return IndexResult(memory_ids=[], total_count=0, query_time=0.0, filtered_by=[]) - def _get_candidate_memories(self, query: IndexQuery) -> Set[str]: + def _get_candidate_memories(self, query: IndexQuery) -> set[str]: """获取候选记忆ID集合""" candidate_ids = set() @@ -444,7 +444,7 @@ class MetadataIndexManager: return candidate_ids - def _collect_index_matches(self, index_type: IndexType, token: Optional[Union[str, Enum]]) -> Set[str]: + def _collect_index_matches(self, index_type: IndexType, token: str | Enum | None) -> set[str]: """根据给定token收集索引匹配,支持部分匹配""" mapping = self.indices.get(index_type) if mapping is None: @@ -461,7 +461,7 @@ class MetadataIndexManager: if not key: return set() - matches: Set[str] = set(mapping.get(key, set())) + matches: set[str] = set(mapping.get(key, set())) if matches: return set(matches) @@ -477,7 +477,7 @@ class MetadataIndexManager: return matches - def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]: + def _apply_filters(self, candidate_ids: set[str], query: IndexQuery) -> list[str]: """应用过滤条件""" filtered_ids = list(candidate_ids) @@ -545,7 +545,7 @@ class MetadataIndexManager: created_at = self.memory_metadata_cache[memory_id]["created_at"] return start_time <= created_at <= end_time - def _sort_memories(self, memory_ids: List[str], sort_by: str, sort_order: str) -> List[str]: + def _sort_memories(self, memory_ids: list[str], sort_by: str, sort_order: str) -> list[str]: """对记忆进行排序""" if sort_by == "created_at": # 使用时间索引(已经有序) @@ -582,7 +582,7 @@ class MetadataIndexManager: return memory_ids - def _get_applied_filters(self, query: IndexQuery) -> List[str]: + def _get_applied_filters(self, query: IndexQuery) -> list[str]: """获取应用的过滤器列表""" filters = [] if query.memory_types: @@ -686,11 +686,11 @@ class MetadataIndexManager: except Exception as e: logger.error(f"❌ 移除记忆索引失败: {e}") - async def get_memory_metadata(self, memory_id: str) -> Optional[Dict[str, Any]]: + async def get_memory_metadata(self, memory_id: str) -> dict[str, Any] | None: """获取记忆元数据""" return self.memory_metadata_cache.get(memory_id) - async def get_user_memory_ids(self, user_id: str, limit: Optional[int] = None) -> List[str]: + async def get_user_memory_ids(self, user_id: str, limit: int | None = None) -> list[str]: """获取用户的所有记忆ID""" user_memory_ids = list(self.indices[IndexType.USER_ID].get(user_id, set())) @@ -699,7 +699,7 @@ class MetadataIndexManager: return user_memory_ids - async def get_memory_statistics(self, user_id: Optional[str] = None) -> Dict[str, Any]: + async def get_memory_statistics(self, user_id: str | None = None) -> dict[str, Any]: """获取记忆统计信息""" stats = { "total_memories": self.index_stats["total_memories"], @@ -784,7 +784,7 @@ class MetadataIndexManager: logger.info("正在保存元数据索引...") # 保存各类索引 - indices_data: Dict[str, Dict[str, List[str]]] = {} + indices_data: dict[str, dict[str, list[str]]] = {} for index_type, index_data in self.indices.items(): serialized_index = {} for key, values in index_data.items(): @@ -839,7 +839,7 @@ class MetadataIndexManager: # 加载各类索引 indices_file = self.index_path / "indices.json" if indices_file.exists(): - with open(indices_file, "r", encoding="utf-8") as f: + with open(indices_file, encoding="utf-8") as f: indices_data = orjson.loads(f.read()) for index_type_value, index_data in indices_data.items(): @@ -853,25 +853,25 @@ class MetadataIndexManager: # 加载时间索引 time_index_file = self.index_path / "time_index.json" if time_index_file.exists(): - with open(time_index_file, "r", encoding="utf-8") as f: + with open(time_index_file, encoding="utf-8") as f: self.time_index = orjson.loads(f.read()) # 加载关系分索引 relationship_index_file = self.index_path / "relationship_index.json" if relationship_index_file.exists(): - with open(relationship_index_file, "r", encoding="utf-8") as f: + with open(relationship_index_file, encoding="utf-8") as f: self.relationship_index = orjson.loads(f.read()) # 加载访问频率索引 access_frequency_index_file = self.index_path / "access_frequency_index.json" if access_frequency_index_file.exists(): - with open(access_frequency_index_file, "r", encoding="utf-8") as f: + with open(access_frequency_index_file, encoding="utf-8") as f: self.access_frequency_index = orjson.loads(f.read()) # 加载元数据缓存 metadata_cache_file = self.index_path / "metadata_cache.json" if metadata_cache_file.exists(): - with open(metadata_cache_file, "r", encoding="utf-8") as f: + with open(metadata_cache_file, encoding="utf-8") as f: cache_data = orjson.loads(f.read()) # 转换置信度和重要性为枚举类型 @@ -914,7 +914,7 @@ class MetadataIndexManager: # 加载统计信息 stats_file = self.index_path / "index_stats.json" if stats_file.exists(): - with open(stats_file, "r", encoding="utf-8") as f: + with open(stats_file, encoding="utf-8") as f: self.index_stats = orjson.loads(f.read()) # 更新记忆计数 @@ -1004,7 +1004,7 @@ class MetadataIndexManager: if len(self.indices[IndexType.CATEGORY][category]) < min_frequency: del self.indices[IndexType.CATEGORY][category] - def get_index_stats(self) -> Dict[str, Any]: + def get_index_stats(self) -> dict[str, Any]: """获取索引统计信息""" stats = self.index_stats.copy() if stats["total_queries"] > 0: diff --git a/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py b/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py index bc0a1a0f4..f13792603 100644 --- a/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py +++ b/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py @@ -1,19 +1,19 @@ -# -*- coding: utf-8 -*- """ 多阶段召回机制 实现粗粒度到细粒度的记忆检索优化 """ import time -from typing import Dict, List, Optional, Set, Any from dataclasses import dataclass, field from enum import Enum -import orjson +from typing import Any -from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +import orjson from src.chat.memory_system.enhanced_reranker import EnhancedReRanker, ReRankingConfig +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +from src.common.logger import get_logger + logger = get_logger(__name__) @@ -73,11 +73,11 @@ class StageResult: """阶段结果""" stage: RetrievalStage - memory_ids: List[str] + memory_ids: list[str] processing_time: float filtered_count: int score_threshold: float - details: List[Dict[str, Any]] = field(default_factory=list) + details: list[dict[str, Any]] = field(default_factory=list) @dataclass @@ -86,17 +86,17 @@ class RetrievalResult: query: str user_id: str - final_memories: List[MemoryChunk] - stage_results: List[StageResult] + final_memories: list[MemoryChunk] + stage_results: list[StageResult] total_processing_time: float total_filtered: int - retrieval_stats: Dict[str, Any] + retrieval_stats: dict[str, Any] class MultiStageRetrieval: """多阶段召回系统""" - def __init__(self, config: Optional[RetrievalConfig] = None): + def __init__(self, config: RetrievalConfig | None = None): self.config = config or RetrievalConfig.from_global_config() # 初始化增强重排序器 @@ -124,11 +124,11 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], + context: dict[str, Any], metadata_index, vector_storage, - all_memories_cache: Dict[str, MemoryChunk], - limit: Optional[int] = None, + all_memories_cache: dict[str, MemoryChunk], + limit: int | None = None, ) -> RetrievalResult: """多阶段记忆检索""" start_time = time.time() @@ -136,7 +136,7 @@ class MultiStageRetrieval: stage_results = [] current_memory_ids = set() - memory_debug_info: Dict[str, Dict[str, Any]] = {} + memory_debug_info: dict[str, dict[str, Any]] = {} try: logger.debug(f"开始多阶段检索:query='{query}', user_id='{user_id}'") @@ -311,11 +311,11 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], + context: dict[str, Any], metadata_index, - all_memories_cache: Dict[str, MemoryChunk], + all_memories_cache: dict[str, MemoryChunk], *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None, + debug_log: dict[str, dict[str, Any]] | None = None, ) -> StageResult: """阶段1:元数据过滤""" start_time = time.time() @@ -345,7 +345,7 @@ class MultiStageRetrieval: result = await metadata_index.query_memories(index_query) result_ids = list(result.memory_ids) filtered_count = max(0, len(all_memories_cache) - len(result_ids)) - details: List[Dict[str, Any]] = [] + details: list[dict[str, Any]] = [] # 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆 if not result_ids: @@ -440,12 +440,12 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], + context: dict[str, Any], vector_storage, - candidate_ids: Set[str], - all_memories_cache: Dict[str, MemoryChunk], + candidate_ids: set[str], + all_memories_cache: dict[str, MemoryChunk], *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None, + debug_log: dict[str, dict[str, Any]] | None = None, ) -> StageResult: """阶段2:向量搜索""" start_time = time.time() @@ -479,8 +479,8 @@ class MultiStageRetrieval: # 过滤候选记忆 filtered_memories = [] - details: List[Dict[str, Any]] = [] - raw_details: List[Dict[str, Any]] = [] + details: list[dict[str, Any]] = [] + raw_details: list[dict[str, Any]] = [] threshold = self.config.vector_similarity_threshold for memory_id, similarity in search_result: @@ -561,7 +561,7 @@ class MultiStageRetrieval: ) def _create_text_search_fallback( - self, candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], query_text: str, start_time: float + self, candidate_ids: set[str], all_memories_cache: dict[str, MemoryChunk], query_text: str, start_time: float ) -> StageResult: """当向量搜索失败时,使用文本搜索作为回退策略""" try: @@ -618,18 +618,18 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], - candidate_ids: Set[str], - all_memories_cache: Dict[str, MemoryChunk], + context: dict[str, Any], + candidate_ids: set[str], + all_memories_cache: dict[str, MemoryChunk], *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None, + debug_log: dict[str, dict[str, Any]] | None = None, ) -> StageResult: """阶段3:语义重排序""" start_time = time.time() try: reranked_memories = [] - details: List[Dict[str, Any]] = [] + details: list[dict[str, Any]] = [] threshold = self.config.semantic_similarity_threshold for memory_id in candidate_ids: @@ -704,19 +704,19 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], - candidate_ids: List[str], - all_memories_cache: Dict[str, MemoryChunk], + context: dict[str, Any], + candidate_ids: list[str], + all_memories_cache: dict[str, MemoryChunk], limit: int, *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None, + debug_log: dict[str, dict[str, Any]] | None = None, ) -> StageResult: """阶段4:上下文过滤""" start_time = time.time() try: final_memories = [] - details: List[Dict[str, Any]] = [] + details: list[dict[str, Any]] = [] for memory_id in candidate_ids: if memory_id not in all_memories_cache: @@ -793,12 +793,12 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], - all_memories_cache: Dict[str, MemoryChunk], + context: dict[str, Any], + all_memories_cache: dict[str, MemoryChunk], limit: int, *, - excluded_ids: Optional[Set[str]] = None, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None, + excluded_ids: set[str] | None = None, + debug_log: dict[str, dict[str, Any]] | None = None, ) -> StageResult: """回退检索阶段 - 当主检索失败时使用更宽松的策略""" start_time = time.time() @@ -881,8 +881,8 @@ class MultiStageRetrieval: ) async def _generate_query_embedding( - self, query: str, context: Dict[str, Any], vector_storage - ) -> Optional[List[float]]: + self, query: str, context: dict[str, Any], vector_storage + ) -> list[float] | None: """生成查询向量""" try: query_plan = context.get("query_plan") @@ -916,7 +916,7 @@ class MultiStageRetrieval: logger.error(f"生成查询向量时发生异常: {e}", exc_info=True) return None - async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float: + async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float: """计算语义相似度 - 简化优化版本,提升召回率""" try: query_plan = context.get("query_plan") @@ -947,9 +947,10 @@ class MultiStageRetrieval: # 核心匹配策略2:词汇匹配 word_score = 0.0 try: - import jieba import re + import jieba + # 分词处理 query_words = list(jieba.cut(query_text)) + re.findall(r"[a-zA-Z]+", query_text) memory_words = list(jieba.cut(memory_text)) + re.findall(r"[a-zA-Z]+", memory_text) @@ -1059,7 +1060,7 @@ class MultiStageRetrieval: logger.warning(f"计算语义相似度失败: {e}") return 0.0 - async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float: + async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float: """计算上下文相关度""" try: score = 0.0 @@ -1132,7 +1133,7 @@ class MultiStageRetrieval: return 0.0 async def _calculate_final_score( - self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float + self, query: str, memory: MemoryChunk, context: dict[str, Any], context_score: float ) -> float: """计算最终评分""" try: @@ -1184,7 +1185,7 @@ class MultiStageRetrieval: logger.warning(f"计算最终评分失败: {e}") return 0.0 - def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: Optional[List[str]]) -> float: + def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: list[str] | None) -> float: if not required_subjects: return 0.0 @@ -1229,7 +1230,7 @@ class MultiStageRetrieval: except Exception: return 0.5 - def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]: + def _extract_memory_types_from_context(self, context: dict[str, Any]) -> list[MemoryType]: """从上下文中提取记忆类型""" try: query_plan = context.get("query_plan") @@ -1256,10 +1257,10 @@ class MultiStageRetrieval: except Exception: return [] - def _extract_keywords_from_query(self, query: str, query_plan: Optional[Any] = None) -> List[str]: + def _extract_keywords_from_query(self, query: str, query_plan: Any | None = None) -> list[str]: """从查询中提取关键词""" try: - extracted: List[str] = [] + extracted: list[str] = [] if query_plan and getattr(query_plan, "required_keywords", None): extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)]) @@ -1283,7 +1284,7 @@ class MultiStageRetrieval: except Exception: return [] - def _update_retrieval_stats(self, total_time: float, stage_results: List[StageResult]): + def _update_retrieval_stats(self, total_time: float, stage_results: list[StageResult]): """更新检索统计""" self.retrieval_stats["total_queries"] += 1 @@ -1306,7 +1307,7 @@ class MultiStageRetrieval: ] stage_stat["avg_time"] = new_stage_avg - def get_retrieval_stats(self) -> Dict[str, Any]: + def get_retrieval_stats(self) -> dict[str, Any]: """获取检索统计信息""" return self.retrieval_stats.copy() @@ -1328,12 +1329,12 @@ class MultiStageRetrieval: self, query: str, user_id: str, - context: Dict[str, Any], - candidate_ids: List[str], - all_memories_cache: Dict[str, MemoryChunk], + context: dict[str, Any], + candidate_ids: list[str], + all_memories_cache: dict[str, MemoryChunk], limit: int, *, - debug_log: Optional[Dict[str, Dict[str, Any]]] = None, + debug_log: dict[str, dict[str, Any]] | None = None, ) -> StageResult: """阶段5:增强重排序 - 使用多维度评分模型""" start_time = time.time() diff --git a/src/chat/memory_system/deprecated_backup/vector_storage.py b/src/chat/memory_system/deprecated_backup/vector_storage.py index 5d2e4fb91..d5d974486 100644 --- a/src/chat/memory_system/deprecated_backup/vector_storage.py +++ b/src/chat/memory_system/deprecated_backup/vector_storage.py @@ -1,24 +1,23 @@ -# -*- coding: utf-8 -*- """ 向量数据库存储接口 为记忆系统提供高效的向量存储和语义搜索能力 """ -import time -import orjson import asyncio -from typing import Dict, List, Optional, Tuple, Any -from dataclasses import dataclass import threading +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any import numpy as np -from pathlib import Path +import orjson -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config -from src.common.config_helpers import resolve_embedding_dimension from src.chat.memory_system.memory_chunk import MemoryChunk +from src.common.config_helpers import resolve_embedding_dimension +from src.common.logger import get_logger +from src.config.config import model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger(__name__) @@ -48,7 +47,7 @@ class VectorStorageConfig: class VectorStorageManager: """向量存储管理器""" - def __init__(self, config: Optional[VectorStorageConfig] = None): + def __init__(self, config: VectorStorageConfig | None = None): self.config = config or VectorStorageConfig() resolved_dimension = resolve_embedding_dimension(self.config.dimension) @@ -68,8 +67,8 @@ class VectorStorageManager: self.index_to_memory_id = {} # vector index -> memory_id # 内存缓存 - self.memory_cache: Dict[str, MemoryChunk] = {} - self.vector_cache: Dict[str, List[float]] = {} + self.memory_cache: dict[str, MemoryChunk] = {} + self.vector_cache: dict[str, list[float]] = {} # 统计信息 self.storage_stats = { @@ -125,7 +124,7 @@ class VectorStorageManager: ) logger.info("✅ 嵌入模型初始化完成") - async def generate_query_embedding(self, query_text: str) -> Optional[List[float]]: + async def generate_query_embedding(self, query_text: str) -> list[float] | None: """生成查询向量,用于记忆召回""" if not query_text: logger.warning("查询文本为空,无法生成向量") @@ -155,7 +154,7 @@ class VectorStorageManager: logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True) return None - async def store_memories(self, memories: List[MemoryChunk]): + async def store_memories(self, memories: list[MemoryChunk]): """存储记忆向量""" if not memories: return @@ -231,7 +230,7 @@ class VectorStorageManager: logger.debug("记忆 %s 缺少可用展示文本,使用占位符生成嵌入输入", memory.memory_id) return memory.memory_id - async def _batch_generate_and_store_embeddings(self, memory_texts: List[Tuple[str, str]]): + async def _batch_generate_and_store_embeddings(self, memory_texts: list[tuple[str, str]]): """批量生成和存储嵌入向量""" if not memory_texts: return @@ -253,12 +252,12 @@ class VectorStorageManager: except Exception as e: logger.error(f"❌ 批量生成嵌入向量失败: {e}") - async def _batch_generate_embeddings(self, memory_ids: List[str], texts: List[str]) -> Dict[str, List[float]]: + async def _batch_generate_embeddings(self, memory_ids: list[str], texts: list[str]) -> dict[str, list[float]]: """批量生成嵌入向量""" if not texts: return {} - results: Dict[str, List[float]] = {} + results: dict[str, list[float]] = {} try: semaphore = asyncio.Semaphore(min(4, max(1, len(texts)))) @@ -281,7 +280,9 @@ class VectorStorageManager: logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc) results[memory_id] = [] - tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)] + tasks = [ + asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False) + ] await asyncio.gather(*tasks, return_exceptions=True) except Exception as e: @@ -291,7 +292,7 @@ class VectorStorageManager: return results - async def _add_single_memory(self, memory: MemoryChunk, embedding: List[float]): + async def _add_single_memory(self, memory: MemoryChunk, embedding: list[float]): """添加单个记忆到向量存储""" with self._lock: try: @@ -337,7 +338,7 @@ class VectorStorageManager: except Exception as e: logger.error(f"❌ 添加记忆到向量存储失败: {e}") - def _normalize_vector(self, vector: List[float]) -> List[float]: + def _normalize_vector(self, vector: list[float]) -> list[float]: """L2归一化向量""" if not vector: return vector @@ -357,12 +358,12 @@ class VectorStorageManager: async def search_similar_memories( self, - query_vector: Optional[List[float]] = None, + query_vector: list[float] | None = None, *, - query_text: Optional[str] = None, + query_text: str | None = None, limit: int = 10, - scope_id: Optional[str] = None, - ) -> List[Tuple[str, float]]: + scope_id: str | None = None, + ) -> list[tuple[str, float]]: """搜索相似记忆""" start_time = time.time() @@ -379,7 +380,7 @@ class VectorStorageManager: logger.warning("查询向量生成失败") return [] - scope_filter: Optional[str] = None + scope_filter: str | None = None if isinstance(scope_id, str): normalized_scope = scope_id.strip().lower() if normalized_scope and normalized_scope not in {"global", "global_memory"}: @@ -491,7 +492,7 @@ class VectorStorageManager: logger.error(f"❌ 向量搜索失败: {e}", exc_info=True) return [] - async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]: + async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None: """根据ID获取记忆""" # 先检查缓存 if memory_id in self.memory_cache: @@ -501,7 +502,7 @@ class VectorStorageManager: self.storage_stats["total_searches"] += 1 return None - async def update_memory_embedding(self, memory_id: str, new_embedding: List[float]): + async def update_memory_embedding(self, memory_id: str, new_embedding: list[float]): """更新记忆的嵌入向量""" with self._lock: try: @@ -636,7 +637,7 @@ class VectorStorageManager: # 加载记忆缓存 cache_file = self.storage_path / "memory_cache.json" if cache_file.exists(): - with open(cache_file, "r", encoding="utf-8") as f: + with open(cache_file, encoding="utf-8") as f: cache_data = orjson.loads(f.read()) self.memory_cache = { @@ -646,13 +647,13 @@ class VectorStorageManager: # 加载向量缓存 vector_cache_file = self.storage_path / "vector_cache.json" if vector_cache_file.exists(): - with open(vector_cache_file, "r", encoding="utf-8") as f: + with open(vector_cache_file, encoding="utf-8") as f: self.vector_cache = orjson.loads(f.read()) # 加载映射关系 mapping_file = self.storage_path / "id_mapping.json" if mapping_file.exists(): - with open(mapping_file, "r", encoding="utf-8") as f: + with open(mapping_file, encoding="utf-8") as f: mapping_data = orjson.loads(f.read()) raw_memory_to_index = mapping_data.get("memory_id_to_index", {}) self.memory_id_to_index = { @@ -689,7 +690,7 @@ class VectorStorageManager: # 加载统计信息 stats_file = self.storage_path / "storage_stats.json" if stats_file.exists(): - with open(stats_file, "r", encoding="utf-8") as f: + with open(stats_file, encoding="utf-8") as f: self.storage_stats = orjson.loads(f.read()) # 更新向量计数 @@ -806,7 +807,7 @@ class VectorStorageManager: if invalid_memory_ids: logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用") - def get_storage_stats(self) -> Dict[str, Any]: + def get_storage_stats(self) -> dict[str, Any]: """获取存储统计信息""" stats = self.storage_stats.copy() if stats["total_searches"] > 0: @@ -821,11 +822,11 @@ class SimpleVectorIndex: def __init__(self, dimension: int): self.dimension = dimension - self.vectors: List[List[float]] = [] - self.vector_ids: List[int] = [] + self.vectors: list[list[float]] = [] + self.vector_ids: list[int] = [] self.next_id = 0 - def add_vector(self, vector: List[float]) -> int: + def add_vector(self, vector: list[float]) -> int: """添加向量""" if len(vector) != self.dimension: raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}") @@ -837,7 +838,7 @@ class SimpleVectorIndex: return vector_id - def search(self, query_vector: List[float], limit: int) -> List[Tuple[int, float]]: + def search(self, query_vector: list[float], limit: int) -> list[tuple[int, float]]: """搜索相似向量""" if len(query_vector) != self.dimension: raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}") @@ -853,7 +854,7 @@ class SimpleVectorIndex: return results[:limit] - def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float: + def _calculate_cosine_similarity(self, v1: list[float], v2: list[float]) -> float: """计算余弦相似度""" try: dot_product = sum(x * y for x, y in zip(v1, v2, strict=False)) diff --git a/src/chat/memory_system/enhanced_memory_activator.py b/src/chat/memory_system/enhanced_memory_activator.py index 7570715ee..22b44c7a1 100644 --- a/src/chat/memory_system/enhanced_memory_activator.py +++ b/src/chat/memory_system/enhanced_memory_activator.py @@ -1,25 +1,24 @@ -# -*- coding: utf-8 -*- """ 记忆激活器 记忆系统的激活器组件 """ import difflib -import orjson -from typing import List, Dict, Optional from datetime import datetime +import orjson from json_repair import repair_json -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger -from src.chat.utils.prompt import Prompt, global_prompt_manager + from src.chat.memory_system.memory_manager import MemoryResult +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("memory_activator") -def get_keywords_from_json(json_str) -> List: +def get_keywords_from_json(json_str) -> list: """ 从JSON字符串中提取关键词列表 @@ -81,7 +80,7 @@ class MemoryActivator: self.cached_keywords = set() # 用于缓存历史关键词 self.last_memory_query_time = 0 # 上次查询记忆的时间 - async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]: + async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]: """ 激活记忆 """ @@ -155,7 +154,7 @@ class MemoryActivator: return self.running_memory - async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]: + async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]: """查询统一记忆系统""" try: # 使用记忆系统 @@ -198,7 +197,7 @@ class MemoryActivator: logger.error(f"查询统一记忆失败: {e}") return [] - async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]: + async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None: """ 获取即时记忆 - 兼容原有接口(使用统一存储) """ diff --git a/src/chat/memory_system/memory_activator_new.py b/src/chat/memory_system/memory_activator_new.py index 491034de4..0b4e9a938 100644 --- a/src/chat/memory_system/memory_activator_new.py +++ b/src/chat/memory_system/memory_activator_new.py @@ -1,25 +1,24 @@ -# -*- coding: utf-8 -*- """ 记忆激活器 记忆系统的激活器组件 """ import difflib -import orjson -from typing import List, Dict, Optional from datetime import datetime +import orjson from json_repair import repair_json -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger -from src.chat.utils.prompt import Prompt, global_prompt_manager + from src.chat.memory_system.memory_manager import MemoryResult +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("memory_activator") -def get_keywords_from_json(json_str) -> List: +def get_keywords_from_json(json_str) -> list: """ 从JSON字符串中提取关键词列表 @@ -81,7 +80,7 @@ class MemoryActivator: self.cached_keywords = set() # 用于缓存历史关键词 self.last_memory_query_time = 0 # 上次查询记忆的时间 - async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]: + async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]: """ 激活记忆 """ @@ -155,7 +154,7 @@ class MemoryActivator: return self.running_memory - async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]: + async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]: """查询统一记忆系统""" try: # 使用记忆系统 @@ -198,7 +197,7 @@ class MemoryActivator: logger.error(f"查询统一记忆失败: {e}") return [] - async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]: + async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None: """ 获取即时记忆 - 兼容原有接口(使用统一存储) """ diff --git a/src/chat/memory_system/memory_builder.py b/src/chat/memory_system/memory_builder.py index 0c3f47043..a2f936028 100644 --- a/src/chat/memory_system/memory_builder.py +++ b/src/chat/memory_system/memory_builder.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 记忆构建模块 从对话流中提取高质量、结构化记忆单元 @@ -33,19 +32,19 @@ import time from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Union, Type +from typing import Any import orjson -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest from src.chat.memory_system.memory_chunk import ( - MemoryChunk, - MemoryType, ConfidenceLevel, ImportanceLevel, + MemoryChunk, + MemoryType, create_memory_chunk, ) +from src.common.logger import get_logger +from src.llm_models.utils_model import LLMRequest logger = get_logger(__name__) @@ -62,8 +61,8 @@ class ExtractionStrategy(Enum): class ExtractionResult: """提取结果""" - memories: List[MemoryChunk] - confidence_scores: List[float] + memories: list[MemoryChunk] + confidence_scores: list[float] extraction_time: float strategy_used: ExtractionStrategy @@ -85,8 +84,8 @@ class MemoryBuilder: } async def build_memories( - self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: float - ) -> List[MemoryChunk]: + self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float + ) -> list[MemoryChunk]: """从对话中构建记忆""" start_time = time.time() @@ -116,8 +115,8 @@ class MemoryBuilder: raise async def _extract_with_llm( - self, text: str, context: Dict[str, Any], user_id: str, timestamp: float - ) -> List[MemoryChunk]: + self, text: str, context: dict[str, Any], user_id: str, timestamp: float + ) -> list[MemoryChunk]: """使用LLM提取记忆""" try: prompt = self._build_llm_extraction_prompt(text, context) @@ -135,7 +134,7 @@ class MemoryBuilder: logger.error(f"LLM提取失败: {e}") raise MemoryExtractionError(str(e)) from e - def _build_llm_extraction_prompt(self, text: str, context: Dict[str, Any]) -> str: + def _build_llm_extraction_prompt(self, text: str, context: dict[str, Any]) -> str: """构建LLM提取提示""" current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") message_type = context.get("message_type", "normal") @@ -315,7 +314,7 @@ class MemoryBuilder: return prompt - def _extract_json_payload(self, response: str) -> Optional[str]: + def _extract_json_payload(self, response: str) -> str | None: """从模型响应中提取JSON部分,兼容Markdown代码块等格式""" if not response: return None @@ -338,8 +337,8 @@ class MemoryBuilder: return stripped if stripped.startswith("{") and stripped.endswith("}") else None def _parse_llm_response( - self, response: str, user_id: str, timestamp: float, context: Dict[str, Any] - ) -> List[MemoryChunk]: + self, response: str, user_id: str, timestamp: float, context: dict[str, Any] + ) -> list[MemoryChunk]: """解析LLM响应""" if not response: raise MemoryExtractionError("LLM未返回任何响应") @@ -385,7 +384,7 @@ class MemoryBuilder: bot_display = self._clean_subject_text(bot_display) - memories: List[MemoryChunk] = [] + memories: list[MemoryChunk] = [] for mem_data in memory_list: try: @@ -460,7 +459,7 @@ class MemoryBuilder: return memories - def _parse_enum_value(self, enum_cls: Type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum: + def _parse_enum_value(self, enum_cls: type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum: """解析枚举值,兼容数字/字符串表示""" if isinstance(raw_value, enum_cls): return raw_value @@ -514,7 +513,7 @@ class MemoryBuilder: ) return default - def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]: + def _collect_bot_identifiers(self, context: dict[str, Any] | None) -> set[str]: identifiers: set[str] = {"bot", "机器人", "ai助手"} if not context: return identifiers @@ -540,7 +539,7 @@ class MemoryBuilder: return identifiers - def _collect_system_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]: + def _collect_system_identifiers(self, context: dict[str, Any] | None) -> set[str]: identifiers: set[str] = set() if not context: return identifiers @@ -568,8 +567,8 @@ class MemoryBuilder: return identifiers - def _resolve_conversation_participants(self, context: Optional[Dict[str, Any]], user_id: str) -> List[str]: - participants: List[str] = [] + def _resolve_conversation_participants(self, context: dict[str, Any] | None, user_id: str) -> list[str]: + participants: list[str] = [] if context: candidate_keys = [ @@ -609,7 +608,7 @@ class MemoryBuilder: if not participants: participants = ["对话参与者"] - deduplicated: List[str] = [] + deduplicated: list[str] = [] seen = set() for name in participants: key = name.lower() @@ -620,7 +619,7 @@ class MemoryBuilder: return deduplicated - def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str: + def _resolve_user_display(self, context: dict[str, Any] | None, user_id: str) -> str: candidate_keys = [ "user_display_name", "user_name", @@ -683,7 +682,7 @@ class MemoryBuilder: return False - def _split_subject_string(self, value: str) -> List[str]: + def _split_subject_string(self, value: str) -> list[str]: if not value: return [] @@ -699,12 +698,12 @@ class MemoryBuilder: subject: Any, bot_identifiers: set[str], system_identifiers: set[str], - default_subjects: List[str], - bot_display: Optional[str] = None, - ) -> List[str]: + default_subjects: list[str], + bot_display: str | None = None, + ) -> list[str]: defaults = default_subjects or ["对话参与者"] - raw_candidates: List[str] = [] + raw_candidates: list[str] = [] if isinstance(subject, list): for item in subject: if isinstance(item, str): @@ -716,7 +715,7 @@ class MemoryBuilder: elif subject is not None: raw_candidates.extend(self._split_subject_string(str(subject))) - normalized: List[str] = [] + normalized: list[str] = [] bot_primary = self._clean_subject_text(bot_display or "") for candidate in raw_candidates: @@ -741,7 +740,7 @@ class MemoryBuilder: if not normalized: normalized = list(defaults) - deduplicated: List[str] = [] + deduplicated: list[str] = [] seen = set() for name in normalized: key = name.lower() @@ -752,7 +751,7 @@ class MemoryBuilder: return deduplicated - def _extract_value_from_object(self, obj: Union[str, Dict[str, Any], List[Any]], keys: List[str]) -> Optional[str]: + def _extract_value_from_object(self, obj: str | dict[str, Any] | list[Any], keys: list[str]) -> str | None: if isinstance(obj, dict): for key in keys: value = obj.get(key) @@ -773,9 +772,7 @@ class MemoryBuilder: return obj.strip() or None return None - def _compose_display_text( - self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]] - ) -> str: + def _compose_display_text(self, subjects: list[str], predicate: str, obj: str | dict[str, Any] | list[Any]) -> str: subject_phrase = "、".join(subjects) if subjects else "对话参与者" predicate = (predicate or "").strip() @@ -841,7 +838,7 @@ class MemoryBuilder: return f"{subject_phrase}{predicate}".strip() return subject_phrase - def _validate_and_enhance_memories(self, memories: List[MemoryChunk], context: Dict[str, Any]) -> List[MemoryChunk]: + def _validate_and_enhance_memories(self, memories: list[MemoryChunk], context: dict[str, Any]) -> list[MemoryChunk]: """验证和增强记忆""" validated_memories = [] @@ -876,7 +873,7 @@ class MemoryBuilder: return True - def _enhance_memory(self, memory: MemoryChunk, context: Dict[str, Any]) -> MemoryChunk: + def _enhance_memory(self, memory: MemoryChunk, context: dict[str, Any]) -> MemoryChunk: """增强记忆块""" # 时间规范化处理 self._normalize_time_in_memory(memory) @@ -985,7 +982,7 @@ class MemoryBuilder: total_confidence / self.extraction_stats["successful_extractions"] ) - def get_extraction_stats(self) -> Dict[str, Any]: + def get_extraction_stats(self) -> dict[str, Any]: """获取提取统计信息""" return self.extraction_stats.copy() diff --git a/src/chat/memory_system/memory_chunk.py b/src/chat/memory_system/memory_chunk.py index b5b609af6..dcce6eb64 100644 --- a/src/chat/memory_system/memory_chunk.py +++ b/src/chat/memory_system/memory_chunk.py @@ -1,18 +1,19 @@ -# -*- coding: utf-8 -*- """ 结构化记忆单元设计 实现高质量、结构化的记忆单元,符合文档设计规范 """ +import hashlib import time import uuid -import orjson -from typing import Dict, List, Optional, Any, Union, Iterable +from collections.abc import Iterable from dataclasses import dataclass, field from enum import Enum -import hashlib +from typing import Any import numpy as np +import orjson + from src.common.logger import get_logger logger = get_logger(__name__) @@ -56,17 +57,17 @@ class ImportanceLevel(Enum): class ContentStructure: """主谓宾结构,包含自然语言描述""" - subject: Union[str, List[str]] + subject: str | list[str] predicate: str - object: Union[str, Dict] + object: str | dict display: str = "" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为字典格式""" return {"subject": self.subject, "predicate": self.predicate, "object": self.object, "display": self.display} @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure": + def from_dict(cls, data: dict[str, Any]) -> "ContentStructure": """从字典创建实例""" return cls( subject=data.get("subject", ""), @@ -75,7 +76,7 @@ class ContentStructure: display=data.get("display", ""), ) - def to_subject_list(self) -> List[str]: + def to_subject_list(self) -> list[str]: """将主语转换为列表形式""" if isinstance(self.subject, list): return [s for s in self.subject if isinstance(s, str) and s.strip()] @@ -99,7 +100,7 @@ class MemoryMetadata: # 基础信息 memory_id: str # 唯一标识符 user_id: str # 用户ID - chat_id: Optional[str] = None # 聊天ID(群聊或私聊) + chat_id: str | None = None # 聊天ID(群聊或私聊) # 时间信息 created_at: float = 0.0 # 创建时间戳 @@ -124,9 +125,9 @@ class MemoryMetadata: last_forgetting_check: float = 0.0 # 上次遗忘检查时间 # 来源信息 - source_context: Optional[str] = None # 来源上下文片段 + source_context: str | None = None # 来源上下文片段 # 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source - source: Optional[str] = None + source: str | None = None def __post_init__(self): """后初始化处理""" @@ -209,7 +210,7 @@ class MemoryMetadata: # 设置最小和最大阈值 return max(7.0, min(threshold, 365.0)) # 7天到1年之间 - def should_forget(self, current_time: Optional[float] = None) -> bool: + def should_forget(self, current_time: float | None = None) -> bool: """判断是否应该遗忘""" if current_time is None: current_time = time.time() @@ -222,7 +223,7 @@ class MemoryMetadata: return days_since_activation > self.forgetting_threshold - def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool: + def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool: """判断是否处于休眠状态(长期未激活)""" if current_time is None: current_time = time.time() @@ -230,7 +231,7 @@ class MemoryMetadata: days_since_last_access = (current_time - self.last_accessed) / 86400 return days_since_last_access > inactive_days - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为字典格式""" return { "memory_id": self.memory_id, @@ -252,7 +253,7 @@ class MemoryMetadata: } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "MemoryMetadata": + def from_dict(cls, data: dict[str, Any]) -> "MemoryMetadata": """从字典创建实例""" return cls( memory_id=data.get("memory_id", ""), @@ -286,17 +287,17 @@ class MemoryChunk: memory_type: MemoryType # 记忆类型 # 扩展信息 - keywords: List[str] = field(default_factory=list) # 关键词列表 - tags: List[str] = field(default_factory=list) # 标签列表 - categories: List[str] = field(default_factory=list) # 分类列表 + keywords: list[str] = field(default_factory=list) # 关键词列表 + tags: list[str] = field(default_factory=list) # 标签列表 + categories: list[str] = field(default_factory=list) # 分类列表 # 语义信息 - embedding: Optional[List[float]] = None # 语义向量 - semantic_hash: Optional[str] = None # 语义哈希值 + embedding: list[float] | None = None # 语义向量 + semantic_hash: str | None = None # 语义哈希值 # 关联信息 - related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表 - temporal_context: Optional[Dict[str, Any]] = None # 时间上下文 + related_memories: list[str] = field(default_factory=list) # 关联记忆ID列表 + temporal_context: dict[str, Any] | None = None # 时间上下文 def __post_init__(self): """后初始化处理""" @@ -310,7 +311,7 @@ class MemoryChunk: try: # 使用向量和内容生成稳定的哈希 - content_str = f"{self.content.subject}:{self.content.predicate}:{str(self.content.object)}" + content_str = f"{self.content.subject}:{self.content.predicate}:{self.content.object!s}" embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding])) hash_input = f"{content_str}|{embedding_str}" @@ -342,7 +343,7 @@ class MemoryChunk: return self.content.display or str(self.content) @property - def subjects(self) -> List[str]: + def subjects(self) -> list[str]: """获取主语列表""" return self.content.to_subject_list() @@ -354,11 +355,11 @@ class MemoryChunk: """更新相关度评分""" self.metadata.update_relevance(new_score) - def should_forget(self, current_time: Optional[float] = None) -> bool: + def should_forget(self, current_time: float | None = None) -> bool: """判断是否应该遗忘""" return self.metadata.should_forget(current_time) - def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool: + def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool: """判断是否处于休眠状态(长期未激活)""" return self.metadata.is_dormant(current_time, inactive_days) @@ -386,7 +387,7 @@ class MemoryChunk: if memory_id and memory_id not in self.related_memories: self.related_memories.append(memory_id) - def set_embedding(self, embedding: List[float]): + def set_embedding(self, embedding: list[float]): """设置语义向量""" self.embedding = embedding self._generate_semantic_hash() @@ -415,7 +416,7 @@ class MemoryChunk: logger.warning(f"计算记忆相似度失败: {e}") return 0.0 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为完整的字典格式""" return { "metadata": self.metadata.to_dict(), @@ -431,7 +432,7 @@ class MemoryChunk: } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "MemoryChunk": + def from_dict(cls, data: dict[str, Any]) -> "MemoryChunk": """从字典创建实例""" metadata = MemoryMetadata.from_dict(data.get("metadata", {})) content = ContentStructure.from_dict(data.get("content", {})) @@ -541,7 +542,7 @@ class MemoryChunk: return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})" -def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, Dict]) -> str: +def _build_display_text(subjects: Iterable[str], predicate: str, obj: str | dict) -> str: """根据主谓宾生成自然语言描述""" subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)] subject_part = "、".join(subjects_clean) if subjects_clean else "对话参与者" @@ -569,15 +570,15 @@ def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, def create_memory_chunk( user_id: str, - subject: Union[str, List[str]], + subject: str | list[str], predicate: str, - obj: Union[str, Dict], + obj: str | dict, memory_type: MemoryType, - chat_id: Optional[str] = None, - source_context: Optional[str] = None, + chat_id: str | None = None, + source_context: str | None = None, importance: ImportanceLevel = ImportanceLevel.NORMAL, confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM, - display: Optional[str] = None, + display: str | None = None, **kwargs, ) -> MemoryChunk: """便捷的内存块创建函数""" @@ -593,10 +594,10 @@ def create_memory_chunk( source_context=source_context, ) - subjects: List[str] + subjects: list[str] if isinstance(subject, list): subjects = [s for s in subject if isinstance(s, str) and s.strip()] - subject_payload: Union[str, List[str]] = subjects + subject_payload: str | list[str] = subjects else: cleaned = subject.strip() if isinstance(subject, str) else "" subjects = [cleaned] if cleaned else [] diff --git a/src/chat/memory_system/memory_forgetting_engine.py b/src/chat/memory_system/memory_forgetting_engine.py index 3e243e433..e41d1149c 100644 --- a/src/chat/memory_system/memory_forgetting_engine.py +++ b/src/chat/memory_system/memory_forgetting_engine.py @@ -1,17 +1,15 @@ -# -*- coding: utf-8 -*- """ 智能记忆遗忘引擎 基于重要程度、置信度和激活频率的智能遗忘机制 """ -import time import asyncio -from typing import List, Dict, Optional, Tuple -from datetime import datetime +import time from dataclasses import dataclass +from datetime import datetime +from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import MemoryChunk, ImportanceLevel, ConfidenceLevel logger = get_logger(__name__) @@ -65,7 +63,7 @@ class ForgettingConfig: class MemoryForgettingEngine: """智能记忆遗忘引擎""" - def __init__(self, config: Optional[ForgettingConfig] = None): + def __init__(self, config: ForgettingConfig | None = None): self.config = config or ForgettingConfig() self.stats = ForgettingStats() self._last_forgetting_check = 0.0 @@ -116,7 +114,7 @@ class MemoryForgettingEngine: # 确保在合理范围内 return max(self.config.min_forgetting_days, min(threshold, self.config.max_forgetting_days)) - def should_forget_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool: + def should_forget_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool: """ 判断记忆是否应该被遗忘 @@ -155,7 +153,7 @@ class MemoryForgettingEngine: return should_forget - def is_dormant_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool: + def is_dormant_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool: """ 判断记忆是否处于休眠状态 @@ -168,7 +166,7 @@ class MemoryForgettingEngine: """ return memory.is_dormant(current_time, self.config.dormant_threshold_days) - def should_force_forget_dormant(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool: + def should_force_forget_dormant(self, memory: MemoryChunk, current_time: float | None = None) -> bool: """ 判断是否应该强制遗忘休眠记忆 @@ -189,7 +187,7 @@ class MemoryForgettingEngine: days_since_last_access = (current_time - memory.metadata.last_accessed) / 86400 return days_since_last_access > self.config.force_forget_dormant_days - async def check_memories_for_forgetting(self, memories: List[MemoryChunk]) -> Tuple[List[str], List[str]]: + async def check_memories_for_forgetting(self, memories: list[MemoryChunk]) -> tuple[list[str], list[str]]: """ 检查记忆列表,识别需要遗忘的记忆 @@ -241,7 +239,7 @@ class MemoryForgettingEngine: return normal_forgetting_ids, force_forgetting_ids - async def perform_forgetting_check(self, memories: List[MemoryChunk]) -> Dict[str, any]: + async def perform_forgetting_check(self, memories: list[MemoryChunk]) -> dict[str, any]: """ 执行完整的遗忘检查流程 @@ -314,7 +312,7 @@ class MemoryForgettingEngine: except Exception as e: logger.error(f"定期遗忘检查失败: {e}", exc_info=True) - def get_forgetting_stats(self) -> Dict[str, any]: + def get_forgetting_stats(self) -> dict[str, any]: """获取遗忘统计信息""" return { "total_checked": self.stats.total_checked, diff --git a/src/chat/memory_system/memory_fusion.py b/src/chat/memory_system/memory_fusion.py index 3ecc4cd71..59f36ed93 100644 --- a/src/chat/memory_system/memory_fusion.py +++ b/src/chat/memory_system/memory_fusion.py @@ -1,16 +1,14 @@ -# -*- coding: utf-8 -*- """ 记忆融合与去重机制 避免记忆碎片化,确保长期记忆库的高质量 """ import time -from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass +from typing import Any - +from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk from src.common.logger import get_logger -from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel logger = get_logger(__name__) @@ -22,9 +20,9 @@ class FusionResult: original_count: int fused_count: int removed_duplicates: int - merged_memories: List[MemoryChunk] + merged_memories: list[MemoryChunk] fusion_time: float - details: List[str] + details: list[str] @dataclass @@ -32,9 +30,9 @@ class DuplicateGroup: """重复记忆组""" group_id: str - memories: List[MemoryChunk] - similarity_matrix: List[List[float]] - representative_memory: Optional[MemoryChunk] = None + memories: list[MemoryChunk] + similarity_matrix: list[list[float]] + representative_memory: MemoryChunk | None = None class MemoryFusionEngine: @@ -59,8 +57,8 @@ class MemoryFusionEngine: } async def fuse_memories( - self, new_memories: List[MemoryChunk], existing_memories: Optional[List[MemoryChunk]] = None - ) -> List[MemoryChunk]: + self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk] | None = None + ) -> list[MemoryChunk]: """融合记忆列表""" start_time = time.time() @@ -106,8 +104,8 @@ class MemoryFusionEngine: return new_memories # 失败时返回原始记忆 async def _detect_duplicate_groups( - self, new_memories: List[MemoryChunk], existing_memories: List[MemoryChunk] - ) -> List[DuplicateGroup]: + self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk] + ) -> list[DuplicateGroup]: """检测重复记忆组""" all_memories = new_memories + existing_memories new_memory_ids = {memory.memory_id for memory in new_memories} @@ -212,7 +210,7 @@ class MemoryFusionEngine: jaccard_similarity = len(intersection) / len(union) return jaccard_similarity - def _calculate_keyword_similarity(self, keywords1: List[str], keywords2: List[str]) -> float: + def _calculate_keyword_similarity(self, keywords1: list[str], keywords2: list[str]) -> float: """计算关键词相似度""" if not keywords1 or not keywords2: return 0.0 @@ -302,7 +300,7 @@ class MemoryFusionEngine: return best_memory - async def _fuse_memory_group(self, group: DuplicateGroup) -> Optional[MemoryChunk]: + async def _fuse_memory_group(self, group: DuplicateGroup) -> MemoryChunk | None: """融合记忆组""" if not group.memories: return None @@ -328,7 +326,7 @@ class MemoryFusionEngine: # 返回置信度最高的记忆 return max(group.memories, key=lambda m: m.metadata.confidence.value) - async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: List[MemoryChunk]) -> MemoryChunk: + async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: list[MemoryChunk]) -> MemoryChunk: """合并记忆属性""" # 创建基础记忆的深拷贝 fused_memory = MemoryChunk.from_dict(base_memory.to_dict()) @@ -395,7 +393,7 @@ class MemoryFusionEngine: source_ids = [m.memory_id[:8] for m in group.memories] fused_memory.metadata.source_context = f"Fused from {len(group.memories)} memories: {', '.join(source_ids)}" - def _merge_temporal_context(self, memories: List[MemoryChunk]) -> Dict[str, Any]: + def _merge_temporal_context(self, memories: list[MemoryChunk]) -> dict[str, Any]: """合并时间上下文""" contexts = [m.temporal_context for m in memories if m.temporal_context] @@ -426,8 +424,8 @@ class MemoryFusionEngine: return merged_context async def incremental_fusion( - self, new_memory: MemoryChunk, existing_memories: List[MemoryChunk] - ) -> Tuple[MemoryChunk, List[MemoryChunk]]: + self, new_memory: MemoryChunk, existing_memories: list[MemoryChunk] + ) -> tuple[MemoryChunk, list[MemoryChunk]]: """增量融合(单个新记忆与现有记忆融合)""" # 寻找相似记忆 similar_memories = [] @@ -493,7 +491,7 @@ class MemoryFusionEngine: except Exception as e: logger.error(f"❌ 记忆融合引擎维护失败: {e}", exc_info=True) - def get_fusion_stats(self) -> Dict[str, Any]: + def get_fusion_stats(self) -> dict[str, Any]: """获取融合统计信息""" return self.fusion_stats.copy() diff --git a/src/chat/memory_system/memory_manager.py b/src/chat/memory_system/memory_manager.py index 4c6b2696e..1ba79fe59 100644 --- a/src/chat/memory_system/memory_manager.py +++ b/src/chat/memory_system/memory_manager.py @@ -1,17 +1,15 @@ -# -*- coding: utf-8 -*- """ 记忆系统管理器 替代原有的 Hippocampus 和 instant_memory 系统 """ import re -from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass +from typing import Any -from src.common.logger import get_logger -from src.chat.memory_system.memory_system import MemorySystem from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType -from src.chat.memory_system.memory_system import initialize_memory_system +from src.chat.memory_system.memory_system import MemorySystem, initialize_memory_system +from src.common.logger import get_logger logger = get_logger(__name__) @@ -27,14 +25,14 @@ class MemoryResult: timestamp: float source: str = "memory" relevance_score: float = 0.0 - structure: Dict[str, Any] | None = None + structure: dict[str, Any] | None = None class MemoryManager: """记忆系统管理器 - 替代原有的 HippocampusManager""" def __init__(self): - self.memory_system: Optional[MemorySystem] = None + self.memory_system: MemorySystem | None = None self.is_initialized = False self.user_cache = {} # 用户记忆缓存 @@ -63,8 +61,8 @@ class MemoryManager: logger.info("正在初始化记忆系统...") # 获取LLM模型 - from src.llm_models.utils_model import LLMRequest from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory") @@ -121,7 +119,7 @@ class MemoryManager: max_memory_length: int = 2, time_weight: float = 1.0, keyword_weight: float = 1.0, - ) -> List[Tuple[str, str]]: + ) -> list[tuple[str, str]]: """从文本获取相关记忆 - 兼容原有接口""" if not self.is_initialized or not self.memory_system: return [] @@ -152,8 +150,8 @@ class MemoryManager: return [] async def get_memory_from_topic( - self, valid_keywords: List[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 - ) -> List[Tuple[str, str]]: + self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 + ) -> list[tuple[str, str]]: """从关键词获取记忆 - 兼容原有接口""" if not self.is_initialized or not self.memory_system: return [] @@ -208,8 +206,8 @@ class MemoryManager: return [] async def process_conversation( - self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: Optional[float] = None - ) -> List[MemoryChunk]: + self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float | None = None + ) -> list[MemoryChunk]: """处理对话并构建记忆 - 新增功能""" if not self.is_initialized or not self.memory_system: return [] @@ -235,8 +233,8 @@ class MemoryManager: return [] async def get_enhanced_memory_context( - self, query_text: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: int = 5 - ) -> List[MemoryResult]: + self, query_text: str, user_id: str, context: dict[str, Any] | None = None, limit: int = 5 + ) -> list[MemoryResult]: """获取增强记忆上下文 - 新增功能""" if not self.is_initialized or not self.memory_system: return [] @@ -267,7 +265,7 @@ class MemoryManager: logger.error(f"get_enhanced_memory_context 失败: {e}") return [] - def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]: + def _format_memory_chunk(self, memory: MemoryChunk) -> tuple[str, dict[str, Any]]: """将记忆块转换为更易读的文本描述""" structure = memory.content.to_dict() if memory.display: @@ -289,7 +287,7 @@ class MemoryManager: return formatted, structure - def _format_subject(self, subject: Optional[str], memory: MemoryChunk) -> str: + def _format_subject(self, subject: str | None, memory: MemoryChunk) -> str: if not subject: return "该用户" @@ -299,7 +297,7 @@ class MemoryManager: return "该聊天" return self._clean_text(subject) - def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> Optional[str]: + def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> str | None: predicate = (predicate or "").strip() obj_value = obj @@ -446,10 +444,10 @@ class MemoryManager: text = self._truncate(str(obj).strip()) return self._clean_text(text) - def _extract_from_object(self, obj: Any, keys: List[str]) -> Optional[str]: + def _extract_from_object(self, obj: Any, keys: list[str]) -> str | None: if isinstance(obj, dict): for key in keys: - if key in obj and obj[key]: + if obj.get(key): value = obj[key] if isinstance(value, (dict, list)): return self._clean_text(self._format_object(value)) diff --git a/src/chat/memory_system/memory_metadata_index.py b/src/chat/memory_system/memory_metadata_index.py index ad27971a6..4b405aad6 100644 --- a/src/chat/memory_system/memory_metadata_index.py +++ b/src/chat/memory_system/memory_metadata_index.py @@ -1,15 +1,15 @@ -# -*- coding: utf-8 -*- """ 记忆元数据索引管理器 使用JSON文件存储记忆元数据,支持快速模糊搜索和过滤 """ -import orjson import threading -from pathlib import Path -from typing import Dict, List, Optional, Set, Any -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass from datetime import datetime +from pathlib import Path +from typing import Any + +import orjson from src.common.logger import get_logger @@ -25,10 +25,10 @@ class MemoryMetadataIndexEntry: # 分类信息 memory_type: str # MemoryType.value - subjects: List[str] # 主语列表 - objects: List[str] # 宾语列表 - keywords: List[str] # 关键词列表 - tags: List[str] # 标签列表 + subjects: list[str] # 主语列表 + objects: list[str] # 宾语列表 + keywords: list[str] # 关键词列表 + tags: list[str] # 标签列表 # 数值字段(用于范围过滤) importance: int # ImportanceLevel.value (1-4) @@ -37,8 +37,8 @@ class MemoryMetadataIndexEntry: access_count: int # 访问次数 # 可选字段 - chat_id: Optional[str] = None - content_preview: Optional[str] = None # 内容预览(前100字符) + chat_id: str | None = None + content_preview: str | None = None # 内容预览(前100字符) class MemoryMetadataIndex: @@ -46,13 +46,13 @@ class MemoryMetadataIndex: def __init__(self, index_file: str = "data/memory_metadata_index.json"): self.index_file = Path(index_file) - self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry + self.index: dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry # 倒排索引(用于快速查找) - self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids} - self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids} - self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids} - self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids} + self.type_index: dict[str, set[str]] = {} # type -> {memory_ids} + self.subject_index: dict[str, set[str]] = {} # subject -> {memory_ids} + self.keyword_index: dict[str, set[str]] = {} # keyword -> {memory_ids} + self.tag_index: dict[str, set[str]] = {} # tag -> {memory_ids} self.lock = threading.RLock() @@ -178,7 +178,7 @@ class MemoryMetadataIndex: self._remove_from_inverted_indices(memory_id) del self.index[memory_id] - def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]): + def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]): """批量添加或更新""" with self.lock: for entry in entries: @@ -191,18 +191,18 @@ class MemoryMetadataIndex: def search( self, - memory_types: Optional[List[str]] = None, - subjects: Optional[List[str]] = None, - keywords: Optional[List[str]] = None, - tags: Optional[List[str]] = None, - importance_min: Optional[int] = None, - importance_max: Optional[int] = None, - created_after: Optional[float] = None, - created_before: Optional[float] = None, - user_id: Optional[str] = None, - limit: Optional[int] = None, + memory_types: list[str] | None = None, + subjects: list[str] | None = None, + keywords: list[str] | None = None, + tags: list[str] | None = None, + importance_min: int | None = None, + importance_max: int | None = None, + created_after: float | None = None, + created_before: float | None = None, + user_id: str | None = None, + limit: int | None = None, flexible_mode: bool = True, # 新增:灵活匹配模式 - ) -> List[str]: + ) -> list[str]: """ 搜索符合条件的记忆ID列表(支持模糊匹配) @@ -237,14 +237,14 @@ class MemoryMetadataIndex: def _search_flexible( self, - memory_types: Optional[List[str]] = None, - subjects: Optional[List[str]] = None, - created_after: Optional[float] = None, - created_before: Optional[float] = None, - user_id: Optional[str] = None, - limit: Optional[int] = None, + memory_types: list[str] | None = None, + subjects: list[str] | None = None, + created_after: float | None = None, + created_before: float | None = None, + user_id: str | None = None, + limit: int | None = None, **kwargs, # 接受但不使用的参数 - ) -> List[str]: + ) -> list[str]: """ 灵活搜索模式:2/4项匹配即可,支持部分匹配 @@ -374,20 +374,20 @@ class MemoryMetadataIndex: def _search_strict( self, - memory_types: Optional[List[str]] = None, - subjects: Optional[List[str]] = None, - keywords: Optional[List[str]] = None, - tags: Optional[List[str]] = None, - importance_min: Optional[int] = None, - importance_max: Optional[int] = None, - created_after: Optional[float] = None, - created_before: Optional[float] = None, - user_id: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[str]: + memory_types: list[str] | None = None, + subjects: list[str] | None = None, + keywords: list[str] | None = None, + tags: list[str] | None = None, + importance_min: int | None = None, + importance_max: int | None = None, + created_after: float | None = None, + created_before: float | None = None, + user_id: str | None = None, + limit: int | None = None, + ) -> list[str]: """严格搜索模式(原有逻辑)""" # 初始候选集(所有记忆) - candidate_ids: Optional[Set[str]] = None + candidate_ids: set[str] | None = None # 用户过滤(必选) if user_id: @@ -471,11 +471,11 @@ class MemoryMetadataIndex: return result_ids - def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]: + def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None: """获取单个索引条目""" return self.index.get(memory_id) - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: """获取索引统计信息""" with self.lock: return { diff --git a/src/chat/memory_system/memory_query_planner.py b/src/chat/memory_system/memory_query_planner.py index a8be9d951..bbedf766c 100644 --- a/src/chat/memory_system/memory_query_planner.py +++ b/src/chat/memory_system/memory_query_planner.py @@ -1,11 +1,10 @@ -# -*- coding: utf-8 -*- """记忆检索查询规划器""" from __future__ import annotations import re from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any import orjson @@ -21,16 +20,16 @@ class MemoryQueryPlan: """查询规划结果""" semantic_query: str - memory_types: List[MemoryType] = field(default_factory=list) - subject_includes: List[str] = field(default_factory=list) - object_includes: List[str] = field(default_factory=list) - required_keywords: List[str] = field(default_factory=list) - optional_keywords: List[str] = field(default_factory=list) - owner_filters: List[str] = field(default_factory=list) + memory_types: list[MemoryType] = field(default_factory=list) + subject_includes: list[str] = field(default_factory=list) + object_includes: list[str] = field(default_factory=list) + required_keywords: list[str] = field(default_factory=list) + optional_keywords: list[str] = field(default_factory=list) + owner_filters: list[str] = field(default_factory=list) recency_preference: str = "any" limit: int = 10 - emphasis: Optional[str] = None - raw_plan: Dict[str, Any] = field(default_factory=dict) + emphasis: str | None = None + raw_plan: dict[str, Any] = field(default_factory=dict) def ensure_defaults(self, fallback_query: str, default_limit: int) -> None: if not self.semantic_query: @@ -46,11 +45,11 @@ class MemoryQueryPlan: class MemoryQueryPlanner: """基于小模型的记忆检索查询规划器""" - def __init__(self, planner_model: Optional[LLMRequest], default_limit: int = 10): + def __init__(self, planner_model: LLMRequest | None, default_limit: int = 10): self.model = planner_model self.default_limit = default_limit - async def plan_query(self, query_text: str, context: Dict[str, Any]) -> MemoryQueryPlan: + async def plan_query(self, query_text: str, context: dict[str, Any]) -> MemoryQueryPlan: if not self.model: logger.debug("未提供查询规划模型,使用默认规划") return self._default_plan(query_text) @@ -82,10 +81,10 @@ class MemoryQueryPlanner: def _default_plan(self, query_text: str) -> MemoryQueryPlan: return MemoryQueryPlan(semantic_query=query_text, limit=self.default_limit) - def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan: + def _parse_plan_dict(self, data: dict[str, Any], fallback_query: str) -> MemoryQueryPlan: semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query - def _collect_list(key: str) -> List[str]: + def _collect_list(key: str) -> list[str]: value = data.get(key) if isinstance(value, str): return [value] @@ -94,7 +93,7 @@ class MemoryQueryPlanner: return [] memory_type_values = _collect_list("memory_types") - memory_types: List[MemoryType] = [] + memory_types: list[MemoryType] = [] for item in memory_type_values: if not item: continue @@ -123,7 +122,7 @@ class MemoryQueryPlanner: ) return plan - def _build_prompt(self, query_text: str, context: Dict[str, Any]) -> str: + def _build_prompt(self, query_text: str, context: dict[str, Any]) -> str: participants = context.get("participants") or context.get("speaker_names") or [] if isinstance(participants, str): participants = [participants] @@ -206,7 +205,7 @@ class MemoryQueryPlanner: 请直接输出符合要求的 JSON 对象,禁止添加额外文本或 Markdown 代码块。 """ - def _extract_json_payload(self, response: str) -> Optional[str]: + def _extract_json_payload(self, response: str) -> str | None: if not response: return None diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index 4a275babd..5236da62a 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 精准记忆系统核心模块 1. 基于文档设计的高效记忆构建、存储与召回优化系统,覆盖构建、向量化与多阶段检索全流程。 @@ -6,26 +5,27 @@ """ import asyncio -import time -import orjson -import re import hashlib -from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING +import re +import time +from dataclasses import asdict, dataclass from datetime import datetime, timedelta -from dataclasses import dataclass, asdict from enum import Enum +from typing import TYPE_CHECKING, Any + +import orjson -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config, global_config -from src.chat.memory_system.memory_chunk import MemoryChunk from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractionError +from src.chat.memory_system.memory_chunk import MemoryChunk from src.chat.memory_system.memory_fusion import MemoryFusionEngine from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest if TYPE_CHECKING: - from src.common.data_models.database_data_model import DatabaseMessages from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine + from src.common.data_models.database_data_model import DatabaseMessages logger = get_logger(__name__) @@ -121,7 +121,7 @@ class MemorySystemConfig: class MemorySystem: """精准记忆系统核心类""" - def __init__(self, llm_model: Optional[LLMRequest] = None, config: Optional[MemorySystemConfig] = None): + def __init__(self, llm_model: LLMRequest | None = None, config: MemorySystemConfig | None = None): self.config = config or MemorySystemConfig.from_global_config() self.llm_model = llm_model self.status = MemorySystemStatus.INITIALIZING @@ -131,7 +131,7 @@ class MemorySystem: self.fusion_engine: MemoryFusionEngine = None self.unified_storage = None # 统一存储系统 self.query_planner: MemoryQueryPlanner = None - self.forgetting_engine: Optional[MemoryForgettingEngine] = None + self.forgetting_engine: MemoryForgettingEngine | None = None # LLM模型 self.value_assessment_model: LLMRequest = None @@ -143,10 +143,10 @@ class MemorySystem: self.last_retrieval_time = None # 构建节流记录 - self._last_memory_build_times: Dict[str, float] = {} + self._last_memory_build_times: dict[str, float] = {} # 记忆指纹缓存,用于快速检测重复记忆 - self._memory_fingerprints: Dict[str, str] = {} + self._memory_fingerprints: dict[str, str] = {} logger.info("MemorySystem 初始化开始") @@ -210,7 +210,7 @@ class MemorySystem: raise # 初始化遗忘引擎 - from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig + from src.chat.memory_system.memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine # 从全局配置创建遗忘引擎配置 forgetting_config = ForgettingConfig( @@ -241,7 +241,7 @@ class MemorySystem: self.forgetting_engine = MemoryForgettingEngine(forgetting_config) planner_task_config = getattr(model_config.model_task_config, "utils_small", None) - planner_model: Optional[LLMRequest] = None + planner_model: LLMRequest | None = None try: planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner") except Exception as planner_exc: @@ -261,8 +261,8 @@ class MemorySystem: raise async def retrieve_memories_for_building( - self, query_text: str, user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5 - ) -> List[MemoryChunk]: + self, query_text: str, user_id: str | None = None, context: dict[str, Any] | None = None, limit: int = 5 + ) -> list[MemoryChunk]: """在构建记忆时检索相关记忆,使用统一存储系统 Args: @@ -302,8 +302,8 @@ class MemorySystem: return [] async def build_memory_from_conversation( - self, conversation_text: str, context: Dict[str, Any], timestamp: Optional[float] = None - ) -> List[MemoryChunk]: + self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None + ) -> list[MemoryChunk]: """从对话中构建记忆 Args: @@ -318,8 +318,8 @@ class MemorySystem: self.status = MemorySystemStatus.BUILDING start_time = time.time() - build_scope_key: Optional[str] = None - build_marker_time: Optional[float] = None + build_scope_key: str | None = None + build_marker_time: float | None = None try: normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp) @@ -408,7 +408,7 @@ class MemorySystem: logger.error(f"❌ 记忆构建失败: {e}", exc_info=True) raise - def _log_memory_preview(self, memories: List[MemoryChunk]) -> None: + def _log_memory_preview(self, memories: list[MemoryChunk]) -> None: """在控制台输出记忆预览,便于人工检查""" if not memories: logger.info("📝 本次未生成新的记忆") @@ -425,12 +425,12 @@ class MemorySystem: f"置信度={memory.metadata.confidence.name} | 内容={text}" ) - async def _collect_fusion_candidates(self, new_memories: List[MemoryChunk]) -> List[MemoryChunk]: + async def _collect_fusion_candidates(self, new_memories: list[MemoryChunk]) -> list[MemoryChunk]: """收集与新记忆相似的现有记忆,便于融合去重""" if not new_memories: return [] - candidate_ids: Set[str] = set() + candidate_ids: set[str] = set() new_memory_ids = {memory.memory_id for memory in new_memories if memory and getattr(memory, "memory_id", None)} # 基于指纹的直接匹配 @@ -493,7 +493,7 @@ class MemorySystem: continue candidate_ids.add(memory_id) - existing_candidates: List[MemoryChunk] = [] + existing_candidates: list[MemoryChunk] = [] cache = self.unified_storage.memory_cache if self.unified_storage else {} for candidate_id in candidate_ids: if candidate_id in new_memory_ids: @@ -511,7 +511,7 @@ class MemorySystem: return existing_candidates - async def process_conversation_memory(self, context: Dict[str, Any]) -> Dict[str, Any]: + async def process_conversation_memory(self, context: dict[str, Any]) -> dict[str, Any]: """对外暴露的对话记忆处理接口,仅依赖上下文信息""" start_time = time.time() @@ -559,12 +559,12 @@ class MemorySystem: async def retrieve_relevant_memories( self, - query_text: Optional[str] = None, - user_id: Optional[str] = None, - context: Optional[Dict[str, Any]] = None, + query_text: str | None = None, + user_id: str | None = None, + context: dict[str, Any] | None = None, limit: int = 5, **kwargs, - ) -> List[MemoryChunk]: + ) -> list[MemoryChunk]: """检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)""" raw_query = query_text or kwargs.get("query") if not raw_query: @@ -750,7 +750,7 @@ class MemorySystem: raise @staticmethod - def _extract_json_payload(response: str) -> Optional[str]: + def _extract_json_payload(response: str) -> str | None: """从模型响应中提取JSON部分,兼容Markdown代码块等格式""" if not response: return None @@ -773,10 +773,10 @@ class MemorySystem: return stripped if stripped.startswith("{") and stripped.endswith("}") else None def _normalize_context( - self, raw_context: Optional[Dict[str, Any]], user_id: Optional[str], timestamp: Optional[float] - ) -> Dict[str, Any]: + self, raw_context: dict[str, Any] | None, user_id: str | None, timestamp: float | None + ) -> dict[str, Any]: """标准化上下文,确保必备字段存在且格式正确""" - context: Dict[str, Any] = {} + context: dict[str, Any] = {} if raw_context: try: context = dict(raw_context) @@ -822,7 +822,7 @@ class MemorySystem: return context - async def _build_enhanced_query_context(self, raw_query: str, normalized_context: Dict[str, Any]) -> Dict[str, Any]: + async def _build_enhanced_query_context(self, raw_query: str, normalized_context: dict[str, Any]) -> dict[str, Any]: """构建包含未读消息综合上下文的增强查询上下文 Args: @@ -861,7 +861,7 @@ class MemorySystem: return enhanced_context - async def _collect_unread_messages_context(self, stream_id: str) -> Optional[Dict[str, Any]]: + async def _collect_unread_messages_context(self, stream_id: str) -> dict[str, Any] | None: """收集未读消息的综合上下文信息 Args: @@ -953,7 +953,7 @@ class MemorySystem: logger.warning(f"收集未读消息上下文失败: {e}", exc_info=True) return None - def _build_unread_context_summary(self, messages_summary: List[Dict[str, Any]]) -> str: + def _build_unread_context_summary(self, messages_summary: list[dict[str, Any]]) -> str: """构建未读消息的文本摘要 Args: @@ -974,7 +974,7 @@ class MemorySystem: return " | ".join(summary_parts) - async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str: + async def _resolve_conversation_context(self, fallback_text: str, context: dict[str, Any] | None) -> str: """使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本""" if not context: return fallback_text @@ -1043,11 +1043,11 @@ class MemorySystem: # 回退到传入文本 return fallback_text - def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]: + def _get_build_scope_key(self, context: dict[str, Any], user_id: str | None) -> str | None: """确定用于节流控制的记忆构建作用域""" return "global_scope" - def _determine_history_limit(self, context: Dict[str, Any]) -> int: + def _determine_history_limit(self, context: dict[str, Any]) -> int: """确定历史消息获取数量,限制在30-50之间""" default_limit = 40 candidate = context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit") @@ -1065,12 +1065,12 @@ class MemorySystem: return history_limit - def _format_history_messages(self, messages: List["DatabaseMessages"]) -> Optional[str]: + def _format_history_messages(self, messages: list["DatabaseMessages"]) -> str | None: """将历史消息格式化为可供LLM处理的多轮对话文本""" if not messages: return None - lines: List[str] = [] + lines: list[str] = [] for msg in messages: try: content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None) @@ -1105,7 +1105,7 @@ class MemorySystem: return "\n".join(lines) if lines else None - async def _assess_information_value(self, text: str, context: Dict[str, Any]) -> float: + async def _assess_information_value(self, text: str, context: dict[str, Any]) -> float: """评估信息价值 Args: @@ -1201,7 +1201,7 @@ class MemorySystem: logger.error(f"信息价值评估失败: {e}", exc_info=True) return 0.5 # 默认中等价值 - async def _store_memories_unified(self, memory_chunks: List[MemoryChunk]) -> int: + async def _store_memories_unified(self, memory_chunks: list[MemoryChunk]) -> int: """使用统一存储系统存储记忆块""" if not memory_chunks or not self.unified_storage: return 0 @@ -1222,7 +1222,7 @@ class MemorySystem: return 0 # 保留原有方法以兼容旧代码 - async def _store_memories(self, memory_chunks: List[MemoryChunk]) -> int: + async def _store_memories(self, memory_chunks: list[MemoryChunk]) -> int: """兼容性方法:重定向到统一存储""" return await self._store_memories_unified(memory_chunks) @@ -1271,7 +1271,7 @@ class MemorySystem: key = self._fingerprint_key(memory.user_id, fingerprint) self._memory_fingerprints[key] = memory.memory_id - def _register_memory_fingerprints(self, memories: List[MemoryChunk]) -> None: + def _register_memory_fingerprints(self, memories: list[MemoryChunk]) -> None: for memory in memories: fingerprint = self._build_memory_fingerprint(memory) key = self._fingerprint_key(memory.user_id, fingerprint) @@ -1302,9 +1302,9 @@ class MemorySystem: @staticmethod def _fingerprint_key(user_id: str, fingerprint: str) -> str: - return f"{str(user_id)}:{fingerprint}" + return f"{user_id!s}:{fingerprint}" - def get_system_stats(self) -> Dict[str, Any]: + def get_system_stats(self) -> dict[str, Any]: """获取系统统计信息""" return { "status": self.status.value, @@ -1314,7 +1314,7 @@ class MemorySystem: "config": asdict(self.config), } - def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float: + def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: dict[str, Any]) -> float: """根据查询和上下文为记忆计算匹配分数""" tokens_query = self._tokenize_text(query_text) tokens_memory = self._tokenize_text(memory.text_content) @@ -1338,7 +1338,7 @@ class MemorySystem: final_score = base_score * 0.7 + keyword_overlap * 0.15 + importance_boost + confidence_boost return max(0.0, min(1.0, final_score)) - def _tokenize_text(self, text: str) -> Set[str]: + def _tokenize_text(self, text: str) -> set[str]: """简单分词,兼容中英文""" if not text: return set() @@ -1450,7 +1450,7 @@ def get_memory_system() -> MemorySystem: return memory_system -async def initialize_memory_system(llm_model: Optional[LLMRequest] = None): +async def initialize_memory_system(llm_model: LLMRequest | None = None): """初始化全局记忆系统""" global memory_system if memory_system is None: diff --git a/src/chat/memory_system/vector_memory_storage_v2.py b/src/chat/memory_system/vector_memory_storage_v2.py index 3c924ba30..7fcae93c8 100644 --- a/src/chat/memory_system/vector_memory_storage_v2.py +++ b/src/chat/memory_system/vector_memory_storage_v2.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 基于Vector DB的统一记忆存储系统 V2 使用ChromaDB作为底层存储,替代JSON存储方式 @@ -11,20 +10,21 @@ - 自动清理过期记忆 """ -import time -import orjson import asyncio import threading -from typing import Dict, List, Optional, Tuple, Any +import time from dataclasses import dataclass from datetime import datetime +from typing import Any -from src.common.logger import get_logger -from src.common.vector_db import vector_db_service -from src.chat.utils.utils import get_embedding -from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel +import orjson + +from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndex, MemoryMetadataIndexEntry +from src.chat.utils.utils import get_embedding +from src.common.logger import get_logger +from src.common.vector_db import vector_db_service logger = get_logger(__name__) @@ -32,7 +32,7 @@ logger = get_logger(__name__) _ENUM_MAPPINGS_CACHE = {} -def _build_enum_mapping(enum_class: type) -> Dict[str, Any]: +def _build_enum_mapping(enum_class: type) -> dict[str, Any]: """构建枚举类的完整映射表 Args: @@ -145,7 +145,7 @@ class VectorMemoryStorage: """基于Vector DB的记忆存储系统""" - def __init__(self, config: Optional[VectorStorageConfig] = None): + def __init__(self, config: VectorStorageConfig | None = None): # 默认从全局配置读取,如果没有传入config if config is None: try: @@ -163,15 +163,15 @@ class VectorMemoryStorage: self.vector_db_service = vector_db_service # 内存缓存 - self.memory_cache: Dict[str, MemoryChunk] = {} - self.cache_timestamps: Dict[str, float] = {} + self.memory_cache: dict[str, MemoryChunk] = {} + self.cache_timestamps: dict[str, float] = {} self._cache = self.memory_cache # 别名,兼容旧代码 # 元数据索引管理器(JSON文件索引) self.metadata_index = MemoryMetadataIndex() # 遗忘引擎 - self.forgetting_engine: Optional[MemoryForgettingEngine] = None + self.forgetting_engine: MemoryForgettingEngine | None = None if self.config.enable_forgetting: self.forgetting_engine = MemoryForgettingEngine() @@ -267,7 +267,7 @@ class VectorMemoryStorage: except Exception as e: logger.error(f"自动清理失败: {e}") - def _memory_to_vector_format(self, memory: MemoryChunk) -> Dict[str, Any]: + def _memory_to_vector_format(self, memory: MemoryChunk) -> dict[str, Any]: """将MemoryChunk转换为向量存储格式""" try: # 获取memory_id @@ -323,7 +323,7 @@ class VectorMemoryStorage: logger.error(f"转换记忆 {memory_id} 到向量格式失败: {e}", exc_info=True) raise - def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]: + def _vector_result_to_memory(self, document: str, metadata: dict[str, Any]) -> MemoryChunk | None: """将Vector DB结果转换为MemoryChunk""" try: # 从元数据中恢复完整记忆 @@ -440,7 +440,7 @@ class VectorMemoryStorage: logger.warning(f"不支持的{enum_class.__name__}值类型: {type(value)},使用默认值") return default - def _get_from_cache(self, memory_id: str) -> Optional[MemoryChunk]: + def _get_from_cache(self, memory_id: str) -> MemoryChunk | None: """从缓存获取记忆""" if not self.config.enable_caching: return None @@ -472,7 +472,7 @@ class VectorMemoryStorage: self.memory_cache[memory_id] = memory self.cache_timestamps[memory_id] = time.time() - async def store_memories(self, memories: List[MemoryChunk]) -> int: + async def store_memories(self, memories: list[MemoryChunk]) -> int: """批量存储记忆""" if not memories: return 0 @@ -603,11 +603,11 @@ class VectorMemoryStorage: self, query_text: str, limit: int = 10, - similarity_threshold: Optional[float] = None, - filters: Optional[Dict[str, Any]] = None, + similarity_threshold: float | None = None, + filters: dict[str, Any] | None = None, # 新增:元数据过滤参数(用于JSON索引粗筛) - metadata_filters: Optional[Dict[str, Any]] = None, - ) -> List[Tuple[MemoryChunk, float]]: + metadata_filters: dict[str, Any] | None = None, + ) -> list[tuple[MemoryChunk, float]]: """ 搜索相似记忆(混合索引模式) @@ -632,7 +632,7 @@ class VectorMemoryStorage: try: # === 阶段一:JSON元数据粗筛(可选) === - candidate_ids: Optional[List[str]] = None + candidate_ids: list[str] | None = None if metadata_filters: logger.debug(f"[JSON元数据粗筛] 开始,过滤条件: {metadata_filters}") candidate_ids = self.metadata_index.search( @@ -746,7 +746,7 @@ class VectorMemoryStorage: logger.error(f"搜索相似记忆失败: {e}") return [] - async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]: + async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None: """根据ID获取记忆""" # 首先尝试从缓存获取 memory = self._get_from_cache(memory_id) @@ -772,7 +772,7 @@ class VectorMemoryStorage: return None - async def get_memories_by_filters(self, filters: Dict[str, Any], limit: int = 100) -> List[MemoryChunk]: + async def get_memories_by_filters(self, filters: dict[str, Any], limit: int = 100) -> list[MemoryChunk]: """根据过滤条件获取记忆""" try: results = vector_db_service.get(collection_name=self.config.memory_collection, where=filters, limit=limit) @@ -848,7 +848,7 @@ class VectorMemoryStorage: logger.error(f"删除记忆 {memory_id} 失败: {e}") return False - async def delete_memories_by_filters(self, filters: Dict[str, Any]) -> int: + async def delete_memories_by_filters(self, filters: dict[str, Any]) -> int: """根据过滤条件批量删除记忆""" try: # 先获取要删除的记忆ID @@ -880,7 +880,7 @@ class VectorMemoryStorage: logger.error(f"批量删除记忆失败: {e}") return 0 - async def perform_forgetting_check(self) -> Dict[str, Any]: + async def perform_forgetting_check(self) -> dict[str, Any]: """执行遗忘检查""" if not self.forgetting_engine: return {"error": "遗忘引擎未启用"} @@ -925,7 +925,7 @@ class VectorMemoryStorage: logger.error(f"执行遗忘检查失败: {e}") return {"error": str(e)} - def get_storage_stats(self) -> Dict[str, Any]: + def get_storage_stats(self) -> dict[str, Any]: """获取存储统计信息""" try: current_total = vector_db_service.count(self.config.memory_collection) @@ -960,7 +960,7 @@ class VectorMemoryStorage: _global_vector_storage = None -def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> VectorMemoryStorage: +def get_vector_memory_storage(config: VectorStorageConfig | None = None) -> VectorMemoryStorage: """获取全局Vector记忆存储实例""" global _global_vector_storage @@ -974,15 +974,15 @@ def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> V class VectorMemoryStorageAdapter: """适配器类,提供与原UnifiedMemoryStorage兼容的接口""" - def __init__(self, config: Optional[VectorStorageConfig] = None): + def __init__(self, config: VectorStorageConfig | None = None): self.storage = VectorMemoryStorage(config) - async def store_memories(self, memories: List[MemoryChunk]) -> int: + async def store_memories(self, memories: list[MemoryChunk]) -> int: return await self.storage.store_memories(memories) async def search_similar_memories( - self, query_text: str, limit: int = 10, scope_id: Optional[str] = None, filters: Optional[Dict[str, Any]] = None - ) -> List[Tuple[str, float]]: + self, query_text: str, limit: int = 10, scope_id: str | None = None, filters: dict[str, Any] | None = None + ) -> list[tuple[str, float]]: results = await self.storage.search_similar_memories(query_text, limit, filters=filters) # 转换为原格式:(memory_id, similarity) return [ @@ -990,7 +990,7 @@ class VectorMemoryStorageAdapter: for memory, similarity in results ] - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: return self.storage.get_storage_stats() diff --git a/src/chat/message_manager/__init__.py b/src/chat/message_manager/__init__.py index fe5e90785..c8bd18a08 100644 --- a/src/chat/message_manager/__init__.py +++ b/src/chat/message_manager/__init__.py @@ -3,14 +3,14 @@ 提供统一的消息管理、上下文管理和流循环调度功能 """ -from .message_manager import MessageManager, message_manager from .context_manager import SingleStreamContextManager from .distribution_manager import StreamLoopManager, stream_loop_manager +from .message_manager import MessageManager, message_manager __all__ = [ "MessageManager", - "message_manager", "SingleStreamContextManager", "StreamLoopManager", + "message_manager", "stream_loop_manager", ] diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index 5f3212065..ceefa99b2 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -6,13 +6,14 @@ import asyncio import time -from typing import Dict, List, Optional, Any +from typing import Any +from src.chat.energy_system import energy_manager +from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.message_manager_data_model import StreamContext from src.common.logger import get_logger from src.config.config import global_config -from src.common.data_models.database_data_model import DatabaseMessages -from src.chat.energy_system import energy_manager + from .distribution_manager import stream_loop_manager logger = get_logger("context_manager") @@ -21,7 +22,7 @@ logger = get_logger("context_manager") class SingleStreamContextManager: """单流上下文管理器 - 每个实例只管理一个 stream 的上下文""" - def __init__(self, stream_id: str, context: StreamContext, max_context_size: Optional[int] = None): + def __init__(self, stream_id: str, context: StreamContext, max_context_size: int | None = None): self.stream_id = stream_id self.context = context @@ -66,7 +67,7 @@ class SingleStreamContextManager: logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True) return False - async def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool: + async def update_message(self, message_id: str, updates: dict[str, Any]) -> bool: """更新上下文中的消息 Args: @@ -84,7 +85,7 @@ class SingleStreamContextManager: logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True) return False - def get_messages(self, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]: + def get_messages(self, limit: int | None = None, include_unread: bool = True) -> list[DatabaseMessages]: """获取上下文消息 Args: @@ -117,7 +118,7 @@ class SingleStreamContextManager: logger.error(f"获取单流上下文消息失败 {self.stream_id}: {e}", exc_info=True) return [] - def get_unread_messages(self) -> List[DatabaseMessages]: + def get_unread_messages(self) -> list[DatabaseMessages]: """获取未读消息""" try: return self.context.get_unread_messages() @@ -125,7 +126,7 @@ class SingleStreamContextManager: logger.error(f"获取单流未读消息失败 {self.stream_id}: {e}", exc_info=True) return [] - def mark_messages_as_read(self, message_ids: List[str]) -> bool: + def mark_messages_as_read(self, message_ids: list[str]) -> bool: """标记消息为已读""" try: if not hasattr(self.context, "mark_message_as_read"): @@ -168,7 +169,7 @@ class SingleStreamContextManager: logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True) return False - def get_statistics(self) -> Dict[str, Any]: + def get_statistics(self) -> dict[str, Any]: """获取流统计信息""" try: current_time = time.time() @@ -285,7 +286,7 @@ class SingleStreamContextManager: logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True) return False - async def update_message_async(self, message_id: str, updates: Dict[str, Any]) -> bool: + async def update_message_async(self, message_id: str, updates: dict[str, Any]) -> bool: """异步实现的 update_message:更新消息并在需要时 await 能量更新。""" try: self.context.update_message_info(message_id, **updates) @@ -327,7 +328,7 @@ class SingleStreamContextManager: """更新流能量""" try: history_messages = self.context.get_history_messages(limit=self.max_context_size) - messages: List[DatabaseMessages] = list(history_messages) + messages: list[DatabaseMessages] = list(history_messages) if include_unread: messages.extend(self.get_unread_messages()) diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index 69f3e662d..152c40362 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -5,12 +5,12 @@ import asyncio import time -from typing import Dict, Optional, Any +from typing import Any +from src.chat.chatter_manager import ChatterManager +from src.chat.energy_system import energy_manager from src.common.logger import get_logger from src.config.config import global_config -from src.chat.energy_system import energy_manager -from src.chat.chatter_manager import ChatterManager from src.plugin_system.apis.chat_api import get_chat_manager logger = get_logger("stream_loop_manager") @@ -19,13 +19,13 @@ logger = get_logger("stream_loop_manager") class StreamLoopManager: """流循环管理器 - 每个流一个独立的无限循环任务""" - def __init__(self, max_concurrent_streams: Optional[int] = None): + def __init__(self, max_concurrent_streams: int | None = None): # 流循环任务管理 - self.stream_loops: Dict[str, asyncio.Task] = {} + self.stream_loops: dict[str, asyncio.Task] = {} self.loop_lock = asyncio.Lock() # 统计信息 - self.stats: Dict[str, Any] = { + self.stats: dict[str, Any] = { "active_streams": 0, "total_loops": 0, "total_process_cycles": 0, @@ -37,13 +37,13 @@ class StreamLoopManager: self.max_concurrent_streams = max_concurrent_streams or global_config.chat.max_concurrent_distributions # 强制分发策略 - self.force_dispatch_unread_threshold: Optional[int] = getattr( + self.force_dispatch_unread_threshold: int | None = getattr( global_config.chat, "force_dispatch_unread_threshold", 20 ) self.force_dispatch_min_interval: float = getattr(global_config.chat, "force_dispatch_min_interval", 0.1) # Chatter管理器 - self.chatter_manager: Optional[ChatterManager] = None + self.chatter_manager: ChatterManager | None = None # 状态控制 self.is_running = False @@ -212,7 +212,7 @@ class StreamLoopManager: logger.info(f"流循环结束: {stream_id}") - async def _get_stream_context(self, stream_id: str) -> Optional[Any]: + async def _get_stream_context(self, stream_id: str) -> Any | None: """获取流上下文 Args: @@ -320,7 +320,7 @@ class StreamLoopManager: logger.debug(f"流 {stream_id} 使用默认间隔: {base_interval:.2f}s ({e})") return base_interval - def get_queue_status(self) -> Dict[str, Any]: + def get_queue_status(self) -> dict[str, Any]: """获取队列状态 Returns: @@ -374,14 +374,14 @@ class StreamLoopManager: except Exception: return 0 - def _needs_force_dispatch_for_context(self, context: Any, unread_count: Optional[int] = None) -> bool: + def _needs_force_dispatch_for_context(self, context: Any, unread_count: int | None = None) -> bool: if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0: return False count = unread_count if unread_count is not None else self._get_unread_count(context) return count > self.force_dispatch_unread_threshold - def get_performance_summary(self) -> Dict[str, Any]: + def get_performance_summary(self) -> dict[str, Any]: """获取性能摘要 Returns: diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index bd55bd43f..78e3363ff 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -6,19 +6,20 @@ import asyncio import random import time -from typing import Dict, Optional, Any, TYPE_CHECKING, List +from typing import TYPE_CHECKING, Any +from src.chat.chatter_manager import ChatterManager from src.chat.message_receive.chat_stream import ChatStream -from src.common.logger import get_logger +from src.chat.planner_actions.action_manager import ChatterActionManager from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats -from src.chat.chatter_manager import ChatterManager -from src.chat.planner_actions.action_manager import ChatterActionManager -from .sleep_manager.sleep_manager import SleepManager -from .sleep_manager.wakeup_manager import WakeUpManager +from src.common.logger import get_logger from src.config.config import global_config from src.plugin_system.apis.chat_api import get_chat_manager + from .distribution_manager import stream_loop_manager +from .sleep_manager.sleep_manager import SleepManager +from .sleep_manager.wakeup_manager import WakeUpManager if TYPE_CHECKING: pass @@ -32,7 +33,7 @@ class MessageManager: def __init__(self, check_interval: float = 5.0): self.check_interval = check_interval # 检查间隔(秒) self.is_running = False - self.manager_task: Optional[asyncio.Task] = None + self.manager_task: asyncio.Task | None = None # 统计信息 self.stats = MessageManagerStats() @@ -125,7 +126,7 @@ class MessageManager: except Exception as e: logger.error(f"更新消息 {message_id} 时发生错误: {e}") - async def bulk_update_messages(self, stream_id: str, updates: List[Dict[str, Any]]) -> int: + async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int: """批量更新消息信息,降低更新频率""" if not updates: return 0 @@ -214,7 +215,7 @@ class MessageManager: except Exception as e: logger.error(f"激活聊天流 {stream_id} 时发生错误: {e}") - def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]: + def get_stream_stats(self, stream_id: str) -> StreamStats | None: """获取聊天流统计""" try: # 通过 ChatManager 获取 ChatStream @@ -243,7 +244,7 @@ class MessageManager: logger.error(f"获取聊天流 {stream_id} 统计时发生错误: {e}") return None - def get_manager_stats(self) -> Dict[str, Any]: + def get_manager_stats(self) -> dict[str, Any]: """获取管理器统计""" return { "total_streams": self.stats.total_streams, @@ -278,7 +279,7 @@ class MessageManager: except Exception as e: logger.error(f"清理不活跃聊天流时发生错误: {e}") - async def _check_and_handle_interruption(self, chat_stream: Optional[ChatStream] = None): + async def _check_and_handle_interruption(self, chat_stream: ChatStream | None = None): """检查并处理消息打断""" if not global_config.chat.interruption_enabled: return diff --git a/src/chat/message_manager/sleep_manager/sleep_manager.py b/src/chat/message_manager/sleep_manager/sleep_manager.py index b0cf79b1b..6aeab8037 100644 --- a/src/chat/message_manager/sleep_manager/sleep_manager.py +++ b/src/chat/message_manager/sleep_manager/sleep_manager.py @@ -1,12 +1,13 @@ import asyncio import random from datetime import datetime, timedelta -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from src.common.logger import get_logger from src.config.config import global_config + from .notification_sender import NotificationSender -from .sleep_state import SleepState, SleepContext +from .sleep_state import SleepContext, SleepState from .time_checker import TimeChecker if TYPE_CHECKING: @@ -92,7 +93,7 @@ class SleepManager: elif current_state == SleepState.WOKEN_UP: self._handle_woken_up(now, is_in_theoretical_sleep, wakeup_manager) - def _handle_awake_to_sleep(self, now: datetime, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]): + def _handle_awake_to_sleep(self, now: datetime, activity: str | None, wakeup_manager: Optional["WakeUpManager"]): """处理从“清醒”到“准备入睡”的状态转换。""" if activity: logger.info(f"进入理论休眠时间 '{activity}',开始进行睡眠决策...") @@ -181,7 +182,7 @@ class SleepManager: self, now: datetime, is_in_theoretical_sleep: bool, - activity: Optional[str], + activity: str | None, wakeup_manager: Optional["WakeUpManager"], ): """处理“正在睡觉”状态下的逻辑。""" diff --git a/src/chat/message_manager/sleep_manager/sleep_state.py b/src/chat/message_manager/sleep_manager/sleep_state.py index 105302169..21a9f11bb 100644 --- a/src/chat/message_manager/sleep_manager/sleep_state.py +++ b/src/chat/message_manager/sleep_manager/sleep_state.py @@ -1,6 +1,5 @@ +from datetime import date, datetime from enum import Enum, auto -from datetime import datetime, date -from typing import Optional from src.common.logger import get_logger from src.manager.local_store_manager import local_storage @@ -29,10 +28,10 @@ class SleepContext: def __init__(self): """初始化睡眠上下文,并从本地存储加载初始状态。""" self.current_state: SleepState = SleepState.AWAKE - self.sleep_buffer_end_time: Optional[datetime] = None + self.sleep_buffer_end_time: datetime | None = None self.total_delayed_minutes_today: float = 0.0 - self.last_sleep_check_date: Optional[date] = None - self.re_sleep_attempt_time: Optional[datetime] = None + self.last_sleep_check_date: date | None = None + self.re_sleep_attempt_time: datetime | None = None self.load() def save(self): diff --git a/src/chat/message_manager/sleep_manager/time_checker.py b/src/chat/message_manager/sleep_manager/time_checker.py index 773830c3a..0ea099039 100644 --- a/src/chat/message_manager/sleep_manager/time_checker.py +++ b/src/chat/message_manager/sleep_manager/time_checker.py @@ -1,6 +1,6 @@ -from datetime import datetime, time, timedelta -from typing import Optional, List, Dict, Any import random +from datetime import datetime, time, timedelta +from typing import Any from src.common.logger import get_logger from src.config.config import global_config @@ -37,11 +37,11 @@ class TimeChecker: return self._daily_sleep_offset, self._daily_wake_offset @staticmethod - def get_today_schedule() -> Optional[List[Dict[str, Any]]]: + def get_today_schedule() -> list[dict[str, Any]] | None: """从全局 ScheduleManager 获取今天的日程安排。""" return schedule_manager.today_schedule - def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]: + def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, str | None]: if global_config.sleep_system.sleep_by_schedule: if self.get_today_schedule(): return self._is_in_schedule_sleep_time(now_time) @@ -50,7 +50,7 @@ class TimeChecker: else: return self._is_in_sleep_time(now_time) - def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]: + def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, str | None]: """检查当前时间是否落在日程表的任何一个睡眠活动中""" sleep_keywords = ["休眠", "睡觉", "梦乡"] today_schedule = self.get_today_schedule() @@ -79,7 +79,7 @@ class TimeChecker: continue return False, None - def _is_in_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]: + def _is_in_sleep_time(self, now_time: time) -> tuple[bool, str | None]: """检查当前时间是否在固定的睡眠时间内(应用偏移量)""" try: start_time_str = global_config.sleep_system.fixed_sleep_time diff --git a/src/chat/message_manager/sleep_manager/wakeup_manager.py b/src/chat/message_manager/sleep_manager/wakeup_manager.py index 5fc68ff41..d390d9d3d 100644 --- a/src/chat/message_manager/sleep_manager/wakeup_manager.py +++ b/src/chat/message_manager/sleep_manager/wakeup_manager.py @@ -1,9 +1,10 @@ import asyncio import time -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING + +from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext from src.common.logger import get_logger from src.config.config import global_config -from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext if TYPE_CHECKING: from .sleep_manager import SleepManager @@ -27,9 +28,9 @@ class WakeUpManager: """ self.sleep_manager = sleep_manager self.context = WakeUpContext() # 使用新的上下文管理器 - self.angry_chat_id: Optional[str] = None + self.angry_chat_id: str | None = None self.last_decay_time = time.time() - self._decay_task: Optional[asyncio.Task] = None + self._decay_task: asyncio.Task | None = None self.is_running = False self.last_log_time = 0 self.log_interval = 30 @@ -104,9 +105,7 @@ class WakeUpManager: logger.debug(f"唤醒度衰减: {old_value:.1f} -> {self.context.wakeup_value:.1f}") self.context.save() - def add_wakeup_value( - self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None - ) -> bool: + def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False, chat_id: str | None = None) -> bool: """ 增加唤醒度值 diff --git a/src/chat/message_receive/__init__.py b/src/chat/message_receive/__init__.py index 44b9eee36..32a3fe9f5 100644 --- a/src/chat/message_receive/__init__.py +++ b/src/chat/message_receive/__init__.py @@ -2,9 +2,8 @@ from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.storage import MessageStorage - __all__ = [ - "get_emoji_manager", - "get_chat_manager", "MessageStorage", + "get_chat_manager", + "get_emoji_manager", ] diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 47d1f26e2..2007d01ec 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -1,25 +1,24 @@ -import traceback import os import re +import traceback +from typing import Any -from typing import Dict, Any, Optional from maim_message import UserInfo -from src.common.logger import get_logger -from src.config.config import global_config -from src.mood.mood_manager import mood_manager # 导入情绪管理器 -from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream -from src.chat.message_receive.message import MessageRecv, MessageRecvS4U -from src.chat.message_receive.storage import MessageStorage -from src.chat.message_manager import message_manager -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.plugin_system.core import component_registry, event_manager, global_announcement_manager -from src.plugin_system.base import BaseCommand, EventType -from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor -from src.chat.utils.utils import is_mentioned_bot_in_message - # 导入反注入系统 from src.chat.antipromptinjector import initialize_anti_injector +from src.chat.message_manager import message_manager +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.message import MessageRecv, MessageRecvS4U +from src.chat.message_receive.storage import MessageStorage +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.chat.utils.utils import is_mentioned_bot_in_message +from src.common.logger import get_logger +from src.config.config import global_config +from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor +from src.mood.mood_manager import mood_manager # 导入情绪管理器 +from src.plugin_system.base import BaseCommand, EventType +from src.plugin_system.core import component_registry, event_manager, global_announcement_manager # 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录) PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) @@ -219,7 +218,7 @@ class ChatBot: logger.error(traceback.format_exc()) try: - await plus_command_instance.send_text(f"命令执行出错: {str(e)}") + await plus_command_instance.send_text(f"命令执行出错: {e!s}") except Exception as send_error: logger.error(f"发送错误消息失败: {send_error}") @@ -286,7 +285,7 @@ class ChatBot: logger.error(traceback.format_exc()) try: - await command_instance.send_text(f"命令执行出错: {str(e)}") + await command_instance.send_text(f"命令执行出错: {e!s}") except Exception as send_error: logger.error(f"发送错误消息失败: {send_error}") @@ -338,7 +337,7 @@ class ChatBot: except Exception as e: logger.error(f"处理适配器响应时出错: {e}") - async def do_s4u(self, message_data: Dict[str, Any]): + async def do_s4u(self, message_data: dict[str, Any]): message = MessageRecvS4U(message_data) group_info = message.message_info.group_info user_info = message.message_info.user_info @@ -359,7 +358,7 @@ class ChatBot: return - async def message_process(self, message_data: Dict[str, Any]) -> None: + async def message_process(self, message_data: dict[str, Any]) -> None: """处理转化后的统一格式消息""" try: # 首先处理可能的切片消息重组 @@ -458,7 +457,7 @@ class ChatBot: # TODO:暂不可用 # 确认从接口发来的message是否有自定义的prompt模板信息 if message.message_info.template_info and not message.message_info.template_info.template_default: - template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore + template_group_name: str | None = message.message_info.template_info.template_name # type: ignore template_items = message.message_info.template_info.template_items async with global_prompt_manager.async_message_scope(template_group_name): if isinstance(template_items, dict): diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 559490694..40833b285 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -1,17 +1,18 @@ import asyncio +import copy import hashlib import time -import copy -from typing import Dict, Optional, TYPE_CHECKING -from rich.traceback import install -from maim_message import GroupInfo, UserInfo +from typing import TYPE_CHECKING -from src.common.logger import get_logger +from maim_message import GroupInfo, UserInfo +from rich.traceback import install from sqlalchemy import select -from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.dialects.mysql import insert as mysql_insert -from src.common.database.sqlalchemy_models import ChatStreams # 新增导入 +from sqlalchemy.dialects.sqlite import insert as sqlite_insert + from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import ChatStreams # 新增导入 +from src.common.logger import get_logger from src.config.config import global_config # 新增导入 # 避免循环导入,使用TYPE_CHECKING进行类型提示 @@ -33,8 +34,8 @@ class ChatStream: stream_id: str, platform: str, user_info: UserInfo, - group_info: Optional[GroupInfo] = None, - data: Optional[dict] = None, + group_info: GroupInfo | None = None, + data: dict | None = None, ): self.stream_id = stream_id self.platform = platform @@ -47,7 +48,7 @@ class ChatStream: # 使用StreamContext替代ChatMessageContext from src.common.data_models.message_manager_data_model import StreamContext - from src.plugin_system.base.component_types import ChatType, ChatMode + from src.plugin_system.base.component_types import ChatMode, ChatType # 创建StreamContext self.stream_context: StreamContext = StreamContext( @@ -133,11 +134,11 @@ class ChatStream: # 恢复stream_context信息 if "stream_context_chat_type" in data: - from src.plugin_system.base.component_types import ChatType, ChatMode + from src.plugin_system.base.component_types import ChatMode, ChatType instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"]) if "stream_context_chat_mode" in data: - from src.plugin_system.base.component_types import ChatType, ChatMode + from src.plugin_system.base.component_types import ChatMode, ChatType instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"]) @@ -163,9 +164,10 @@ class ChatStream: def set_context(self, message: "MessageRecv"): """设置聊天消息上下文""" # 将MessageRecv转换为DatabaseMessages并设置到stream_context - from src.common.data_models.database_data_model import DatabaseMessages import json + from src.common.data_models.database_data_model import DatabaseMessages + # 安全获取message_info中的数据 message_info = getattr(message, "message_info", {}) user_info = getattr(message_info, "user_info", {}) @@ -248,7 +250,7 @@ class ChatStream: f"interest_value: {db_message.interest_value}" ) - def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]: + def _safe_get_actions(self, message: "MessageRecv") -> list | None: """安全获取消息的actions字段""" try: actions = getattr(message, "actions", None) @@ -278,7 +280,7 @@ class ChatStream: logger.warning(f"获取actions字段失败: {e}") return None - def _extract_reply_from_segment(self, segment) -> Optional[str]: + def _extract_reply_from_segment(self, segment) -> str | None: """从消息段中提取reply_to信息""" try: if hasattr(segment, "type") and segment.type == "seglist": @@ -391,8 +393,8 @@ class ChatManager: def __init__(self): if not self._initialized: - self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream - self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message + self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream + self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message # try: # async with get_db_session() as session: # db.connect(reuse_if_open=True) @@ -414,7 +416,7 @@ class ChatManager: await self.load_all_streams() logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流") except Exception as e: - logger.error(f"聊天管理器启动失败: {str(e)}") + logger.error(f"聊天管理器启动失败: {e!s}") async def _auto_save_task(self): """定期自动保存所有聊天流""" @@ -424,7 +426,7 @@ class ChatManager: await self._save_all_streams() logger.info("聊天流自动保存完成") except Exception as e: - logger.error(f"聊天流自动保存失败: {str(e)}") + logger.error(f"聊天流自动保存失败: {e!s}") def register_message(self, message: "MessageRecv"): """注册消息到聊天流""" @@ -437,9 +439,7 @@ class ChatManager: # logger.debug(f"注册消息到聊天流: {stream_id}") @staticmethod - def _generate_stream_id( - platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None - ) -> str: + def _generate_stream_id(platform: str, user_info: UserInfo | None, group_info: GroupInfo | None = None) -> str: """生成聊天流唯一ID""" if not user_info and not group_info: raise ValueError("用户信息或群组信息必须提供") @@ -462,7 +462,7 @@ class ChatManager: return hashlib.md5(key.encode()).hexdigest() async def get_or_create_stream( - self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None + self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None ) -> ChatStream: """获取或创建聊天流 @@ -572,7 +572,7 @@ class ChatManager: await self._save_stream(stream) return stream - def get_stream(self, stream_id: str) -> Optional[ChatStream]: + def get_stream(self, stream_id: str) -> ChatStream | None: """通过stream_id获取聊天流""" stream = self.streams.get(stream_id) if not stream: @@ -582,13 +582,13 @@ class ChatManager: return stream def get_stream_by_info( - self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None - ) -> Optional[ChatStream]: + self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None + ) -> ChatStream | None: """通过信息获取聊天流""" stream_id = self._generate_stream_id(platform, user_info, group_info) return self.streams.get(stream_id) - def get_stream_name(self, stream_id: str) -> Optional[str]: + def get_stream_name(self, stream_id: str) -> str | None: """根据 stream_id 获取聊天流名称""" stream = self.get_stream(stream_id) if not stream: diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index fee932b62..7953ff862 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -1,20 +1,19 @@ import base64 import time -from abc import abstractmethod, ABCMeta +from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Optional, Any +from typing import Any, Optional import urllib3 -from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase +from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo from rich.traceback import install +from src.chat.message_receive.chat_stream import ChatStream from src.chat.utils.utils_image import get_image_manager from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available from src.chat.utils.utils_voice import get_voice_text from src.common.logger import get_logger from src.config.config import global_config -from src.chat.message_receive.chat_stream import ChatStream - install(extra_lines=3) @@ -41,8 +40,8 @@ class Message(MessageBase, metaclass=ABCMeta): message_id: str, chat_stream: "ChatStream", user_info: UserInfo, - message_segment: Optional[Seg] = None, - timestamp: Optional[float] = None, + message_segment: Seg | None = None, + timestamp: float | None = None, reply: Optional["MessageRecv"] = None, processed_plain_text: str = "", ): @@ -264,7 +263,7 @@ class MessageRecv(Message): logger.warning("视频消息中没有base64数据") return "[收到视频消息,但数据异常]" except Exception as e: - logger.error(f"视频处理失败: {str(e)}") + logger.error(f"视频处理失败: {e!s}") import traceback logger.error(f"错误详情: {traceback.format_exc()}") @@ -278,7 +277,7 @@ class MessageRecv(Message): logger.info("未启用视频识别") return "[视频]" except Exception as e: - logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") + logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") return f"[处理失败的{segment.type}消息]" @@ -291,7 +290,7 @@ class MessageRecvS4U(MessageRecv): self.is_superchat = False self.gift_info = None self.gift_name = None - self.gift_count: Optional[str] = None + self.gift_count: str | None = None self.superchat_info = None self.superchat_price = None self.superchat_message_text = None @@ -444,7 +443,7 @@ class MessageRecvS4U(MessageRecv): logger.warning("视频消息中没有base64数据") return "[收到视频消息,但数据异常]" except Exception as e: - logger.error(f"视频处理失败: {str(e)}") + logger.error(f"视频处理失败: {e!s}") import traceback logger.error(f"错误详情: {traceback.format_exc()}") @@ -458,7 +457,7 @@ class MessageRecvS4U(MessageRecv): logger.info("未启用视频识别") return "[视频]" except Exception as e: - logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") + logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") return f"[处理失败的{segment.type}消息]" @@ -471,10 +470,10 @@ class MessageProcessBase(Message): message_id: str, chat_stream: "ChatStream", bot_user_info: UserInfo, - message_segment: Optional[Seg] = None, + message_segment: Seg | None = None, reply: Optional["MessageRecv"] = None, thinking_start_time: float = 0, - timestamp: Optional[float] = None, + timestamp: float | None = None, ): # 调用父类初始化,传递时间戳 super().__init__( @@ -533,9 +532,9 @@ class MessageProcessBase(Message): return f"[回复<{self.reply.message_info.user_info.user_nickname}> 的消息:{self.reply.processed_plain_text}]" # type: ignore return None else: - return f"[{seg.type}:{str(seg.data)}]" + return f"[{seg.type}:{seg.data!s}]" except Exception as e: - logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}") + logger.error(f"处理消息段失败: {e!s}, 类型: {seg.type}, 数据: {seg.data}") return f"[处理失败的{seg.type}消息]" def _generate_detailed_text(self) -> str: @@ -565,7 +564,7 @@ class MessageSending(MessageProcessBase): is_emoji: bool = False, thinking_start_time: float = 0, apply_set_reply_logic: bool = False, - reply_to: Optional[str] = None, + reply_to: str | None = None, ): # 调用父类初始化 super().__init__( @@ -635,11 +634,11 @@ class MessageSet: self.messages.append(message) self.messages.sort(key=lambda x: x.message_info.time) # type: ignore - def get_message_by_index(self, index: int) -> Optional[MessageSending]: + def get_message_by_index(self, index: int) -> MessageSending | None: """通过索引获取消息""" return self.messages[index] if 0 <= index < len(self.messages) else None - def get_message_by_time(self, target_time: float) -> Optional[MessageSending]: + def get_message_by_time(self, target_time: float) -> MessageSending | None: """获取最接近指定时间的消息""" if not self.messages: return None diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 5a654e867..1382adfb8 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,14 +1,15 @@ import re import traceback -import orjson -from typing import Union -from src.common.database.sqlalchemy_models import Messages, Images -from src.common.logger import get_logger -from .chat_stream import ChatStream -from .message import MessageSending, MessageRecv +import orjson +from sqlalchemy import desc, select, update + from src.common.database.sqlalchemy_database_api import get_db_session -from sqlalchemy import select, update, desc +from src.common.database.sqlalchemy_models import Images, Messages +from src.common.logger import get_logger + +from .chat_stream import ChatStream +from .message import MessageRecv, MessageSending logger = get_logger("message_storage") @@ -32,7 +33,7 @@ class MessageStorage: return [] @staticmethod - async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None: + async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None: """存储消息到数据库""" try: # 过滤敏感信息的正则模式 @@ -299,6 +300,7 @@ class MessageStorage: try: async with get_db_session() as session: from sqlalchemy import select, update + from src.common.database.sqlalchemy_models import Messages # 查找需要修复的记录:interest_value为0、null或很小的值 diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index a881549f5..bd23402e2 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -3,12 +3,11 @@ import traceback from rich.traceback import install -from src.common.message.api import get_global_api -from src.common.logger import get_logger from src.chat.message_receive.message import MessageSending from src.chat.message_receive.storage import MessageStorage -from src.chat.utils.utils import truncate_message -from src.chat.utils.utils import calculate_typing_time +from src.chat.utils.utils import calculate_typing_time, truncate_message +from src.common.logger import get_logger +from src.common.message.api import get_global_api install(extra_lines=3) @@ -27,7 +26,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool: return True except Exception as e: - logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}") + logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {e!s}") traceback.print_exc() raise e # 重新抛出其他异常 diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 21a00ee52..9adde80cb 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -1,19 +1,17 @@ import asyncio -import traceback import time -from typing import Dict, Optional, Type, Any, Tuple +import traceback +from typing import Any - -from src.chat.utils.timer_calculator import Timer -from src.person_info.person_info import get_person_info_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.utils.timer_calculator import Timer from src.common.logger import get_logger from src.config.config import global_config -from src.plugin_system.core.component_registry import component_registry -from src.plugin_system.base.component_types import ComponentType, ActionInfo +from src.person_info.person_info import get_person_info_manager +from src.plugin_system.apis import database_api, generator_api, message_api, send_api from src.plugin_system.base.base_action import BaseAction -from src.plugin_system.apis import generator_api, database_api, send_api, message_api - +from src.plugin_system.base.component_types import ActionInfo, ComponentType +from src.plugin_system.core.component_registry import component_registry logger = get_logger("action_manager") @@ -29,7 +27,7 @@ class ChatterActionManager: """初始化动作管理器""" # 当前正在使用的动作集合,默认加载默认动作 - self._using_actions: Dict[str, ActionInfo] = {} + self._using_actions: dict[str, ActionInfo] = {} # 初始化时将默认动作加载到使用中的动作 self._using_actions = component_registry.get_default_actions() @@ -48,8 +46,8 @@ class ChatterActionManager: chat_stream: ChatStream, log_prefix: str, shutting_down: bool = False, - action_message: Optional[dict] = None, - ) -> Optional[BaseAction]: + action_message: dict | None = None, + ) -> BaseAction | None: """ 创建动作处理器实例 @@ -68,7 +66,7 @@ class ChatterActionManager: """ try: # 获取组件类 - 明确指定查询Action类型 - component_class: Type[BaseAction] = component_registry.get_component_class( + component_class: type[BaseAction] = component_registry.get_component_class( action_name, ComponentType.ACTION ) # type: ignore if not component_class: @@ -107,7 +105,7 @@ class ChatterActionManager: logger.error(traceback.format_exc()) return None - def get_using_actions(self) -> Dict[str, ActionInfo]: + def get_using_actions(self) -> dict[str, ActionInfo]: """获取当前正在使用的动作集合""" return self._using_actions.copy() @@ -140,10 +138,10 @@ class ChatterActionManager: self, action_name: str, chat_id: str, - target_message: Optional[dict] = None, + target_message: dict | None = None, reasoning: str = "", - action_data: Optional[dict] = None, - thinking_id: Optional[str] = None, + action_data: dict | None = None, + thinking_id: str | None = None, log_prefix: str = "", clear_unread_messages: bool = True, ) -> Any: @@ -437,10 +435,10 @@ class ChatterActionManager: response_set, loop_start_time, action_message, - cycle_timers: Dict[str, float], + cycle_timers: dict[str, float], thinking_id, actions, - ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: + ) -> tuple[dict[str, Any], str, dict[str, float]]: """ 发送并存储回复信息 @@ -488,7 +486,7 @@ class ChatterActionManager: ) # 构建循环信息 - loop_info: Dict[str, Any] = { + loop_info: dict[str, Any] = { "loop_plan_info": { "action_result": actions, }, diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 4e144d3f4..4f3e4b099 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -1,17 +1,17 @@ -import random import asyncio import hashlib +import random import time -from typing import List, Any, Dict, TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Any +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.planner_actions.action_manager import ChatterActionManager +from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat +from src.common.data_models.message_manager_data_model import StreamContext from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.chat.message_receive.chat_stream import get_chat_manager -from src.common.data_models.message_manager_data_model import StreamContext -from src.chat.planner_actions.action_manager import ChatterActionManager -from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages -from src.plugin_system.base.component_types import ActionInfo, ActionActivationType +from src.plugin_system.base.component_types import ActionActivationType, ActionInfo from src.plugin_system.core.global_announcement_manager import global_announcement_manager if TYPE_CHECKING: @@ -59,18 +59,17 @@ class ActionModifier: """ logger.debug(f"{self.log_prefix}开始完整动作修改流程") - removals_s1: List[Tuple[str, str]] = [] - removals_s2: List[Tuple[str, str]] = [] - removals_s3: List[Tuple[str, str]] = [] + removals_s1: list[tuple[str, str]] = [] + removals_s2: list[tuple[str, str]] = [] + removals_s3: list[tuple[str, str]] = [] self.action_manager.restore_actions() all_actions = self.action_manager.get_using_actions() # === 第0阶段:根据聊天类型过滤动作 === - from src.plugin_system.base.component_types import ChatType - from src.plugin_system.core.component_registry import component_registry - from src.plugin_system.base.component_types import ComponentType from src.chat.utils.utils import get_chat_type_and_target_info + from src.plugin_system.base.component_types import ChatType, ComponentType + from src.plugin_system.core.component_registry import component_registry # 获取聊天类型 is_group_chat, _ = get_chat_type_and_target_info(self.chat_id) @@ -167,8 +166,8 @@ class ActionModifier: logger.info(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}") - def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: StreamContext): - type_mismatched_actions: List[Tuple[str, str]] = [] + def _check_action_associated_types(self, all_actions: dict[str, ActionInfo], chat_context: StreamContext): + type_mismatched_actions: list[tuple[str, str]] = [] for action_name, action_info in all_actions.items(): if action_info.associated_types and not chat_context.check_types(action_info.associated_types): associated_types_str = ", ".join(action_info.associated_types) @@ -179,9 +178,9 @@ class ActionModifier: async def _get_deactivated_actions_by_type( self, - actions_with_info: Dict[str, ActionInfo], + actions_with_info: dict[str, ActionInfo], chat_content: str = "", - ) -> List[tuple[str, str]]: + ) -> list[tuple[str, str]]: """ 根据激活类型过滤,返回需要停用的动作列表及原因 @@ -254,9 +253,9 @@ class ActionModifier: async def _process_llm_judge_actions_parallel( self, - llm_judge_actions: Dict[str, Any], + llm_judge_actions: dict[str, Any], chat_content: str = "", - ) -> Dict[str, bool]: + ) -> dict[str, bool]: """ 并行处理LLM判定actions,支持智能缓存 diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 063fc1bf1..0d4b0b574 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -3,42 +3,41 @@ 使用重构后的统一Prompt系统替换原有的复杂提示词构建逻辑 """ -import traceback -import time import asyncio import random import re - -from typing import List, Optional, Dict, Any, Tuple +import time +import traceback from datetime import datetime -from src.mais4u.mai_think import mai_thinking_manager -from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.individuality.individuality import get_individuality -from src.llm_models.utils_model import LLMRequest -from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending +from typing import Any + +from src.chat.express.expression_selector import expression_selector from src.chat.message_receive.chat_stream import ChatStream -from src.chat.utils.memory_mappings import get_memory_type_chinese_label +from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo from src.chat.message_receive.uni_message_sender import HeartFCSender -from src.chat.utils.timer_calculator import Timer -from src.chat.utils.utils import get_chat_type_and_target_info -from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import ( build_readable_messages, get_raw_msg_before_timestamp_with_chat, replace_user_references_sync, ) -from src.chat.express.expression_selector import expression_selector +from src.chat.utils.memory_mappings import get_memory_type_chinese_label + +# 导入新的统一Prompt系统 +from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager +from src.chat.utils.timer_calculator import Timer +from src.chat.utils.utils import get_chat_type_and_target_info +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.individuality.individuality import get_individuality +from src.llm_models.utils_model import LLMRequest +from src.mais4u.mai_think import mai_thinking_manager # 旧记忆系统已被移除 # 旧记忆系统已被移除 from src.mood.mood_manager import mood_manager from src.person_info.person_info import get_person_info_manager -from src.plugin_system.base.component_types import ActionInfo, EventType from src.plugin_system.apis import llm_api - -# 导入新的统一Prompt系统 -from src.chat.utils.prompt import PromptParameters +from src.plugin_system.base.component_types import ActionInfo, EventType logger = get_logger("replyer") @@ -248,12 +247,12 @@ class DefaultReplyer: self, reply_to: str = "", extra_info: str = "", - available_actions: Optional[Dict[str, ActionInfo]] = None, + available_actions: dict[str, ActionInfo] | None = None, enable_tool: bool = True, from_plugin: bool = True, - stream_id: Optional[str] = None, - reply_message: Optional[Dict[str, Any]] = None, - ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: + stream_id: str | None = None, + reply_message: dict[str, Any] | None = None, + ) -> tuple[bool, dict[str, Any] | None, str | None]: # sourcery skip: merge-nested-ifs """ 回复器 (Replier): 负责生成回复文本的核心逻辑。 @@ -353,7 +352,7 @@ class DefaultReplyer: reason: str = "", reply_to: str = "", return_prompt: bool = False, - ) -> Tuple[bool, Optional[str], Optional[str]]: + ) -> tuple[bool, str | None, str | None]: """ 表达器 (Expressor): 负责重写和优化回复文本。 @@ -722,7 +721,7 @@ class DefaultReplyer: logger.error(f"工具信息获取失败: {e}") return "" - def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: + def _parse_reply_target(self, target_message: str) -> tuple[str, str]: """解析回复目标消息 - 使用共享工具""" from src.chat.utils.prompt import Prompt @@ -731,7 +730,7 @@ class DefaultReplyer: return "未知用户", "(无消息内容)" return Prompt.parse_reply_target(target_message) - async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str: + async def build_keywords_reaction_prompt(self, target: str | None) -> str: """构建关键词反应提示 Args: @@ -766,14 +765,14 @@ class DefaultReplyer: keywords_reaction_prompt += f"{reaction}," break except re.error as e: - logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}") + logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {e!s}") continue except Exception as e: - logger.error(f"关键词检测与反应时发生异常: {str(e)}", exc_info=True) + logger.error(f"关键词检测与反应时发生异常: {e!s}", exc_info=True) return keywords_reaction_prompt - async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]: + async def _time_and_run_task(self, coroutine, name: str) -> tuple[str, Any, float]: """计时并运行异步任务的辅助函数 Args: @@ -790,8 +789,8 @@ class DefaultReplyer: return name, result, duration async def build_s4u_chat_history_prompts( - self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str - ) -> Tuple[str, str]: + self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str + ) -> tuple[str, str]: """ 构建 s4u 风格的已读/未读历史消息 prompt @@ -907,8 +906,8 @@ class DefaultReplyer: return await self._fallback_build_chat_history_prompts(message_list_before_now, target_user_id, sender) async def _fallback_build_chat_history_prompts( - self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str - ) -> Tuple[str, str]: + self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str + ) -> tuple[str, str]: """ 回退的已读/未读历史消息构建方法 """ @@ -1000,15 +999,15 @@ class DefaultReplyer: return read_history_prompt, unread_history_prompt - async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]: + async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]: """为消息获取兴趣度评分""" interest_scores = {} try: + from src.common.data_models.database_data_model import DatabaseMessages from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ( chatter_interest_scoring_system as interest_scoring_system, ) - from src.common.data_models.database_data_model import DatabaseMessages # 转换消息格式 db_messages = [] @@ -1094,9 +1093,9 @@ class DefaultReplyer: self, reply_to: str, extra_info: str = "", - available_actions: Optional[Dict[str, ActionInfo]] = None, + available_actions: dict[str, ActionInfo] | None = None, enable_tool: bool = True, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: dict[str, Any] | None = None, ) -> str: """ 构建回复器上下文 @@ -1417,7 +1416,7 @@ class DefaultReplyer: raw_reply: str, reason: str, reply_to: str, - reply_message: Optional[Dict[str, Any]] = None, + reply_message: dict[str, Any] | None = None, ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if chat_stream = self.chat_stream chat_id = chat_stream.stream_id @@ -1553,7 +1552,7 @@ class DefaultReplyer: is_emoji: bool, thinking_start_time: float, display_message: str, - anchor_message: Optional[MessageRecv] = None, + anchor_message: MessageRecv | None = None, ) -> MessageSending: """构建单个发送消息""" @@ -1644,7 +1643,7 @@ class DefaultReplyer: logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...") return "" except Exception as e: - logger.error(f"获取知识库内容时发生异常: {str(e)}") + logger.error(f"获取知识库内容时发生异常: {e!s}") return "" async def build_relation_info(self, sender: str, target: str): @@ -1660,10 +1659,9 @@ class DefaultReplyer: # 使用AFC关系追踪器获取关系信息 try: - from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker - # 创建关系追踪器实例 from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system + from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker relationship_tracker = ChatterRelationshipTracker(chatter_interest_scoring_system) if relationship_tracker: @@ -1704,7 +1702,7 @@ class DefaultReplyer: logger.error(f"获取AFC关系信息失败: {e}") return f"你与{sender}是普通朋友关系。" - async def _store_chat_memory_async(self, reply_to: str, reply_message: Optional[Dict[str, Any]] = None): + async def _store_chat_memory_async(self, reply_to: str, reply_message: dict[str, Any] | None = None): """ 异步存储聊天记忆(从build_memory_block迁移而来) diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index 2f64ab07f..55a422c1b 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -1,22 +1,20 @@ -from typing import Dict, Optional - -from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.replyer.default_generator import DefaultReplyer +from src.common.logger import get_logger logger = get_logger("ReplyerManager") class ReplyerManager: def __init__(self): - self._repliers: Dict[str, DefaultReplyer] = {} + self._repliers: dict[str, DefaultReplyer] = {} def get_replyer( self, - chat_stream: Optional[ChatStream] = None, - chat_id: Optional[str] = None, + chat_stream: ChatStream | None = None, + chat_id: str | None = None, request_type: str = "replyer", - ) -> Optional[DefaultReplyer]: + ) -> DefaultReplyer | None: """ 获取或创建回复器实例。 diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 8503e369a..65c123338 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -1,18 +1,19 @@ -import time # 导入 time 模块以获取当前时间 import random import re +import time # 导入 time 模块以获取当前时间 +from collections.abc import Callable +from typing import Any -from typing import List, Dict, Any, Tuple, Optional, Callable from rich.traceback import install +from sqlalchemy import and_, select -from src.config.config import global_config -from src.common.message_repository import find_messages, count_messages -from src.common.database.sqlalchemy_models import ActionRecords, Images -from src.person_info.person_info import PersonInfoManager, get_person_info_manager -from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids +from src.chat.utils.utils import assign_message_ids, translate_timestamp_to_human_readable from src.common.database.sqlalchemy_database_api import get_db_session -from sqlalchemy import select, and_ +from src.common.database.sqlalchemy_models import ActionRecords, Images from src.common.logger import get_logger +from src.common.message_repository import count_messages, find_messages +from src.config.config import global_config +from src.person_info.person_info import PersonInfoManager, get_person_info_manager logger = get_logger("chat_message_builder") @@ -22,7 +23,7 @@ install(extra_lines=3) def replace_user_references_sync( content: str, platform: str, - name_resolver: Optional[Callable[[str, str], str]] = None, + name_resolver: Callable[[str, str], str] | None = None, replace_bot_name: bool = True, ) -> str: """ @@ -100,7 +101,7 @@ def replace_user_references_sync( async def replace_user_references_async( content: str, platform: str, - name_resolver: Optional[Callable[[str, str], Any]] = None, + name_resolver: Callable[[str, str], Any] | None = None, replace_bot_name: bool = True, ) -> str: """ @@ -174,7 +175,7 @@ async def replace_user_references_async( async def get_raw_msg_by_timestamp( timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 @@ -194,7 +195,7 @@ async def get_raw_msg_by_timestamp_with_chat( limit_mode: str = "latest", filter_bot=False, filter_command=False, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。 @@ -220,7 +221,7 @@ async def get_raw_msg_by_timestamp_with_chat_inclusive( limit: int = 0, limit_mode: str = "latest", filter_bot=False, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。 @@ -239,10 +240,10 @@ async def get_raw_msg_by_timestamp_with_chat_users( chat_id: str, timestamp_start: float, timestamp_end: float, - person_ids: List[str], + person_ids: list[str], limit: int = 0, limit_mode: str = "latest", -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取某些特定用户在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。 @@ -263,7 +264,7 @@ async def get_actions_by_timestamp_with_chat( timestamp_end: float = time.time(), limit: int = 0, limit_mode: str = "latest", -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表""" from src.common.logger import get_logger @@ -372,7 +373,7 @@ async def get_actions_by_timestamp_with_chat( async def get_actions_by_timestamp_with_chat_inclusive( chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表""" async with get_db_session() as session: if limit > 0: @@ -423,7 +424,7 @@ async def get_actions_by_timestamp_with_chat_inclusive( async def get_raw_msg_by_timestamp_random( timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息 """ @@ -441,7 +442,7 @@ async def get_raw_msg_by_timestamp_random( async def get_raw_msg_by_timestamp_with_users( timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest" -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。 @@ -452,7 +453,7 @@ async def get_raw_msg_by_timestamp_with_users( return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> list[dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ @@ -463,7 +464,7 @@ async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List async def get_raw_msg_before_timestamp_with_chat( chat_id: str, timestamp: float, limit: int = 0 -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ @@ -474,7 +475,7 @@ async def get_raw_msg_before_timestamp_with_chat( async def get_raw_msg_before_timestamp_with_users( timestamp: float, person_ids: list, limit: int = 0 -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ @@ -483,9 +484,7 @@ async def get_raw_msg_before_timestamp_with_users( return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -async def num_new_messages_since( - chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None -) -> int: +async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float | None = None) -> int: """ 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。 如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。 @@ -517,16 +516,16 @@ async def num_new_messages_since_with_users( async def _build_readable_messages_internal( - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, timestamp_mode: str = "relative", truncate: bool = False, - pic_id_mapping: Optional[Dict[str, str]] = None, + pic_id_mapping: dict[str, str] | None = None, pic_counter: int = 1, show_pic: bool = True, - message_id_list: Optional[List[Dict[str, Any]]] = None, -) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]: + message_id_list: list[dict[str, Any]] | None = None, +) -> tuple[str, list[tuple[float, str, str]], dict[str, str], int]: """ 内部辅助函数,构建可读消息字符串和原始消息详情列表。 @@ -545,7 +544,7 @@ async def _build_readable_messages_internal( if not messages: return "", [], pic_id_mapping or {}, pic_counter - message_details_raw: List[Tuple[float, str, str, bool]] = [] + message_details_raw: list[tuple[float, str, str, bool]] = [] # 使用传入的映射字典,如果没有则创建新的 if pic_id_mapping is None: @@ -672,7 +671,7 @@ async def _build_readable_messages_internal( message_details_with_flags.append((timestamp, name, content, is_action)) # 应用截断逻辑 (如果 truncate 为 True) - message_details: List[Tuple[float, str, str, bool]] = [] + message_details: list[tuple[float, str, str, bool]] = [] n_messages = len(message_details_with_flags) if truncate and n_messages > 0: for i, (timestamp, name, content, is_action) in enumerate(message_details_with_flags): @@ -809,7 +808,7 @@ async def _build_readable_messages_internal( ) -async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: +async def build_pic_mapping_info(pic_id_mapping: dict[str, str]) -> str: # sourcery skip: use-contextlib-suppress """ 构建图片映射信息字符串,显示图片的具体描述内容 @@ -847,7 +846,7 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: return "\n".join(mapping_lines) -def build_readable_actions(actions: List[Dict[str, Any]]) -> str: +def build_readable_actions(actions: list[dict[str, Any]]) -> str: """ 将动作列表转换为可读的文本格式。 格式: 在()分钟前,你使用了(action_name),具体内容是:(action_prompt_display) @@ -922,12 +921,12 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str: async def build_readable_messages_with_list( - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, timestamp_mode: str = "relative", truncate: bool = False, -) -> Tuple[str, List[Tuple[float, str, str]]]: +) -> tuple[str, list[tuple[float, str, str]]]: """ 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 允许通过参数控制格式化行为。 @@ -943,7 +942,7 @@ async def build_readable_messages_with_list( async def build_readable_messages_with_id( - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, timestamp_mode: str = "relative", @@ -951,7 +950,7 @@ async def build_readable_messages_with_id( truncate: bool = False, show_actions: bool = False, show_pic: bool = True, -) -> Tuple[str, List[Dict[str, Any]]]: +) -> tuple[str, list[dict[str, Any]]]: """ 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 允许通过参数控制格式化行为。 @@ -980,7 +979,7 @@ async def build_readable_messages_with_id( async def build_readable_messages( - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, timestamp_mode: str = "relative", @@ -988,7 +987,7 @@ async def build_readable_messages( truncate: bool = False, show_actions: bool = True, show_pic: bool = True, - message_id_list: Optional[List[Dict[str, Any]]] = None, + message_id_list: list[dict[str, Any]] | None = None, ) -> str: # sourcery skip: extract-method """ 将消息列表转换为可读的文本格式。 @@ -1148,7 +1147,7 @@ async def build_readable_messages( return "".join(result_parts) -async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: +async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str: """ 构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。 处理 回复 和 @ 字段,将bbb映射为匿名占位符。 @@ -1261,7 +1260,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: return formatted_string -async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: +async def get_person_id_list(messages: list[dict[str, Any]]) -> list[str]: """ 从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。 diff --git a/src/chat/utils/memory_mappings.py b/src/chat/utils/memory_mappings.py index 4da20fdb5..b82771f8e 100644 --- a/src/chat/utils/memory_mappings.py +++ b/src/chat/utils/memory_mappings.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 记忆系统相关的映射表和工具函数 提供记忆类型、置信度、重要性等的中文标签映射 diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index baf77a143..d869ec7a2 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -3,19 +3,20 @@ 将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类 """ -import re import asyncio -import time import contextvars -from dataclasses import dataclass, field -from typing import Dict, Any, Optional, List, Literal, Tuple +import re +import time from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from typing import Any, Literal, Optional from rich.traceback import install + +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.chat_message_builder import build_readable_messages from src.common.logger import get_logger from src.config.config import global_config -from src.chat.utils.chat_message_builder import build_readable_messages -from src.chat.message_receive.chat_stream import get_chat_manager from src.person_info.person_info import get_person_info_manager install(extra_lines=3) @@ -50,11 +51,11 @@ class PromptParameters: debug_mode: bool = False # 聊天历史和上下文 - chat_target_info: Optional[Dict[str, Any]] = None - message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list) - message_list_before_short: List[Dict[str, Any]] = field(default_factory=list) + chat_target_info: dict[str, Any] | None = None + message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list) + message_list_before_short: list[dict[str, Any]] = field(default_factory=list) chat_talking_prompt_short: str = "" - target_user_info: Optional[Dict[str, Any]] = None + target_user_info: dict[str, Any] | None = None # 已构建的内容块 expression_habits_block: str = "" @@ -77,12 +78,12 @@ class PromptParameters: action_descriptions: str = "" # 可用动作信息 - available_actions: Optional[Dict[str, Any]] = None + available_actions: dict[str, Any] | None = None # 动态生成的聊天场景提示 chat_scene: str = "" - def validate(self) -> List[str]: + def validate(self) -> list[str]: """参数验证""" errors = [] if not self.chat_id: @@ -98,22 +99,22 @@ class PromptContext: """提示词上下文管理器""" def __init__(self): - self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {} + self._context_prompts: dict[str, dict[str, "Prompt"]] = {} self._current_context_var = contextvars.ContextVar("current_context", default=None) self._context_lock = asyncio.Lock() @property - def _current_context(self) -> Optional[str]: + def _current_context(self) -> str | None: """获取当前协程的上下文ID""" return self._current_context_var.get() @_current_context.setter - def _current_context(self, value: Optional[str]): + def _current_context(self, value: str | None): """设置当前协程的上下文ID""" self._current_context_var.set(value) # type: ignore @asynccontextmanager - async def async_scope(self, context_id: Optional[str] = None): + async def async_scope(self, context_id: str | None = None): """创建一个异步的临时提示模板作用域""" if context_id is not None: try: @@ -159,7 +160,7 @@ class PromptContext: return self._context_prompts[current_context][name] return None - async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None: + async def register_async(self, prompt: "Prompt", context_id: str | None = None) -> None: """异步注册提示模板到指定作用域""" async with self._context_lock: if target_context := context_id or self._current_context: @@ -177,7 +178,7 @@ class PromptManager: self._lock = asyncio.Lock() @asynccontextmanager - async def async_message_scope(self, message_id: Optional[str] = None): + async def async_message_scope(self, message_id: str | None = None): """为消息处理创建异步临时作用域""" async with self._context.async_scope(message_id): yield self @@ -236,8 +237,8 @@ class Prompt: def __init__( self, template: str, - name: Optional[str] = None, - parameters: Optional[PromptParameters] = None, + name: str | None = None, + parameters: PromptParameters | None = None, should_register: bool = True, ): """ @@ -277,7 +278,7 @@ class Prompt: """将临时标记还原为实际的花括号字符""" return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}") - def _parse_template_args(self, template: str) -> List[str]: + def _parse_template_args(self, template: str) -> list[str]: """解析模板参数""" template_args = [] processed_template = self._process_escaped_braces(template) @@ -321,7 +322,7 @@ class Prompt: logger.error(f"构建Prompt失败: {e}") raise RuntimeError(f"构建Prompt失败: {e}") from e - async def _build_context_data(self) -> Dict[str, Any]: + async def _build_context_data(self) -> dict[str, Any]: """构建智能上下文数据""" # 并行执行所有构建任务 start_time = time.time() @@ -401,7 +402,7 @@ class Prompt: default_result = self._get_default_result_for_task(task_name) results.append(default_result) except Exception as e: - logger.error(f"构建任务{task_name}失败: {str(e)}") + logger.error(f"构建任务{task_name}失败: {e!s}") default_result = self._get_default_result_for_task(task_name) results.append(default_result) @@ -411,7 +412,7 @@ class Prompt: task_name = task_names[i] if i < len(task_names) else f"task_{i}" if isinstance(result, Exception): - logger.error(f"构建任务{task_name}失败: {str(result)}") + logger.error(f"构建任务{task_name}失败: {result!s}") elif isinstance(result, dict): context_data.update(result) @@ -453,7 +454,7 @@ class Prompt: return context_data - async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None: + async def _build_s4u_chat_context(self, context_data: dict[str, Any]) -> None: """构建S4U模式的聊天上下文""" if not self.parameters.message_list_before_now_long: return @@ -468,7 +469,7 @@ class Prompt: context_data["read_history_prompt"] = read_history_prompt context_data["unread_history_prompt"] = unread_history_prompt - async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None: + async def _build_normal_chat_context(self, context_data: dict[str, Any]) -> None: """构建normal模式的聊天上下文""" if not self.parameters.chat_talking_prompt_short: return @@ -477,8 +478,8 @@ class Prompt: {self.parameters.chat_talking_prompt_short}""" async def _build_s4u_chat_history_prompts( - self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str - ) -> Tuple[str, str]: + self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str + ) -> tuple[str, str]: """构建S4U风格的已读/未读历史消息prompt""" try: # 动态导入default_generator以避免循环导入 @@ -492,7 +493,7 @@ class Prompt: except Exception as e: logger.error(f"构建S4U历史消息prompt失败: {e}") - async def _build_expression_habits(self) -> Dict[str, Any]: + async def _build_expression_habits(self) -> dict[str, Any]: """构建表达习惯""" use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.parameters.chat_id) if not use_expression: @@ -533,7 +534,7 @@ class Prompt: logger.error(f"构建表达习惯失败: {e}") return {"expression_habits_block": ""} - async def _build_memory_block(self) -> Dict[str, Any]: + async def _build_memory_block(self) -> dict[str, Any]: """构建记忆块""" if not global_config.memory.enable_memory: return {"memory_block": ""} @@ -653,7 +654,7 @@ class Prompt: logger.error(f"构建记忆块失败: {e}") return {"memory_block": ""} - async def _build_memory_block_fast(self) -> Dict[str, Any]: + async def _build_memory_block_fast(self) -> dict[str, Any]: """快速构建记忆块(简化版本,用于未预构建时的后备方案)""" if not global_config.memory.enable_memory: return {"memory_block": ""} @@ -677,7 +678,7 @@ class Prompt: logger.warning(f"快速构建记忆块失败: {e}") return {"memory_block": ""} - async def _build_relation_info(self) -> Dict[str, Any]: + async def _build_relation_info(self) -> dict[str, Any]: """构建关系信息""" try: relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to) @@ -686,7 +687,7 @@ class Prompt: logger.error(f"构建关系信息失败: {e}") return {"relation_info_block": ""} - async def _build_tool_info(self) -> Dict[str, Any]: + async def _build_tool_info(self) -> dict[str, Any]: """构建工具信息""" if not global_config.tool.enable_tool: return {"tool_info_block": ""} @@ -734,7 +735,7 @@ class Prompt: logger.error(f"构建工具信息失败: {e}") return {"tool_info_block": ""} - async def _build_knowledge_info(self) -> Dict[str, Any]: + async def _build_knowledge_info(self) -> dict[str, Any]: """构建知识信息""" if not global_config.lpmm_knowledge.enable: return {"knowledge_prompt": ""} @@ -783,7 +784,7 @@ class Prompt: logger.error(f"构建知识信息失败: {e}") return {"knowledge_prompt": ""} - async def _build_cross_context(self) -> Dict[str, Any]: + async def _build_cross_context(self) -> dict[str, Any]: """构建跨群上下文""" try: cross_context = await Prompt.build_cross_context( @@ -794,7 +795,7 @@ class Prompt: logger.error(f"构建跨群上下文失败: {e}") return {"cross_context_block": ""} - async def _format_with_context(self, context_data: Dict[str, Any]) -> str: + async def _format_with_context(self, context_data: dict[str, Any]) -> str: """使用上下文数据格式化模板""" if self.parameters.prompt_mode == "s4u": params = self._prepare_s4u_params(context_data) @@ -805,7 +806,7 @@ class Prompt: return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params) - def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: + def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]: """准备S4U模式的参数""" return { **context_data, @@ -834,7 +835,7 @@ class Prompt: or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", } - def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: + def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]: """准备Normal模式的参数""" return { **context_data, @@ -862,7 +863,7 @@ class Prompt: or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", } - def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: + def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]: """准备默认模式的参数""" return { "expression_habits_block": context_data.get("expression_habits_block", ""), @@ -905,7 +906,7 @@ class Prompt: result = self._restore_escaped_braces(processed_template) return result except (IndexError, KeyError) as e: - raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e + raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {e!s}") from e def __str__(self) -> str: """返回格式化后的结果或原始模板""" @@ -922,7 +923,7 @@ class Prompt: # ============================================================================= @staticmethod - def parse_reply_target(target_message: str) -> Tuple[str, str]: + def parse_reply_target(target_message: str) -> tuple[str, str]: """ 解析回复目标消息 - 统一实现 @@ -981,7 +982,7 @@ class Prompt: return await relationship_fetcher.build_relation_info(person_id, points_num=5) - def _get_default_result_for_task(self, task_name: str) -> Dict[str, Any]: + def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]: """ 为超时的任务提供默认结果 @@ -1008,7 +1009,7 @@ class Prompt: return {} @staticmethod - async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str: + async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: dict[str, Any] | None) -> str: """ 构建跨群聊上下文 - 统一实现 @@ -1071,7 +1072,7 @@ class Prompt: # 工厂函数 def create_prompt( - template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs + template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs ) -> Prompt: """快速创建Prompt实例的工厂函数""" if parameters is None: @@ -1080,7 +1081,7 @@ def create_prompt( async def create_prompt_async( - template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs + template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs ) -> Prompt: """异步创建Prompt实例""" prompt = create_prompt(template, name, parameters, **kwargs) diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 1c879a01b..96433d21a 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -1,11 +1,11 @@ import asyncio from collections import defaultdict from datetime import datetime, timedelta -from typing import Any, Dict, Tuple, List +from typing import Any +from src.common.database.sqlalchemy_database_api import db_get, db_query, db_save +from src.common.database.sqlalchemy_models import LLMUsage, Messages, OnlineTime from src.common.logger import get_logger -from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages -from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get from src.manager.async_task_manager import AsyncTask from src.manager.local_store_manager import local_storage @@ -150,7 +150,7 @@ class StatisticOutputTask(AsyncTask): # 延迟300秒启动,运行间隔300秒 super().__init__(task_name="Statistics Data Output Task", wait_before_start=0, run_interval=300) - self.name_mapping: Dict[str, Tuple[str, float]] = {} + self.name_mapping: dict[str, tuple[str, float]] = {} """ 联系人/群聊名称映射 {聊天ID: (联系人/群聊名称, 记录时间(timestamp))} 注:设计记录时间的目的是方便更新名称,使联系人/群聊名称保持最新 @@ -170,7 +170,7 @@ class StatisticOutputTask(AsyncTask): deploy_time = datetime(2000, 1, 1) local_storage["deploy_time"] = now.timestamp() - self.stat_period: List[Tuple[str, timedelta, str]] = [ + self.stat_period: list[tuple[str, timedelta, str]] = [ ("all_time", now - deploy_time, "自部署以来"), # 必须保留"all_time" ("last_7_days", timedelta(days=7), "最近7天"), ("last_24_hours", timedelta(days=1), "最近24小时"), @@ -181,7 +181,7 @@ class StatisticOutputTask(AsyncTask): 统计时间段 [(统计名称, 统计时间段, 统计描述), ...] """ - def _statistic_console_output(self, stats: Dict[str, Any], now: datetime): + def _statistic_console_output(self, stats: dict[str, Any], now: datetime): """ 输出统计数据到控制台 :param stats: 统计数据 @@ -239,7 +239,7 @@ class StatisticOutputTask(AsyncTask): # -- 以下为统计数据收集方法 -- @staticmethod - async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: + async def _collect_model_request_for_period(collect_period: list[tuple[str, datetime]]) -> dict[str, Any]: """ 收集指定时间段的LLM请求统计数据 @@ -393,8 +393,8 @@ class StatisticOutputTask(AsyncTask): @staticmethod async def _collect_online_time_for_period( - collect_period: List[Tuple[str, datetime]], now: datetime - ) -> Dict[str, Any]: + collect_period: list[tuple[str, datetime]], now: datetime + ) -> dict[str, Any]: """ 收集指定时间段的在线时间统计数据 @@ -452,7 +452,7 @@ class StatisticOutputTask(AsyncTask): break return stats - async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: + async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]: """ 收集指定时间段的消息统计数据 @@ -523,7 +523,7 @@ class StatisticOutputTask(AsyncTask): break return stats - async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: + async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]: """ 收集各时间段的统计数据 :param now: 基准当前时间 @@ -533,7 +533,7 @@ class StatisticOutputTask(AsyncTask): if "last_full_statistics" in local_storage: # 如果存在上次完整统计数据,则使用该数据进行增量统计 - last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore + last_stat: dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射 last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据 @@ -620,7 +620,7 @@ class StatisticOutputTask(AsyncTask): # -- 以下为统计数据格式化方法 -- @staticmethod - def _format_total_stat(stats: Dict[str, Any]) -> str: + def _format_total_stat(stats: dict[str, Any]) -> str: """ 格式化总统计数据 """ @@ -636,7 +636,7 @@ class StatisticOutputTask(AsyncTask): return "\n".join(output) @staticmethod - def _format_model_classified_stat(stats: Dict[str, Any]) -> str: + def _format_model_classified_stat(stats: dict[str, Any]) -> str: """ 格式化按模型分类的统计数据 """ @@ -662,7 +662,7 @@ class StatisticOutputTask(AsyncTask): output.append("") return "\n".join(output) - def _format_chat_stat(self, stats: Dict[str, Any]) -> str: + def _format_chat_stat(self, stats: dict[str, Any]) -> str: """ 格式化聊天统计数据 """ @@ -1007,7 +1007,7 @@ class StatisticOutputTask(AsyncTask): async def _generate_chart_data(self, stat: dict[str, Any]) -> dict: """生成图表数据 (异步)""" now = datetime.now() - chart_data: Dict[str, Any] = {} + chart_data: dict[str, Any] = {} time_ranges = [ ("6h", 6, 10), @@ -1023,16 +1023,16 @@ class StatisticOutputTask(AsyncTask): async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict: start_time = now - timedelta(hours=hours) - time_points: List[datetime] = [] + time_points: list[datetime] = [] current_time = start_time while current_time <= now: time_points.append(current_time) current_time += timedelta(minutes=interval_minutes) total_cost_data = [0.0] * len(time_points) - cost_by_model: Dict[str, List[float]] = {} - cost_by_module: Dict[str, List[float]] = {} - message_by_chat: Dict[str, List[int]] = {} + cost_by_model: dict[str, list[float]] = {} + cost_by_module: dict[str, list[float]] = {} + message_by_chat: dict[str, list[int]] = {} time_labels = [t.strftime("%H:%M") for t in time_points] interval_seconds = interval_minutes * 60 diff --git a/src/chat/utils/timer_calculator.py b/src/chat/utils/timer_calculator.py index d9479af16..acdadc956 100644 --- a/src/chat/utils/timer_calculator.py +++ b/src/chat/utils/timer_calculator.py @@ -1,8 +1,8 @@ import asyncio - -from time import perf_counter +from collections.abc import Callable from functools import wraps -from typing import Optional, Dict, Callable +from time import perf_counter + from rich.traceback import install install(extra_lines=3) @@ -75,12 +75,12 @@ class Timer: 3. 直接实例化:如果不调用 __enter__,打印对象时将显示当前 perf_counter 的值 """ - __slots__ = ("name", "storage", "elapsed", "auto_unit", "start") + __slots__ = ("auto_unit", "elapsed", "name", "start", "storage") def __init__( self, - name: Optional[str] = None, - storage: Optional[Dict[str, float]] = None, + name: str | None = None, + storage: dict[str, float] | None = None, auto_unit: bool = True, do_type_check: bool = False, ): @@ -103,7 +103,7 @@ class Timer: if storage is not None and not isinstance(storage, dict): raise TimerTypeError("storage", "Optional[dict]", type(storage)) - def __call__(self, func: Optional[Callable] = None) -> Callable: + def __call__(self, func: Callable | None = None) -> Callable: """装饰器模式""" if func is None: return lambda f: Timer(name=self.name or f.__name__, storage=self.storage, auto_unit=self.auto_unit)(f) diff --git a/src/chat/utils/typo_generator.py b/src/chat/utils/typo_generator.py index 9c3718b2b..1852679a3 100644 --- a/src/chat/utils/typo_generator.py +++ b/src/chat/utils/typo_generator.py @@ -2,15 +2,15 @@ 错别字生成器 - 基于拼音和字频的中文错别字生成工具 """ -import orjson import math import os import random import time -import jieba - from collections import defaultdict from pathlib import Path + +import jieba +import orjson from pypinyin import Style, pinyin from src.common.logger import get_logger @@ -51,7 +51,7 @@ class ChineseTypoGenerator: # 如果缓存文件存在,直接加载 if cache_file.exists(): - with open(cache_file, "r", encoding="utf-8") as f: + with open(cache_file, encoding="utf-8") as f: return orjson.loads(f.read()) # 使用内置的词频文件 @@ -59,7 +59,7 @@ class ChineseTypoGenerator: dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt") # 读取jieba的词典文件 - with open(dict_path, "r", encoding="utf-8") as f: + with open(dict_path, encoding="utf-8") as f: for line in f: word, freq = line.strip().split()[:2] # 对词中的每个字进行频率累加 @@ -254,7 +254,7 @@ class ChineseTypoGenerator: # 获取jieba词典和词频信息 dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt") valid_words = {} # 改用字典存储词语及其频率 - with open(dict_path, "r", encoding="utf-8") as f: + with open(dict_path, encoding="utf-8") as f: for line in f: parts = line.strip().split() if len(parts) >= 2: diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index ea3bdc89f..8659b3539 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -3,20 +3,21 @@ import random import re import string import time +from collections import Counter +from typing import Any + import jieba import numpy as np - -from collections import Counter from maim_message import UserInfo -from typing import Optional, Tuple, Dict, List, Any -from src.common.logger import get_logger -from src.common.message_repository import find_messages, count_messages -from src.config.config import global_config, model_config -from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.message import MessageRecv +from src.common.logger import get_logger +from src.common.message_repository import count_messages, find_messages +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import PersonInfoManager, get_person_info_manager + from .typo_generator import ChineseTypoGenerator logger = get_logger("chat_utils") @@ -86,9 +87,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: if not is_mentioned: # 判断是否被回复 if re.match( - rf"\[回复 (.+?)\({str(global_config.bot.qq_account)}\):(.+?)\],说:", message.processed_plain_text + rf"\[回复 (.+?)\({global_config.bot.qq_account!s}\):(.+?)\],说:", message.processed_plain_text ) or re.match( - rf"\[回复<(.+?)(?=:{str(global_config.bot.qq_account)}>)\:{str(global_config.bot.qq_account)}>:(.+?)\],说:", + rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>:(.+?)\],说:", message.processed_plain_text, ): is_mentioned = True @@ -110,14 +111,14 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: return is_mentioned, reply_probability -async def get_embedding(text, request_type="embedding") -> Optional[List[float]]: +async def get_embedding(text, request_type="embedding") -> list[float] | None: """获取文本的embedding向量""" # 每次都创建新的LLMRequest实例以避免事件循环冲突 llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type) try: embedding, _ = await llm.get_embedding(text) except Exception as e: - logger.error(f"获取embedding失败: {str(e)}") + logger.error(f"获取embedding失败: {e!s}") embedding = None return embedding @@ -621,7 +622,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal" return time.strftime("%H:%M:%S", time.localtime(timestamp)) -def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: +def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None]: """ 获取聊天类型(是否群聊)和私聊对象信息。 @@ -670,7 +671,6 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: if loop.is_running(): # 如果事件循环在运行,从其他线程提交并等待结果 try: - fut = asyncio.run_coroutine_threadsafe( person_info_manager.get_value(person_id, "person_name"), loop ) @@ -706,7 +706,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: return is_group_chat, chat_target_info -def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]: +def assign_message_ids(messages: list[Any]) -> list[dict[str, Any]]: """ 为消息列表中的每个消息分配唯一的简短随机ID diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index ab0915842..29a918d87 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -1,29 +1,27 @@ import base64 +import hashlib +import io import os import time -import hashlib import uuid -import io -import numpy as np +from typing import Any -from typing import Optional, Tuple, Dict, Any +import numpy as np from PIL import Image from rich.traceback import install +from sqlalchemy import and_, select +from src.common.database.sqlalchemy_models import ImageDescriptions, Images, get_db_session from src.common.logger import get_logger -from src.common.database.sqlalchemy_models import Images, ImageDescriptions from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.common.database.sqlalchemy_models import get_db_session - -from sqlalchemy import select, and_ install(extra_lines=3) logger = get_logger("chat_image") -def is_image_message(message: Dict[str, Any]) -> bool: +def is_image_message(message: dict[str, Any]) -> bool: """ 判断消息是否为图片消息 @@ -69,7 +67,7 @@ class ImageManager: os.makedirs(self.IMAGE_DIR, exist_ok=True) @staticmethod - async def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: + async def _get_description_from_db(image_hash: str, description_type: str) -> str | None: """从数据库获取图片描述 Args: @@ -93,7 +91,7 @@ class ImageManager: ).scalar() return record.description if record else None except Exception as e: - logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}") + logger.error(f"从数据库获取描述失败 (SQLAlchemy): {e!s}") return None @staticmethod @@ -136,7 +134,7 @@ class ImageManager: await session.commit() # 会在上下文管理器中自动调用 except Exception as e: - logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}") + logger.error(f"保存描述到数据库失败 (SQLAlchemy): {e!s}") @staticmethod async def get_emoji_tag(image_base64: str) -> str: @@ -287,10 +285,10 @@ class ImageManager: session.add(new_img) await session.commit() except Exception as e: - logger.error(f"保存到Images表失败: {str(e)}") + logger.error(f"保存到Images表失败: {e!s}") except Exception as e: - logger.error(f"保存表情包文件或元数据失败: {str(e)}") + logger.error(f"保存表情包文件或元数据失败: {e!s}") else: logger.debug("偷取表情包功能已关闭,跳过保存。") @@ -300,7 +298,7 @@ class ImageManager: return f"[表情包:{final_emotion}]" except Exception as e: - logger.error(f"获取表情包描述失败: {str(e)}") + logger.error(f"获取表情包描述失败: {e!s}") return "[表情包(处理失败)]" async def get_image_description(self, image_base64: str) -> str: @@ -391,11 +389,11 @@ class ImageManager: logger.info(f"[VLM完成] 图片描述生成: {description}...") return f"[图片:{description}]" except Exception as e: - logger.error(f"获取图片描述失败: {str(e)}") + logger.error(f"获取图片描述失败: {e!s}") return "[图片(处理失败)]" @staticmethod - def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]: + def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> str | None: # sourcery skip: use-contextlib-suppress """将GIF转换为水平拼接的静态图像, 跳过相似的帧 @@ -512,10 +510,10 @@ class ImageManager: logger.error("GIF转换失败: 内存不足,可能是GIF太大或帧数太多") return None # 内存不够啦 except Exception as e: - logger.error(f"GIF转换失败: {str(e)}", exc_info=True) # 记录详细错误信息 + logger.error(f"GIF转换失败: {e!s}", exc_info=True) # 记录详细错误信息 return None # 其他错误也返回None - async def process_image(self, image_base64: str) -> Tuple[str, str]: + async def process_image(self, image_base64: str) -> tuple[str, str]: # sourcery skip: hoist-if-from-if """处理图片并返回图片ID和描述 @@ -604,7 +602,7 @@ class ImageManager: return image_id, f"[picid:{image_id}]" except Exception as e: - logger.error(f"处理图片失败: {str(e)}") + logger.error(f"处理图片失败: {e!s}") return "", "[图片]" @@ -637,4 +635,4 @@ def image_path_to_base64(image_path: str) -> str: if image_data := f.read(): return base64.b64encode(image_data).decode("utf-8") else: - raise IOError(f"读取图片文件失败: {image_path}") + raise OSError(f"读取图片文件失败: {image_path}") diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 19ec72cb6..6a6fc6245 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -1,29 +1,28 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ 视频分析器模块 - Rust优化版本 集成了Rust视频关键帧提取模块,提供高性能的视频分析功能 支持SIMD优化、多线程处理和智能关键帧检测 """ -import os -import tempfile import asyncio import base64 import hashlib +import io +import os +import tempfile import time +from pathlib import Path + import numpy as np from PIL import Image -from pathlib import Path -from typing import List, Tuple, Optional, Dict -import io - -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger -from src.common.database.sqlalchemy_models import get_db_session, Videos from sqlalchemy import select +from src.common.database.sqlalchemy_models import Videos, get_db_session +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest + logger = get_logger("utils_video") # Rust模块可用性检测 @@ -203,7 +202,7 @@ class VideoAnalyzer: hash_obj.update(video_data) return hash_obj.hexdigest() - async def _check_video_exists(self, video_hash: str) -> Optional[Videos]: + async def _check_video_exists(self, video_hash: str) -> Videos | None: """检查视频是否已经分析过""" try: async with get_db_session() as session: @@ -220,8 +219,8 @@ class VideoAnalyzer: return None async def _store_video_result( - self, video_hash: str, description: str, metadata: Optional[Dict] = None - ) -> Optional[Videos]: + self, video_hash: str, description: str, metadata: dict | None = None + ) -> Videos | None: """存储视频分析结果到数据库""" # 检查描述是否为错误信息,如果是则不保存 if description.startswith("❌"): @@ -281,7 +280,7 @@ class VideoAnalyzer: else: logger.warning(f"无效的分析模式: {mode}") - async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]: + async def extract_frames(self, video_path: str) -> list[tuple[str, float]]: """提取视频帧 - 智能选择最佳实现""" # 检查是否应该使用Rust实现 if RUST_VIDEO_AVAILABLE and self.frame_extraction_mode == "keyframe": @@ -303,7 +302,7 @@ class VideoAnalyzer: logger.info(f"🔄 抽帧模式为 {self.frame_extraction_mode},使用Python抽帧实现") return await self._extract_frames_python_fallback(video_path) - async def _extract_frames_rust_advanced(self, video_path: str) -> List[Tuple[str, float]]: + async def _extract_frames_rust_advanced(self, video_path: str) -> list[tuple[str, float]]: """使用 Rust 高级接口的帧提取""" try: logger.info("🔄 使用 Rust 高级接口提取关键帧...") @@ -387,7 +386,7 @@ class VideoAnalyzer: logger.info("回退到基础 Rust 方法") return await self._extract_frames_rust(video_path) - async def _extract_frames_rust(self, video_path: str) -> List[Tuple[str, float]]: + async def _extract_frames_rust(self, video_path: str) -> list[tuple[str, float]]: """使用 Rust 实现的帧提取""" try: logger.info("🔄 使用 Rust 模块提取关键帧...") @@ -463,7 +462,7 @@ class VideoAnalyzer: logger.error(f"❌ Rust 帧提取失败: {e}") raise e - async def _extract_frames_python_fallback(self, video_path: str) -> List[Tuple[str, float]]: + async def _extract_frames_python_fallback(self, video_path: str) -> list[tuple[str, float]]: """Python降级抽帧实现 - 支持多种抽帧模式""" try: # 导入旧版本分析器 @@ -490,7 +489,7 @@ class VideoAnalyzer: logger.error(f"❌ Python降级抽帧失败: {e}") return [] - async def analyze_frames_batch(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: + async def analyze_frames_batch(self, frames: list[tuple[str, float]], user_question: str = None) -> str: """批量分析所有帧""" logger.info(f"开始批量分析{len(frames)}帧") @@ -526,7 +525,7 @@ class VideoAnalyzer: logger.error(f"❌ 视频识别失败: {e}") raise e - async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str: + async def _analyze_multiple_frames(self, frames: list[tuple[str, float]], prompt: str) -> str: """使用多图片分析方法""" logger.info(f"开始构建包含{len(frames)}帧的分析请求") @@ -566,7 +565,7 @@ class VideoAnalyzer: logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ") return api_response.content or "❌ 未获得响应内容" - async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: + async def analyze_frames_sequential(self, frames: list[tuple[str, float]], user_question: str = None) -> str: """逐帧分析并汇总""" logger.info(f"开始逐帧分析{len(frames)}帧") @@ -624,7 +623,7 @@ class VideoAnalyzer: # 如果汇总失败,返回各帧分析结果 return f"视频逐帧分析结果:\n\n{chr(10).join(frame_analyses)}" - async def analyze_video(self, video_path: str, user_question: str = None) -> Tuple[bool, str]: + async def analyze_video(self, video_path: str, user_question: str = None) -> tuple[bool, str]: """分析视频的主要方法 Returns: @@ -662,13 +661,13 @@ class VideoAnalyzer: return (True, result) except Exception as e: - error_msg = f"❌ 视频分析失败: {str(e)}" + error_msg = f"❌ 视频分析失败: {e!s}" logger.error(error_msg) return (False, error_msg) async def analyze_video_from_bytes( self, video_bytes: bytes, filename: str = None, user_question: str = None, prompt: str = None - ) -> Dict[str, str]: + ) -> dict[str, str]: """从字节数据分析视频 Args: @@ -778,7 +777,7 @@ class VideoAnalyzer: return {"summary": result} except Exception as e: - error_msg = f"❌ 从字节数据分析视频失败: {str(e)}" + error_msg = f"❌ 从字节数据分析视频失败: {e!s}" logger.error(error_msg) # 不保存错误信息到数据库,允许后续重试 @@ -802,7 +801,7 @@ class VideoAnalyzer: supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"} return Path(file_path).suffix.lower() in supported_formats - def get_processing_capabilities(self) -> Dict[str, any]: + def get_processing_capabilities(self) -> dict[str, any]: """获取处理能力信息""" if not RUST_VIDEO_AVAILABLE: return {"error": "Rust视频处理模块不可用", "available": False, "reason": "rust_video模块未安装或加载失败"} @@ -832,7 +831,7 @@ class VideoAnalyzer: logger.error(f"获取处理能力信息失败: {e}") return {"error": str(e), "available": False} - def _get_recommended_settings(self, cpu_features: Dict[str, bool]) -> Dict[str, any]: + def _get_recommended_settings(self, cpu_features: dict[str, bool]) -> dict[str, any]: """根据CPU特性推荐最佳设置""" settings = { "use_simd": any(cpu_features.values()), @@ -882,7 +881,7 @@ def is_video_analysis_available() -> bool: return False -def get_video_analysis_status() -> Dict[str, any]: +def get_video_analysis_status() -> dict[str, any]: """获取视频分析功能的详细状态信息 Returns: diff --git a/src/chat/utils/utils_video_legacy.py b/src/chat/utils/utils_video_legacy.py index 77ca88142..46eb13857 100644 --- a/src/chat/utils/utils_video_legacy.py +++ b/src/chat/utils/utils_video_legacy.py @@ -1,25 +1,25 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ 视频分析器模块 - 旧版本兼容模块 支持多种分析模式:批处理、逐帧、自动选择 包含Python原生的抽帧功能,作为Rust模块的降级方案 """ -import os -import cv2 import asyncio import base64 +import io +import os +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any + +import cv2 import numpy as np from PIL import Image -from pathlib import Path -from typing import List, Tuple, Optional, Any -import io -from concurrent.futures import ThreadPoolExecutor -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("utils_video_legacy") @@ -30,7 +30,7 @@ def _extract_frames_worker( frame_quality: int, max_image_size: int, frame_extraction_mode: str, - frame_interval_seconds: Optional[float], + frame_interval_seconds: float | None, ) -> list[Any] | list[tuple[str, str]]: """线程池中提取视频帧的工作函数""" frames = [] @@ -221,7 +221,7 @@ class LegacyVideoAnalyzer: f"✅ 旧版本视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}" ) - async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]: + async def extract_frames(self, video_path: str) -> list[tuple[str, float]]: """提取视频帧 - 支持多进程和单线程模式""" # 先获取视频信息 cap = cv2.VideoCapture(video_path) @@ -247,7 +247,7 @@ class LegacyVideoAnalyzer: else: return await self._extract_frames_fallback(video_path) - async def _extract_frames_multiprocess(self, video_path: str) -> List[Tuple[str, float]]: + async def _extract_frames_multiprocess(self, video_path: str) -> list[tuple[str, float]]: """线程池版本的帧提取""" loop = asyncio.get_event_loop() @@ -282,7 +282,7 @@ class LegacyVideoAnalyzer: logger.info("🔄 降级到单线程模式...") return await self._extract_frames_fallback(video_path) - async def _extract_frames_fallback(self, video_path: str) -> List[Tuple[str, float]]: + async def _extract_frames_fallback(self, video_path: str) -> list[tuple[str, float]]: """帧提取的降级方法 - 原始异步版本""" frames = [] extracted_count = 0 @@ -389,7 +389,7 @@ class LegacyVideoAnalyzer: logger.info(f"✅ 成功提取{len(frames)}帧") return frames - async def analyze_frames_batch(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: + async def analyze_frames_batch(self, frames: list[tuple[str, float]], user_question: str = None) -> str: """批量分析所有帧""" logger.info(f"开始批量分析{len(frames)}帧") @@ -441,7 +441,7 @@ class LegacyVideoAnalyzer: logger.error(f"❌ 降级分析也失败: {fallback_e}") raise - async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str: + async def _analyze_multiple_frames(self, frames: list[tuple[str, float]], prompt: str) -> str: """使用多图片分析方法""" logger.info(f"开始构建包含{len(frames)}帧的分析请求") @@ -481,7 +481,7 @@ class LegacyVideoAnalyzer: logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ") return api_response.content or "❌ 未获得响应内容" - async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: + async def analyze_frames_sequential(self, frames: list[tuple[str, float]], user_question: str = None) -> str: """逐帧分析并汇总""" logger.info(f"开始逐帧分析{len(frames)}帧") @@ -567,7 +567,7 @@ class LegacyVideoAnalyzer: return result except Exception as e: - error_msg = f"❌ 视频分析失败: {str(e)}" + error_msg = f"❌ 视频分析失败: {e!s}" logger.error(error_msg) return error_msg diff --git a/src/chat/utils/utils_voice.py b/src/chat/utils/utils_voice.py index 49ec10794..eae96e5f3 100644 --- a/src/chat/utils/utils_voice.py +++ b/src/chat/utils/utils_voice.py @@ -1,8 +1,8 @@ -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from rich.traceback import install from src.common.logger import get_logger -from rich.traceback import install +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest install(extra_lines=3) @@ -25,5 +25,5 @@ async def get_voice_text(voice_base64: str) -> str: return f"[语音:{text}]" except Exception as e: - logger.error(f"语音转文字失败: {str(e)}") + logger.error(f"语音转文字失败: {e!s}") return "[语音]" diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index c77d9e8bd..9afb70dcc 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -1,17 +1,19 @@ -import time -import orjson import hashlib +import time from pathlib import Path -import numpy as np +from typing import Any + import faiss -from typing import Any, Dict, Optional, Union -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config +import numpy as np +import orjson + from src.common.config_helpers import resolve_embedding_dimension -from src.common.database.sqlalchemy_models import CacheEntries from src.common.database.sqlalchemy_database_api import db_query, db_save +from src.common.database.sqlalchemy_models import CacheEntries +from src.common.logger import get_logger from src.common.vector_db import vector_db_service +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("cache_manager") @@ -40,14 +42,14 @@ class CacheManager: self.semantic_cache_collection_name = "semantic_cache" # L1 缓存 (内存) - self.l1_kv_cache: Dict[str, Dict[str, Any]] = {} + self.l1_kv_cache: dict[str, dict[str, Any]] = {} embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension) if not embedding_dim: embedding_dim = global_config.lpmm_knowledge.embedding_dimension self.embedding_dimension = embedding_dim self.l1_vector_index = faiss.IndexFlatIP(embedding_dim) - self.l1_vector_id_to_key: Dict[int, str] = {} + self.l1_vector_id_to_key: dict[int, str] = {} # L2 向量缓存 (使用新的服务) vector_db_service.get_or_create_collection(self.semantic_cache_collection_name) @@ -59,7 +61,7 @@ class CacheManager: logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)") @staticmethod - def _validate_embedding(embedding_result: Any) -> Optional[np.ndarray]: + def _validate_embedding(embedding_result: Any) -> np.ndarray | None: """ 验证和标准化嵌入向量格式 """ @@ -100,7 +102,7 @@ class CacheManager: return None @staticmethod - def _generate_key(tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path]) -> str: + def _generate_key(tool_name: str, function_args: dict[str, Any], tool_file_path: str | Path) -> str: """生成确定性的缓存键,包含文件修改时间以实现自动失效。""" try: tool_file_path = Path(tool_file_path) @@ -124,10 +126,10 @@ class CacheManager: async def get( self, tool_name: str, - function_args: Dict[str, Any], - tool_file_path: Union[str, Path], - semantic_query: Optional[str] = None, - ) -> Optional[Any]: + function_args: dict[str, Any], + tool_file_path: str | Path, + semantic_query: str | None = None, + ) -> Any | None: """ 从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。 """ @@ -251,11 +253,11 @@ class CacheManager: async def set( self, tool_name: str, - function_args: Dict[str, Any], - tool_file_path: Union[str, Path], + function_args: dict[str, Any], + tool_file_path: str | Path, data: Any, - ttl: Optional[int] = None, - semantic_query: Optional[str] = None, + ttl: int | None = None, + semantic_query: str | None = None, ): """将结果存入所有缓存层。""" if ttl is None: diff --git a/src/common/config_helpers.py b/src/common/config_helpers.py index 5a2134fe1..f5460fece 100644 --- a/src/common/config_helpers.py +++ b/src/common/config_helpers.py @@ -1,11 +1,9 @@ from __future__ import annotations -from typing import Optional - from src.config.config import global_config, model_config -def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: bool = True) -> Optional[int]: +def resolve_embedding_dimension(fallback: int | None = None, *, sync_global: bool = True) -> int | None: """获取当前配置的嵌入向量维度。 优先顺序: @@ -14,7 +12,7 @@ def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: 3. 调用方提供的 fallback """ - candidates: list[Optional[int]] = [] + candidates: list[int | None] = [] try: embedding_task = getattr(model_config.model_task_config, "embedding", None) @@ -30,7 +28,7 @@ def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: candidates.append(fallback) - resolved: Optional[int] = next((int(dim) for dim in candidates if dim and int(dim) > 0), None) + resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None) if resolved and sync_global: try: diff --git a/src/common/data_models/bot_interest_data_model.py b/src/common/data_models/bot_interest_data_model.py index 819b50a8f..fe152ca2e 100644 --- a/src/common/data_models/bot_interest_data_model.py +++ b/src/common/data_models/bot_interest_data_model.py @@ -4,8 +4,8 @@ """ from dataclasses import dataclass, field -from typing import List, Dict, Optional, Any from datetime import datetime +from typing import Any from . import BaseDataModel @@ -16,12 +16,12 @@ class BotInterestTag(BaseDataModel): tag_name: str weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0) - embedding: Optional[List[float]] = None # 标签的embedding向量 + embedding: list[float] | None = None # 标签的embedding向量 created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) is_active: bool = True - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为字典格式""" return { "tag_name": self.tag_name, @@ -33,7 +33,7 @@ class BotInterestTag(BaseDataModel): } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "BotInterestTag": + def from_dict(cls, data: dict[str, Any]) -> "BotInterestTag": """从字典创建对象""" return cls( tag_name=data["tag_name"], @@ -51,16 +51,16 @@ class BotPersonalityInterests(BaseDataModel): personality_id: str personality_description: str # 人设描述文本 - interest_tags: List[BotInterestTag] = field(default_factory=list) + interest_tags: list[BotInterestTag] = field(default_factory=list) embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型 last_updated: datetime = field(default_factory=datetime.now) version: int = 1 # 版本号,用于追踪更新 - def get_active_tags(self) -> List[BotInterestTag]: + def get_active_tags(self) -> list[BotInterestTag]: """获取活跃的兴趣标签""" return [tag for tag in self.interest_tags if tag.is_active] - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为字典格式""" return { "personality_id": self.personality_id, @@ -72,7 +72,7 @@ class BotPersonalityInterests(BaseDataModel): } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "BotPersonalityInterests": + def from_dict(cls, data: dict[str, Any]) -> "BotPersonalityInterests": """从字典创建对象""" return cls( personality_id=data["personality_id"], @@ -89,14 +89,14 @@ class InterestMatchResult(BaseDataModel): """兴趣匹配结果""" message_id: str - matched_tags: List[str] = field(default_factory=list) - match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score + matched_tags: list[str] = field(default_factory=list) + match_scores: dict[str, float] = field(default_factory=dict) # tag_name -> score overall_score: float = 0.0 - top_tag: Optional[str] = None + top_tag: str | None = None confidence: float = 0.0 # 匹配置信度 (0.0-1.0) - matched_keywords: List[str] = field(default_factory=list) + matched_keywords: list[str] = field(default_factory=list) - def add_match(self, tag_name: str, score: float, keywords: List[str] = None): + def add_match(self, tag_name: str, score: float, keywords: list[str] = None): """添加匹配结果""" self.matched_tags.append(tag_name) self.match_scores[tag_name] = score @@ -131,7 +131,7 @@ class InterestMatchResult(BaseDataModel): else: self.confidence = 0.0 - def get_top_matches(self, top_n: int = 3) -> List[tuple]: + def get_top_matches(self, top_n: int = 3) -> list[tuple]: """获取前N个最佳匹配""" sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True) return sorted_matches[:top_n] diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 4578d1481..f1bc0ef67 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -1,6 +1,6 @@ import json -from typing import Optional, Any, Dict from dataclasses import dataclass, field +from typing import Any from . import BaseDataModel @@ -10,7 +10,7 @@ class DatabaseUserInfo(BaseDataModel): platform: str = field(default_factory=str) user_id: str = field(default_factory=str) user_nickname: str = field(default_factory=str) - user_cardname: Optional[str] = None + user_cardname: str | None = None # def __post_init__(self): # assert isinstance(self.platform, str), "platform must be a string" @@ -25,7 +25,7 @@ class DatabaseUserInfo(BaseDataModel): class DatabaseGroupInfo(BaseDataModel): group_id: str = field(default_factory=str) group_name: str = field(default_factory=str) - group_platform: Optional[str] = None + group_platform: str | None = None # def __post_init__(self): # assert isinstance(self.group_id, str), "group_id must be a string" @@ -42,7 +42,7 @@ class DatabaseChatInfo(BaseDataModel): create_time: float = field(default_factory=float) last_active_time: float = field(default_factory=float) user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo) - group_info: Optional[DatabaseGroupInfo] = None + group_info: DatabaseGroupInfo | None = None # def __post_init__(self): # assert isinstance(self.stream_id, str), "stream_id must be a string" @@ -62,41 +62,41 @@ class DatabaseMessages(BaseDataModel): message_id: str = "", time: float = 0.0, chat_id: str = "", - reply_to: Optional[str] = None, - interest_value: Optional[float] = None, - key_words: Optional[str] = None, - key_words_lite: Optional[str] = None, - is_mentioned: Optional[bool] = None, - is_at: Optional[bool] = None, - reply_probability_boost: Optional[float] = None, - processed_plain_text: Optional[str] = None, - display_message: Optional[str] = None, - priority_mode: Optional[str] = None, - priority_info: Optional[str] = None, - additional_config: Optional[str] = None, + reply_to: str | None = None, + interest_value: float | None = None, + key_words: str | None = None, + key_words_lite: str | None = None, + is_mentioned: bool | None = None, + is_at: bool | None = None, + reply_probability_boost: float | None = None, + processed_plain_text: str | None = None, + display_message: str | None = None, + priority_mode: str | None = None, + priority_info: str | None = None, + additional_config: str | None = None, is_emoji: bool = False, is_picid: bool = False, is_command: bool = False, is_notify: bool = False, - selected_expressions: Optional[str] = None, + selected_expressions: str | None = None, is_read: bool = False, user_id: str = "", user_nickname: str = "", - user_cardname: Optional[str] = None, + user_cardname: str | None = None, user_platform: str = "", - chat_info_group_id: Optional[str] = None, - chat_info_group_name: Optional[str] = None, - chat_info_group_platform: Optional[str] = None, + chat_info_group_id: str | None = None, + chat_info_group_name: str | None = None, + chat_info_group_platform: str | None = None, chat_info_user_id: str = "", chat_info_user_nickname: str = "", - chat_info_user_cardname: Optional[str] = None, + chat_info_user_cardname: str | None = None, chat_info_user_platform: str = "", chat_info_stream_id: str = "", chat_info_platform: str = "", chat_info_create_time: float = 0.0, chat_info_last_active_time: float = 0.0, # 新增字段 - actions: Optional[list] = None, + actions: list | None = None, should_reply: bool = False, **kwargs: Any, ): @@ -132,7 +132,7 @@ class DatabaseMessages(BaseDataModel): self.selected_expressions = selected_expressions self.is_read = is_read - self.group_info: Optional[DatabaseGroupInfo] = None + self.group_info: DatabaseGroupInfo | None = None self.user_info = DatabaseUserInfo( user_id=user_id, user_nickname=user_nickname, @@ -172,7 +172,7 @@ class DatabaseMessages(BaseDataModel): # assert isinstance(self.interest_value, float) or self.interest_value is None, ( # "interest_value must be a float or None" # ) - def flatten(self) -> Dict[str, Any]: + def flatten(self) -> dict[str, Any]: """ 将消息数据模型转换为字典格式,便于存储或传输 """ @@ -255,7 +255,7 @@ class DatabaseMessages(BaseDataModel): """ return self.actions or [] - def get_message_summary(self) -> Dict[str, Any]: + def get_message_summary(self) -> dict[str, Any]: """ 获取消息摘要信息 diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py index ba45ab3c4..e9ed04162 100644 --- a/src/common/data_models/info_data_model.py +++ b/src/common/data_models/info_data_model.py @@ -1,30 +1,32 @@ from dataclasses import dataclass, field -from typing import Optional, Dict, List, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from src.plugin_system.base.component_types import ChatType + from . import BaseDataModel if TYPE_CHECKING: - from .database_data_model import DatabaseMessages from src.plugin_system.base.component_types import ActionInfo, ChatMode + from .database_data_model import DatabaseMessages + @dataclass class TargetPersonInfo(BaseDataModel): platform: str = field(default_factory=str) user_id: str = field(default_factory=str) user_nickname: str = field(default_factory=str) - person_id: Optional[str] = None - person_name: Optional[str] = None + person_id: str | None = None + person_name: str | None = None @dataclass class ActionPlannerInfo(BaseDataModel): action_type: str = field(default_factory=str) - reasoning: Optional[str] = None - action_data: Optional[Dict] = None + reasoning: str | None = None + action_data: dict | None = None action_message: Optional["DatabaseMessages"] = None - available_actions: Optional[Dict[str, "ActionInfo"]] = None + available_actions: dict[str, "ActionInfo"] | None = None @dataclass @@ -36,7 +38,7 @@ class InterestScore(BaseDataModel): interest_match_score: float relationship_score: float mentioned_score: float - details: Dict[str, str] + details: dict[str, str] @dataclass @@ -50,10 +52,10 @@ class Plan(BaseDataModel): chat_type: "ChatType" # Generator 填充 - available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict) - chat_history: List["DatabaseMessages"] = field(default_factory=list) - target_info: Optional[TargetPersonInfo] = None + available_actions: dict[str, "ActionInfo"] = field(default_factory=dict) + chat_history: list["DatabaseMessages"] = field(default_factory=list) + target_info: TargetPersonInfo | None = None # Filter 填充 - llm_prompt: Optional[str] = None - decided_actions: Optional[List[ActionPlannerInfo]] = None + llm_prompt: str | None = None + decided_actions: list[ActionPlannerInfo] | None = None diff --git a/src/common/data_models/llm_data_model.py b/src/common/data_models/llm_data_model.py index a59b65391..147c2b22b 100644 --- a/src/common/data_models/llm_data_model.py +++ b/src/common/data_models/llm_data_model.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional, List, Tuple, TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any from . import BaseDataModel @@ -9,10 +9,10 @@ if TYPE_CHECKING: @dataclass class LLMGenerationDataModel(BaseDataModel): - content: Optional[str] = None - reasoning: Optional[str] = None - model: Optional[str] = None - tool_calls: Optional[List["ToolCall"]] = None - prompt: Optional[str] = None - selected_expressions: Optional[List[int]] = None - reply_set: Optional[List[Tuple[str, Any]]] = None + content: str | None = None + reasoning: str | None = None + model: str | None = None + tool_calls: list["ToolCall"] | None = None + prompt: str | None = None + selected_expressions: list[int] | None = None + reply_set: list[tuple[str, Any]] | None = None diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index a72b7564c..b836101cc 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -7,11 +7,12 @@ import asyncio import time from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional + +from src.common.logger import get_logger +from src.plugin_system.base.component_types import ChatMode, ChatType from . import BaseDataModel -from src.plugin_system.base.component_types import ChatMode, ChatType -from src.common.logger import get_logger if TYPE_CHECKING: from .database_data_model import DatabaseMessages @@ -34,11 +35,11 @@ class StreamContext(BaseDataModel): stream_id: str chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊 chat_mode: ChatMode = ChatMode.NORMAL # 聊天模式,默认为普通模式 - unread_messages: List["DatabaseMessages"] = field(default_factory=list) - history_messages: List["DatabaseMessages"] = field(default_factory=list) + unread_messages: list["DatabaseMessages"] = field(default_factory=list) + history_messages: list["DatabaseMessages"] = field(default_factory=list) last_check_time: float = field(default_factory=time.time) is_active: bool = True - processing_task: Optional[asyncio.Task] = None + processing_task: asyncio.Task | None = None interruption_count: int = 0 # 打断计数器 last_interruption_time: float = 0.0 # 上次打断时间 afc_threshold_adjustment: float = 0.0 # afc阈值调整量 @@ -49,8 +50,8 @@ class StreamContext(BaseDataModel): # 新增字段以替代ChatMessageContext功能 current_message: Optional["DatabaseMessages"] = None - priority_mode: Optional[str] = None - priority_info: Optional[dict] = None + priority_mode: str | None = None + priority_info: dict | None = None def add_message(self, message: "DatabaseMessages"): """添加消息到上下文""" @@ -150,11 +151,11 @@ class StreamContext(BaseDataModel): self.unread_messages.remove(msg) break - def get_unread_messages(self) -> List["DatabaseMessages"]: + def get_unread_messages(self) -> list["DatabaseMessages"]: """获取未读消息""" return [msg for msg in self.unread_messages if not msg.is_read] - def get_history_messages(self, limit: int = 20) -> List["DatabaseMessages"]: + def get_history_messages(self, limit: int = 20) -> list["DatabaseMessages"]: """获取历史消息""" # 优先返回最近的历史消息和所有未读消息 recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages @@ -230,7 +231,7 @@ class StreamContext(BaseDataModel): """设置当前消息""" self.current_message = message - def get_template_name(self) -> Optional[str]: + def get_template_name(self) -> str | None: """获取模板名称""" if ( self.current_message @@ -336,11 +337,11 @@ class StreamContext(BaseDataModel): return False return True - def get_priority_mode(self) -> Optional[str]: + def get_priority_mode(self) -> str | None: """获取优先级模式""" return self.priority_mode - def get_priority_info(self) -> Optional[dict]: + def get_priority_info(self) -> dict | None: """获取优先级信息""" return self.priority_info diff --git a/src/common/database/database.py b/src/common/database/database.py index 92c851edb..63f632aa5 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -1,10 +1,11 @@ import os + from rich.traceback import install -from src.common.logger import get_logger # SQLAlchemy相关导入 from src.common.database.sqlalchemy_init import initialize_database_compat -from src.common.database.sqlalchemy_models import get_engine, get_db_session +from src.common.database.sqlalchemy_models import get_db_session, get_engine +from src.common.logger import get_logger install(extra_lines=3) diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 330846983..38c972236 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -6,31 +6,31 @@ import time import traceback -from typing import Dict, List, Any, Union, Optional +from typing import Any -from sqlalchemy import desc, asc, func, and_, select +from sqlalchemy import and_, asc, desc, func, select from sqlalchemy.exc import SQLAlchemyError from src.common.database.sqlalchemy_models import ( - get_db_session, - Messages, ActionRecords, - PersonInfo, - ChatStreams, - LLMUsage, - Emoji, - Images, - ImageDescriptions, - OnlineTime, - Memory, - Expression, - ThinkingLog, - GraphNodes, - GraphEdges, - Schedule, - MaiZoneScheduleStatus, CacheEntries, + ChatStreams, + Emoji, + Expression, + GraphEdges, + GraphNodes, + ImageDescriptions, + Images, + LLMUsage, + MaiZoneScheduleStatus, + Memory, + Messages, + OnlineTime, + PersonInfo, + Schedule, + ThinkingLog, UserRelationships, + get_db_session, ) from src.common.logger import get_logger @@ -59,7 +59,7 @@ MODEL_MAPPING = { } -async def build_filters(model_class, filters: Dict[str, Any]): +async def build_filters(model_class, filters: dict[str, Any]): """构建查询过滤条件""" conditions = [] @@ -98,13 +98,13 @@ async def build_filters(model_class, filters: Dict[str, Any]): async def db_query( model_class, - data: Optional[Dict[str, Any]] = None, - query_type: Optional[str] = "get", - filters: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None, - order_by: Optional[List[str]] = None, - single_result: Optional[bool] = False, -) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: + data: dict[str, Any] | None = None, + query_type: str | None = "get", + filters: dict[str, Any] | None = None, + limit: int | None = None, + order_by: list[str] | None = None, + single_result: bool | None = False, +) -> list[dict[str, Any]] | dict[str, Any] | None: """执行异步数据库查询操作 Args: @@ -263,8 +263,8 @@ async def db_query( async def db_save( - model_class, data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None -) -> Optional[Dict[str, Any]]: + model_class, data: dict[str, Any], key_field: str | None = None, key_value: Any | None = None +) -> dict[str, Any] | None: """异步保存数据到数据库(创建或更新) Args: @@ -325,11 +325,11 @@ async def db_save( async def db_get( model_class, - filters: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None, - order_by: Optional[str] = None, - single_result: Optional[bool] = False, -) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: + filters: dict[str, Any] | None = None, + limit: int | None = None, + order_by: str | None = None, + single_result: bool | None = False, +) -> list[dict[str, Any]] | dict[str, Any] | None: """异步从数据库获取记录 Args: @@ -359,9 +359,9 @@ async def store_action_info( action_prompt_display: str = "", action_done: bool = True, thinking_id: str = "", - action_data: Optional[dict] = None, + action_data: dict | None = None, action_name: str = "", -) -> Optional[Dict[str, Any]]: +) -> dict[str, Any] | None: """异步存储动作信息到数据库 Args: diff --git a/src/common/database/sqlalchemy_init.py b/src/common/database/sqlalchemy_init.py index 7d3f97136..daf61f3a5 100644 --- a/src/common/database/sqlalchemy_init.py +++ b/src/common/database/sqlalchemy_init.py @@ -4,10 +4,10 @@ 提供统一的异步数据库初始化接口 """ -from typing import Optional from sqlalchemy.exc import SQLAlchemyError -from src.common.logger import get_logger + from src.common.database.sqlalchemy_models import Base, get_engine, initialize_database +from src.common.logger import get_logger logger = get_logger("sqlalchemy_init") @@ -71,7 +71,7 @@ async def create_all_tables() -> bool: return False -async def get_database_info() -> Optional[dict]: +async def get_database_info() -> dict | None: """ 异步获取数据库信息 diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 2f78e56d0..c89848ee3 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -6,11 +6,12 @@ import datetime import os import time +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import Optional, Any, Dict, AsyncGenerator +from typing import Any -from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime, text -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker +from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Mapped, mapped_column @@ -423,7 +424,7 @@ class Expression(Base): last_active_time: Mapped[float] = mapped_column(Float, nullable=False) chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) type: Mapped[str] = mapped_column(Text, nullable=False) - create_date: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + create_date: Mapped[float | None] = mapped_column(Float, nullable=True) __table_args__ = (Index("idx_expression_chat_id", "chat_id"),) @@ -710,7 +711,7 @@ async def initialize_database(): config = global_config.database # 配置引擎参数 - engine_kwargs: Dict[str, Any] = { + engine_kwargs: dict[str, Any] = { "echo": False, # 生产环境关闭SQL日志 "future": True, } @@ -759,12 +760,12 @@ async def initialize_database(): @asynccontextmanager -async def get_db_session() -> AsyncGenerator[Optional[AsyncSession], None]: +async def get_db_session() -> AsyncGenerator[AsyncSession | None, None]: """ 异步数据库会话上下文管理器。 在初始化失败时会yield None,调用方需要检查会话是否为None。 """ - session: Optional[AsyncSession] = None + session: AsyncSession | None = None SessionLocal = None try: _, SessionLocal = await initialize_database() diff --git a/src/common/logger.py b/src/common/logger.py index 2830c127d..a28628a46 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -1,16 +1,16 @@ # 使用基于时间戳的文件处理器,简单的轮转份数限制 import logging -import orjson import threading import time +from collections.abc import Callable +from datetime import datetime, timedelta +from pathlib import Path + +import orjson import structlog import tomlkit -from pathlib import Path -from typing import Callable, Optional -from datetime import datetime, timedelta - # 创建logs目录 LOG_DIR = Path("logs") LOG_DIR.mkdir(exist_ok=True) @@ -212,7 +212,7 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress try: if config_path.exists(): - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: config = tomlkit.load(f) return config.get("log", default_config) except Exception as e: @@ -942,7 +942,7 @@ raw_logger: structlog.stdlib.BoundLogger = structlog.get_logger() binds: dict[str, Callable] = {} -def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger: +def get_logger(name: str | None) -> structlog.stdlib.BoundLogger: """获取logger实例,支持按名称绑定""" if name is None: return raw_logger diff --git a/src/common/message/__init__.py b/src/common/message/__init__.py index 160456b0f..79f346c04 100644 --- a/src/common/message/__init__.py +++ b/src/common/message/__init__.py @@ -4,7 +4,6 @@ __version__ = "0.1.0" from .api import get_global_api - __all__ = [ "get_global_api", ] diff --git a/src/common/message/api.py b/src/common/message/api.py index 37b7a7ddc..2d797a5a8 100644 --- a/src/common/message/api.py +++ b/src/common/message/api.py @@ -1,10 +1,12 @@ -from src.common.server import get_global_server import importlib.metadata -from maim_message import MessageServer -from src.common.logger import get_logger -from src.config.config import global_config import os +from maim_message import MessageServer + +from src.common.logger import get_logger +from src.common.server import get_global_server +from src.config.config import global_config + global_api = None diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 57f179c36..f9a874859 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,15 +1,15 @@ import traceback +from typing import Any -from typing import List, Optional, Any, Dict -from sqlalchemy import not_, select, func - +from sqlalchemy import func, not_, select from sqlalchemy.orm import DeclarativeBase -from src.config.config import global_config + +from src.common.database.sqlalchemy_database_api import get_db_session # from src.common.database.database_model import Messages from src.common.database.sqlalchemy_models import Messages -from src.common.database.sqlalchemy_database_api import get_db_session from src.common.logger import get_logger +from src.config.config import global_config logger = get_logger(__name__) @@ -18,7 +18,7 @@ class Base(DeclarativeBase): pass -def _model_to_dict(instance: Base) -> Dict[str, Any]: +def _model_to_dict(instance: Base) -> dict[str, Any]: """ 将 SQLAlchemy 模型实例转换为字典。 """ @@ -32,12 +32,12 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]: async def find_messages( message_filter: dict[str, Any], - sort: Optional[List[tuple[str, int]]] = None, + sort: list[tuple[str, int]] | None = None, limit: int = 0, limit_mode: str = "latest", filter_bot=False, filter_command=False, -) -> List[dict[str, Any]]: +) -> list[dict[str, Any]]: """ 根据提供的过滤器、排序和限制条件查找消息。 diff --git a/src/common/remote.py b/src/common/remote.py index 95202f810..f6396a037 100644 --- a/src/common/remote.py +++ b/src/common/remote.py @@ -1,13 +1,13 @@ import asyncio import base64 import json +import platform +from datetime import datetime, timezone import aiohttp -import platform - -from datetime import datetime, timezone from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding, rsa + from src.common.logger import get_logger from src.common.tcp_connector import get_tcp_connector from src.config.config import global_config diff --git a/src/common/server.py b/src/common/server.py index 64299274b..ec6ff932a 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -1,20 +1,20 @@ import os -from typing import Optional -from fastapi import FastAPI, APIRouter +from fastapi import APIRouter, FastAPI from fastapi.middleware.cors import CORSMiddleware # 新增导入 from rich.traceback import install -from uvicorn import Config, Server as UvicornServer +from uvicorn import Config +from uvicorn import Server as UvicornServer install(extra_lines=3) class Server: - def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"): + def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MaiMCore"): self.app = FastAPI(title=app_name) self._host: str = "127.0.0.1" self._port: int = 8080 - self._server: Optional[UvicornServer] = None + self._server: UvicornServer | None = None self.set_address(host, port) # 配置 CORS @@ -57,7 +57,7 @@ class Server: """ self.app.include_router(router, prefix=prefix) - def set_address(self, host: Optional[str] = None, port: Optional[int] = None): + def set_address(self, host: str | None = None, port: int | None = None): """设置服务器地址和端口""" if host: self._host = host @@ -76,7 +76,7 @@ class Server: raise except Exception as e: await self.shutdown() - raise RuntimeError(f"服务器运行错误: {str(e)}") from e + raise RuntimeError(f"服务器运行错误: {e!s}") from e finally: await self.shutdown() diff --git a/src/common/tcp_connector.py b/src/common/tcp_connector.py index dd966e648..868b0c3f2 100644 --- a/src/common/tcp_connector.py +++ b/src/common/tcp_connector.py @@ -1,6 +1,7 @@ import ssl -import certifi + import aiohttp +import certifi ssl_context = ssl.create_default_context(cafile=certifi.where()) diff --git a/src/common/vector_db/__init__.py b/src/common/vector_db/__init__.py index a913c2232..65e0a8025 100644 --- a/src/common/vector_db/__init__.py +++ b/src/common/vector_db/__init__.py @@ -18,4 +18,4 @@ def get_vector_db_service() -> VectorDBBase: # 全局向量数据库服务实例 vector_db_service: VectorDBBase = get_vector_db_service() -__all__ = ["vector_db_service", "VectorDBBase"] +__all__ = ["VectorDBBase", "vector_db_service"] diff --git a/src/common/vector_db/base.py b/src/common/vector_db/base.py index 132ea15cb..04449e24a 100644 --- a/src/common/vector_db/base.py +++ b/src/common/vector_db/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any class VectorDBBase(ABC): @@ -36,10 +36,10 @@ class VectorDBBase(ABC): def add( self, collection_name: str, - embeddings: List[List[float]], - documents: Optional[List[str]] = None, - metadatas: Optional[List[Dict[str, Any]]] = None, - ids: Optional[List[str]] = None, + embeddings: list[list[float]], + documents: list[str] | None = None, + metadatas: list[dict[str, Any]] | None = None, + ids: list[str] | None = None, ) -> None: """ 向指定集合中添加数据。 @@ -57,11 +57,11 @@ class VectorDBBase(ABC): def query( self, collection_name: str, - query_embeddings: List[List[float]], + query_embeddings: list[list[float]], n_results: int = 1, - where: Optional[Dict[str, Any]] = None, + where: dict[str, Any] | None = None, **kwargs: Any, - ) -> Dict[str, List[Any]]: + ) -> dict[str, list[Any]]: """ 在指定集合中查询相似向量。 @@ -81,8 +81,8 @@ class VectorDBBase(ABC): def delete( self, collection_name: str, - ids: Optional[List[str]] = None, - where: Optional[Dict[str, Any]] = None, + ids: list[str] | None = None, + where: dict[str, Any] | None = None, ) -> None: """ 从指定集合中删除数据。 @@ -98,13 +98,13 @@ class VectorDBBase(ABC): def get( self, collection_name: str, - ids: Optional[List[str]] = None, - where: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - where_document: Optional[Dict[str, Any]] = None, - include: Optional[List[str]] = None, - ) -> Dict[str, Any]: + ids: list[str] | None = None, + where: dict[str, Any] | None = None, + limit: int | None = None, + offset: int | None = None, + where_document: dict[str, Any] | None = None, + include: list[str] | None = None, + ) -> dict[str, Any]: """ 根据条件从集合中获取数据。 diff --git a/src/common/vector_db/chromadb_impl.py b/src/common/vector_db/chromadb_impl.py index a0267dfed..1934c812e 100644 --- a/src/common/vector_db/chromadb_impl.py +++ b/src/common/vector_db/chromadb_impl.py @@ -1,12 +1,13 @@ import threading -from typing import Any, Dict, List, Optional +from typing import Any import chromadb from chromadb.config import Settings -from .base import VectorDBBase from src.common.logger import get_logger +from .base import VectorDBBase + logger = get_logger("chromadb_impl") @@ -38,7 +39,7 @@ class ChromaDBImpl(VectorDBBase): self.client = chromadb.PersistentClient( path=path, settings=Settings(anonymized_telemetry=False) ) - self._collections: Dict[str, Any] = {} + self._collections: dict[str, Any] = {} self._initialized = True logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}") except Exception as e: @@ -65,10 +66,10 @@ class ChromaDBImpl(VectorDBBase): def add( self, collection_name: str, - embeddings: List[List[float]], - documents: Optional[List[str]] = None, - metadatas: Optional[List[Dict[str, Any]]] = None, - ids: Optional[List[str]] = None, + embeddings: list[list[float]], + documents: list[str] | None = None, + metadatas: list[dict[str, Any]] | None = None, + ids: list[str] | None = None, ) -> None: collection = self.get_or_create_collection(collection_name) if collection: @@ -85,11 +86,11 @@ class ChromaDBImpl(VectorDBBase): def query( self, collection_name: str, - query_embeddings: List[List[float]], + query_embeddings: list[list[float]], n_results: int = 1, - where: Optional[Dict[str, Any]] = None, + where: dict[str, Any] | None = None, **kwargs: Any, - ) -> Dict[str, List[Any]]: + ) -> dict[str, list[Any]]: collection = self.get_or_create_collection(collection_name) if collection: try: @@ -120,7 +121,7 @@ class ChromaDBImpl(VectorDBBase): logger.error(f"回退查询也失败: {fallback_e}") return {} - def _process_where_condition(self, where: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def _process_where_condition(self, where: dict[str, Any]) -> dict[str, Any] | None: """ 处理where条件,转换为ChromaDB支持的格式 ChromaDB支持的格式: @@ -174,13 +175,13 @@ class ChromaDBImpl(VectorDBBase): def get( self, collection_name: str, - ids: Optional[List[str]] = None, - where: Optional[Dict[str, Any]] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - where_document: Optional[Dict[str, Any]] = None, - include: Optional[List[str]] = None, - ) -> Dict[str, Any]: + ids: list[str] | None = None, + where: dict[str, Any] | None = None, + limit: int | None = None, + offset: int | None = None, + where_document: dict[str, Any] | None = None, + include: list[str] | None = None, + ) -> dict[str, Any]: """根据条件从集合中获取数据""" collection = self.get_or_create_collection(collection_name) if collection: @@ -217,8 +218,8 @@ class ChromaDBImpl(VectorDBBase): def delete( self, collection_name: str, - ids: Optional[List[str]] = None, - where: Optional[Dict[str, Any]] = None, + ids: list[str] | None = None, + where: dict[str, Any] | None = None, ) -> None: collection = self.get_or_create_collection(collection_name) if collection: diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index eb5d1a1f1..de7479efb 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -1,6 +1,7 @@ -from typing import List, Dict, Any, Literal, Union, Optional -from pydantic import Field from threading import Lock +from typing import Any, Literal + +from pydantic import Field from src.config.config_base import ValidatedConfigBase @@ -10,7 +11,7 @@ class APIProvider(ValidatedConfigBase): name: str = Field(..., min_length=1, description="API提供商名称") base_url: str = Field(..., description="API基础URL") - api_key: Union[str, List[str]] = Field(..., min_length=1, description="API密钥,支持单个密钥或密钥列表轮询") + api_key: str | list[str] = Field(..., min_length=1, description="API密钥,支持单个密钥或密钥列表轮询") client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field( default="openai", description="客户端类型(如openai/google等,默认为openai)" ) @@ -70,7 +71,7 @@ class ModelInfo(ValidatedConfigBase): price_in: float = Field(default=0.0, ge=0, description="每M token输入价格") price_out: float = Field(default=0.0, ge=0, description="每M token输出价格") force_stream_mode: bool = Field(default=False, description="是否强制使用流式输出模式") - extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数(用于API调用时的额外配置)") + extra_params: dict[str, Any] = Field(default_factory=dict, description="额外参数(用于API调用时的额外配置)") anti_truncation: bool = Field(default=False, description="是否启用反截断功能,防止模型输出被截断") @classmethod @@ -101,11 +102,11 @@ class ModelInfo(ValidatedConfigBase): class TaskConfig(ValidatedConfigBase): """任务配置类""" - model_list: List[str] = Field(..., description="任务使用的模型列表") + model_list: list[str] = Field(..., description="任务使用的模型列表") max_tokens: int = Field(default=800, description="任务最大输出token数") temperature: float = Field(default=0.7, description="模型温度") concurrency_count: int = Field(default=1, description="并发请求数量") - embedding_dimension: Optional[int] = Field( + embedding_dimension: int | None = Field( default=None, description="嵌入模型输出向量维度,仅在嵌入任务中使用", ge=1, @@ -168,9 +169,9 @@ class ModelTaskConfig(ValidatedConfigBase): class APIAdapterConfig(ValidatedConfigBase): """API Adapter配置类""" - models: List[ModelInfo] = Field(..., min_length=1, description="模型列表") + models: list[ModelInfo] = Field(..., min_length=1, description="模型列表") model_task_config: ModelTaskConfig = Field(..., description="模型任务配置") - api_providers: List[APIProvider] = Field(..., min_length=1, description="API提供商列表") + api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表") def __init__(self, **data): super().__init__(**data) diff --git a/src/config/config.py b/src/config/config.py index 375d513df..846643477 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -1,60 +1,58 @@ import os -import tomlkit import shutil import sys - from datetime import datetime -from tomlkit import TOMLDocument -from tomlkit.items import Table, KeyType -from rich.traceback import install -from typing import List, Optional + +import tomlkit from pydantic import Field +from rich.traceback import install +from tomlkit import TOMLDocument +from tomlkit.items import KeyType, Table from src.common.logger import get_logger from src.config.config_base import ValidatedConfigBase from src.config.official_configs import ( - DatabaseConfig, + AffinityFlowConfig, + AntiPromptInjectionConfig, BotConfig, - PersonalityConfig, - ExpressionConfig, ChatConfig, - NormalChatConfig, - EmojiConfig, - MemoryConfig, - MoodConfig, - KeywordReactionConfig, ChineseTypoConfig, + CommandConfig, + CrossContextConfig, + CustomPromptConfig, + DatabaseConfig, + DebugConfig, + DependencyManagementConfig, + EmojiConfig, + ExperimentalConfig, + ExpressionConfig, + KeywordReactionConfig, + LPMMKnowledgeConfig, + MaimMessageConfig, + MemoryConfig, + MessageReceiveConfig, + MoodConfig, + NormalChatConfig, + PermissionConfig, + PersonalityConfig, + PlanningSystemConfig, + ProactiveThinkingConfig, + RelationshipConfig, ResponsePostProcessConfig, ResponseSplitterConfig, - ExperimentalConfig, - MessageReceiveConfig, - MaimMessageConfig, - LPMMKnowledgeConfig, - RelationshipConfig, - ToolConfig, - VoiceConfig, - DebugConfig, - CustomPromptConfig, - VideoAnalysisConfig, - DependencyManagementConfig, - WebSearchConfig, - AntiPromptInjectionConfig, SleepSystemConfig, - CrossContextConfig, - PermissionConfig, - CommandConfig, - PlanningSystemConfig, - AffinityFlowConfig, - ProactiveThinkingConfig, + ToolConfig, + VideoAnalysisConfig, + VoiceConfig, + WebSearchConfig, ) from .api_ada_configs import ( - ModelTaskConfig, - ModelInfo, APIProvider, + ModelInfo, + ModelTaskConfig, ) - install(extra_lines=3) @@ -148,11 +146,11 @@ def compare_default_values(new, old, path=None, logs=None, changes=None): return logs, changes -def _get_version_from_toml(toml_path) -> Optional[str]: +def _get_version_from_toml(toml_path) -> str | None: """从TOML文件中获取版本号""" if not os.path.exists(toml_path): return None - with open(toml_path, "r", encoding="utf-8") as f: + with open(toml_path, encoding="utf-8") as f: doc = tomlkit.load(f) if "inner" in doc and "version" in doc["inner"]: # type: ignore return doc["inner"]["version"] # type: ignore @@ -264,17 +262,17 @@ def _update_config_generic(config_name: str, template_name: str): # 先读取 compare 下的模板(如果有),用于默认值变动检测 if os.path.exists(compare_path): - with open(compare_path, "r", encoding="utf-8") as f: + with open(compare_path, encoding="utf-8") as f: compare_config = tomlkit.load(f) # 读取当前模板 - with open(template_path, "r", encoding="utf-8") as f: + with open(template_path, encoding="utf-8") as f: new_config = tomlkit.load(f) # 检查默认值变化并处理(只有 compare_config 存在时才做) if compare_config: # 读取旧配置 - with open(old_config_path, "r", encoding="utf-8") as f: + with open(old_config_path, encoding="utf-8") as f: old_config = tomlkit.load(f) logs, changes = compare_default_values(new_config, compare_config) if logs: @@ -304,7 +302,7 @@ def _update_config_generic(config_name: str, template_name: str): # 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次) if old_config is None: - with open(old_config_path, "r", encoding="utf-8") as f: + with open(old_config_path, encoding="utf-8") as f: old_config = tomlkit.load(f) # new_config 已经读取 @@ -350,7 +348,7 @@ def _update_config_generic(config_name: str, template_name: str): # 移除在新模板中已不存在的旧配置项 logger.info(f"开始移除{config_name}中已废弃的配置项...") - with open(template_path, "r", encoding="utf-8") as f: + with open(template_path, encoding="utf-8") as f: template_doc = tomlkit.load(f) _remove_obsolete_keys(new_config, template_doc) logger.info(f"已移除{config_name}中已废弃的配置项") @@ -428,9 +426,9 @@ class Config(ValidatedConfigBase): class APIAdapterConfig(ValidatedConfigBase): """API Adapter配置类""" - models: List[ModelInfo] = Field(..., min_items=1, description="模型列表") + models: list[ModelInfo] = Field(..., min_items=1, description="模型列表") model_task_config: ModelTaskConfig = Field(..., description="模型任务配置") - api_providers: List[APIProvider] = Field(..., min_items=1, description="API提供商列表") + api_providers: list[APIProvider] = Field(..., min_items=1, description="API提供商列表") def __init__(self, **data): super().__init__(**data) @@ -494,7 +492,7 @@ def load_config(config_path: str) -> Config: Config对象 """ # 读取配置文件 - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: config_data = tomlkit.load(f) # 创建Config对象(各个配置类会自动进行 Pydantic 验证) @@ -517,7 +515,7 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig: APIAdapterConfig对象 """ # 读取配置文件 - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: config_data = tomlkit.load(f) config_dict = dict(config_data) diff --git a/src/config/config_base.py b/src/config/config_base.py index 764ec5b71..a80740a46 100644 --- a/src/config/config_base.py +++ b/src/config/config_base.py @@ -1,6 +1,8 @@ -from dataclasses import dataclass, fields, MISSING -from typing import TypeVar, Type, Any, get_origin, get_args, Literal +from dataclasses import MISSING, dataclass, fields +from typing import Any, Literal, TypeVar, get_args, get_origin + from pydantic import BaseModel, ValidationError +from typing_extensions import Self T = TypeVar("T", bound="ConfigBase") @@ -19,7 +21,7 @@ class ConfigBase: """配置类的基类""" @classmethod - def from_dict(cls: Type[T], data: dict[str, Any]) -> T: + def from_dict(cls, data: dict[str, Any]) -> Self: """从字典加载配置字段""" if not isinstance(data, dict): raise TypeError(f"Expected a dictionary, got {type(data).__name__}") @@ -53,7 +55,7 @@ class ConfigBase: return cls() @classmethod - def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any: + def _convert_field(cls, value: Any, field_type: type[Any]) -> Any: """ 转换字段值为指定类型 diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 6a1613baa..ecdb5d5b5 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1,4 +1,5 @@ -from typing import Literal, Optional, List +from typing import Literal + from pydantic import Field from src.config.config_base import ValidatedConfigBase @@ -42,7 +43,7 @@ class BotConfig(ValidatedConfigBase): platform: str = Field(..., description="平台") qq_account: int = Field(..., description="QQ账号") nickname: str = Field(..., description="昵称") - alias_names: List[str] = Field(default_factory=list, description="别名列表") + alias_names: list[str] = Field(default_factory=list, description="别名列表") class PersonalityConfig(ValidatedConfigBase): @@ -54,7 +55,7 @@ class PersonalityConfig(ValidatedConfigBase): background_story: str = Field( default="", description="世界观背景故事,这部分内容会作为背景知识,LLM被指导不应主动复述" ) - safety_guidelines: List[str] = Field( + safety_guidelines: list[str] = Field( default_factory=list, description="安全与互动底线,Bot在任何情况下都必须遵守的原则" ) reply_style: str = Field(default="", description="表达风格") @@ -63,7 +64,7 @@ class PersonalityConfig(ValidatedConfigBase): compress_identity: bool = Field(default=True, description="是否压缩身份") # 回复规则配置 - reply_targeting_rules: List[str] = Field( + reply_targeting_rules: list[str] = Field( default_factory=lambda: [ "拒绝任何包含骚扰、冒犯、暴力、色情或危险内容的请求。", "在拒绝时,请使用符合你人设的、坚定的语气。", @@ -72,7 +73,7 @@ class PersonalityConfig(ValidatedConfigBase): description="安全与互动底线规则,Bot在任何情况下都必须遵守的原则", ) - message_targeting_analysis: List[str] = Field( + message_targeting_analysis: list[str] = Field( default_factory=lambda: [ "**直接针对你**:@你、回复你、明确询问你 → 必须回应", "**间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与", @@ -82,7 +83,7 @@ class PersonalityConfig(ValidatedConfigBase): description="消息针对性分析规则,用于判断是否需要回复", ) - reply_principles: List[str] = Field( + reply_principles: list[str] = Field( default_factory=lambda: [ "明确回应目标消息,而不是宽泛地评论。", "可以分享你的看法、提出相关问题,或者开个合适的玩笑。", @@ -111,7 +112,7 @@ class ChatConfig(ValidatedConfigBase): at_bot_inevitable_reply: bool = Field(default=False, description="@机器人的必然回复") allow_reply_self: bool = Field(default=False, description="是否允许回复自己说的话") focus_value: float = Field(default=1.0, description="专注值") - focus_mode_quiet_groups: List[str] = Field( + focus_mode_quiet_groups: list[str] = Field( default_factory=list, description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]', ) @@ -140,8 +141,8 @@ class ChatConfig(ValidatedConfigBase): class MessageReceiveConfig(ValidatedConfigBase): """消息接收配置类""" - ban_words: List[str] = Field(default_factory=lambda: list(), description="禁用词列表") - ban_msgs_regex: List[str] = Field(default_factory=lambda: list(), description="禁用消息正则列表") + ban_words: list[str] = Field(default_factory=lambda: list(), description="禁用词列表") + ban_msgs_regex: list[str] = Field(default_factory=lambda: list(), description="禁用消息正则列表") class NormalChatConfig(ValidatedConfigBase): @@ -155,16 +156,16 @@ class ExpressionRule(ValidatedConfigBase): use_expression: bool = Field(default=True, description="是否使用学到的表达") learn_expression: bool = Field(default=True, description="是否学习表达") learning_strength: float = Field(default=1.0, description="学习强度") - group: Optional[str] = Field(default=None, description="表达共享组") + group: str | None = Field(default=None, description="表达共享组") class ExpressionConfig(ValidatedConfigBase): """表达配置类""" - rules: List[ExpressionRule] = Field(default_factory=list, description="表达学习规则") + rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则") @staticmethod - def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]: + def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None: """ 解析流配置字符串并生成对应的 chat_id @@ -199,7 +200,7 @@ class ExpressionConfig(ValidatedConfigBase): except (ValueError, IndexError): return None - def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, float]: + def get_expression_config_for_chat(self, chat_stream_id: str | None = None) -> tuple[bool, bool, float]: """ 根据聊天流ID获取表达配置 @@ -362,7 +363,7 @@ class KeywordRuleConfig(ValidatedConfigBase): try: re.compile(pattern) except re.error as e: - raise ValueError(f"无效的正则表达式 '{pattern}': {str(e)}") from e + raise ValueError(f"无效的正则表达式 '{pattern}': {e!s}") from e class KeywordReactionConfig(ValidatedConfigBase): @@ -561,10 +562,10 @@ class SleepSystemConfig(ValidatedConfigBase): # --- 失眠机制相关参数 --- enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统") - insomnia_trigger_delay_minutes: List[int] = Field( + insomnia_trigger_delay_minutes: list[int] = Field( default_factory=lambda: [30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)" ) - insomnia_duration_minutes: List[int] = Field( + insomnia_duration_minutes: list[int] = Field( default_factory=lambda: [15, 45], description="单次失眠状态的持续时间范围(分钟)" ) sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值") @@ -590,7 +591,7 @@ class ContextGroup(ValidatedConfigBase): """上下文共享组配置""" name: str = Field(..., description="共享组的名称") - chat_ids: List[List[str]] = Field( + chat_ids: list[list[str]] = Field( ..., description='属于该组的聊天ID列表,格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]', ) @@ -600,20 +601,20 @@ class CrossContextConfig(ValidatedConfigBase): """跨群聊上下文共享配置""" enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能") - groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表") + groups: list[ContextGroup] = Field(default_factory=list, description="上下文共享组列表") class CommandConfig(ValidatedConfigBase): """命令系统配置类""" - command_prefixes: List[str] = Field(default_factory=lambda: ["/", "!", ".", "#"], description="支持的命令前缀列表") + command_prefixes: list[str] = Field(default_factory=lambda: ["/", "!", ".", "#"], description="支持的命令前缀列表") class PermissionConfig(ValidatedConfigBase): """权限系统配置类""" # Master用户配置(拥有最高权限,无视所有权限节点) - master_users: List[List[str]] = Field( + master_users: list[list[str]] = Field( default_factory=list, description="Master用户列表,格式: [[platform, user_id], ...]" ) @@ -668,10 +669,10 @@ class ProactiveThinkingConfig(ValidatedConfigBase): # --- 作用范围 --- enable_in_private: bool = Field(default=True, description="是否允许在私聊中主动发起对话") enable_in_group: bool = Field(default=True, description="是否允许在群聊中主动发起对话") - enabled_private_chats: List[str] = Field( + enabled_private_chats: list[str] = Field( default_factory=list, description='私聊白名单,为空则对所有私聊生效。格式: ["platform:user_id", ...]' ) - enabled_group_chats: List[str] = Field( + enabled_group_chats: list[str] = Field( default_factory=list, description='群聊白名单,为空则对所有群聊生效。格式: ["platform:group_id", ...]' ) diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index 4716921f9..83c24d4f6 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -1,13 +1,14 @@ -import orjson -import os import hashlib +import os import time +import orjson +from rich.traceback import install + from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import get_person_info_manager -from rich.traceback import install install(extra_lines=3) @@ -193,9 +194,9 @@ class Individuality: """从JSON文件中加载元信息""" if os.path.exists(self.meta_info_file_path): try: - with open(self.meta_info_file_path, "r", encoding="utf-8") as f: + with open(self.meta_info_file_path, encoding="utf-8") as f: return orjson.loads(f.read()) - except (orjson.JSONDecodeError, IOError) as e: + except (OSError, orjson.JSONDecodeError) as e: logger.error(f"读取meta_info文件失败: {e}, 将创建新文件。") return {} return {} @@ -206,16 +207,16 @@ class Individuality: os.makedirs(os.path.dirname(self.meta_info_file_path), exist_ok=True) with open(self.meta_info_file_path, "w", encoding="utf-8") as f: f.write(orjson.dumps(meta_info, option=orjson.OPT_INDENT_2).decode("utf-8")) - except IOError as e: + except OSError as e: logger.error(f"保存meta_info文件失败: {e}") def _load_personality_data(self) -> dict: """从JSON文件中加载personality数据""" if os.path.exists(self.personality_data_file_path): try: - with open(self.personality_data_file_path, "r", encoding="utf-8") as f: + with open(self.personality_data_file_path, encoding="utf-8") as f: return orjson.loads(f.read()) - except (orjson.JSONDecodeError, IOError) as e: + except (OSError, orjson.JSONDecodeError) as e: logger.error(f"读取personality_data文件失败: {e}, 将创建新文件。") return {} return {} @@ -227,7 +228,7 @@ class Individuality: with open(self.personality_data_file_path, "w", encoding="utf-8") as f: f.write(orjson.dumps(personality_data, option=orjson.OPT_INDENT_2).decode("utf-8")) logger.debug(f"已保存personality数据到文件: {self.personality_data_file_path}") - except IOError as e: + except OSError as e: logger.error(f"保存personality_data文件失败: {e}") def _get_personality_from_file(self) -> tuple[str, str]: diff --git a/src/individuality/not_using/offline_llm.py b/src/individuality/not_using/offline_llm.py index 2bafb69aa..752293ab8 100644 --- a/src/individuality/not_using/offline_llm.py +++ b/src/individuality/not_using/offline_llm.py @@ -1,13 +1,13 @@ import asyncio import os import time -from typing import Tuple, Union import aiohttp import requests +from rich.traceback import install + from src.common.logger import get_logger from src.common.tcp_connector import get_tcp_connector -from rich.traceback import install install(extra_lines=3) @@ -26,7 +26,7 @@ class LLMRequestOff: # logger.info(f"API URL: {self.base_url}") # 使用 logger 记录 base_url - def generate_response(self, prompt: str) -> Union[str, Tuple[str, str]]: + def generate_response(self, prompt: str) -> str | tuple[str, str]: """根据输入的提示生成模型的响应""" headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} @@ -67,16 +67,16 @@ class LLMRequestOff: except Exception as e: if retry < max_retries - 1: # 如果还有重试机会 wait_time = base_wait_time * (2**retry) - logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {e!s}") time.sleep(wait_time) else: - logger.error(f"请求失败: {str(e)}") - return f"请求失败: {str(e)}", "" + logger.error(f"请求失败: {e!s}") + return f"请求失败: {e!s}", "" logger.error("达到最大重试次数,请求仍然失败") return "达到最大重试次数,请求仍然失败", "" - async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]: + async def generate_response_async(self, prompt: str) -> str | tuple[str, str]: """异步方式根据输入的提示生成模型的响应""" headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} @@ -117,11 +117,11 @@ class LLMRequestOff: except Exception as e: if retry < max_retries - 1: # 如果还有重试机会 wait_time = base_wait_time * (2**retry) - logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {e!s}") await asyncio.sleep(wait_time) else: - logger.error(f"请求失败: {str(e)}") - return f"请求失败: {str(e)}", "" + logger.error(f"请求失败: {e!s}") + return f"请求失败: {e!s}", "" logger.error("达到最大重试次数,请求仍然失败") return "达到最大重试次数,请求仍然失败", "" diff --git a/src/individuality/not_using/per_bf_gen.py b/src/individuality/not_using/per_bf_gen.py index 9e4d0291f..4aea7e7de 100644 --- a/src/individuality/not_using/per_bf_gen.py +++ b/src/individuality/not_using/per_bf_gen.py @@ -1,10 +1,10 @@ -from typing import Dict, List -import orjson import os -from dotenv import load_dotenv -import sys -import toml import random +import sys + +import orjson +import toml +from dotenv import load_dotenv from tqdm import tqdm # 添加项目根目录到 Python 路径 @@ -13,13 +13,13 @@ sys.path.append(root_path) # 加载配置文件 config_path = os.path.join(root_path, "config", "bot_config.toml") -with open(config_path, "r", encoding="utf-8") as f: +with open(config_path, encoding="utf-8") as f: config = toml.load(f) # 现在可以导入src模块 from individuality.not_using.scene import get_scene_by_factor, PERSONALITY_SCENES # noqa E402 -from individuality.not_using.questionnaire import FACTOR_DESCRIPTIONS # noqa E402 -from individuality.not_using.offline_llm import LLMRequestOff # noqa E402 +from individuality.not_using.questionnaire import FACTOR_DESCRIPTIONS +from individuality.not_using.offline_llm import LLMRequestOff # 加载环境变量 env_path = os.path.join(root_path, ".env") @@ -75,7 +75,7 @@ def adapt_scene(scene: str) -> str: return adapted_scene except Exception as e: - print(f"场景改编过程出错:{str(e)},将使用原始场景") + print(f"场景改编过程出错:{e!s},将使用原始场景") return scene @@ -83,8 +83,8 @@ class PersonalityEvaluatorDirect: def __init__(self): self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} self.scenarios = [] - self.final_scores: Dict[str, float] = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} - self.dimension_counts = {trait: 0 for trait in self.final_scores} + self.final_scores: dict[str, float] = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} + self.dimension_counts = dict.fromkeys(self.final_scores, 0) # 为每个人格特质获取对应的场景 for trait in PERSONALITY_SCENES: @@ -112,7 +112,7 @@ class PersonalityEvaluatorDirect: self.llm = LLMRequestOff() - def evaluate_response(self, scenario: str, response: str, dimensions: List[str]) -> Dict[str, float]: + def evaluate_response(self, scenario: str, response: str, dimensions: list[str]) -> dict[str, float]: """ 使用 DeepSeek AI 评估用户对特定场景的反应 """ @@ -163,10 +163,10 @@ class PersonalityEvaluatorDirect: return {k: max(1, min(6, float(v))) for k, v in scores.items()} else: print("AI响应格式不正确,使用默认评分") - return {dim: 3.5 for dim in dimensions} + return dict.fromkeys(dimensions, 3.5) except Exception as e: - print(f"评估过程出错:{str(e)}") - return {dim: 3.5 for dim in dimensions} + print(f"评估过程出错:{e!s}") + return dict.fromkeys(dimensions, 3.5) def run_evaluation(self): """ diff --git a/src/individuality/not_using/scene.py b/src/individuality/not_using/scene.py index 929a9c426..9c16358e6 100644 --- a/src/individuality/not_using/scene.py +++ b/src/individuality/not_using/scene.py @@ -1,7 +1,8 @@ -import orjson import os from typing import Any +import orjson + def load_scenes() -> dict[str, Any]: """ @@ -13,7 +14,7 @@ def load_scenes() -> dict[str, Any]: current_dir = os.path.dirname(os.path.abspath(__file__)) json_path = os.path.join(current_dir, "template_scene.json") - with open(json_path, "r", encoding="utf-8") as f: + with open(json_path, encoding="utf-8") as f: return orjson.loads(f.read()) diff --git a/src/llm_models/exceptions.py b/src/llm_models/exceptions.py index 5b04f58c6..ad2b9a69d 100644 --- a/src/llm_models/exceptions.py +++ b/src/llm_models/exceptions.py @@ -1,6 +1,5 @@ from typing import Any - # 常见Error Code Mapping (以OpenAI API为例) error_code_mapping = { 400: "参数不正确", diff --git a/src/llm_models/model_client/aiohttp_gemini_client.py b/src/llm_models/model_client/aiohttp_gemini_client.py index 3d4dd8ca1..84470fb60 100644 --- a/src/llm_models/model_client/aiohttp_gemini_client.py +++ b/src/llm_models/model_client/aiohttp_gemini_client.py @@ -1,21 +1,24 @@ import asyncio -import orjson import io -from typing import Callable, Any, Coroutine, Optional -import aiohttp +from collections.abc import Callable, Coroutine +from typing import Any + +import aiohttp +import orjson -from src.config.api_ada_configs import ModelInfo, APIProvider from src.common.logger import get_logger -from .base_client import APIResponse, UsageRecord, BaseClient, client_registry +from src.config.api_ada_configs import APIProvider, ModelInfo + from ..exceptions import ( - RespParseException, NetworkConnectionError, - RespNotOkException, ReqAbortException, + RespNotOkException, + RespParseException, ) from ..payload_content.message import Message, RoleType from ..payload_content.resp_format import RespFormat, RespFormatType -from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall +from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam +from .base_client import APIResponse, BaseClient, UsageRecord, client_registry logger = get_logger("AioHTTP-Gemini客户端") @@ -210,7 +213,7 @@ class AiohttpGeminiStreamParser: chunk_data = orjson.loads(chunk_text) # 解析候选项 - if "candidates" in chunk_data and chunk_data["candidates"]: + if chunk_data.get("candidates"): candidate = chunk_data["candidates"][0] # 解析内容 @@ -266,7 +269,7 @@ class AiohttpGeminiStreamParser: async def _default_stream_response_handler( response: aiohttp.ClientResponse, interrupt_flag: asyncio.Event | None, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: +) -> tuple[APIResponse, tuple[int, int, int] | None]: """默认流式响应处理器""" parser = AiohttpGeminiStreamParser() @@ -290,13 +293,13 @@ async def _default_stream_response_handler( def _default_normal_response_parser( response_data: dict, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: +) -> tuple[APIResponse, tuple[int, int, int] | None]: """默认普通响应解析器""" api_response = APIResponse() try: # 解析候选项 - if "candidates" in response_data and response_data["candidates"]: + if response_data.get("candidates"): candidate = response_data["candidates"][0] # 解析文本内容 @@ -419,13 +422,12 @@ class AiohttpGeminiClient(BaseClient): max_tokens: int = 1024, temperature: float = 0.7, response_format: RespFormat | None = None, - stream_response_handler: Optional[ - Callable[ - [aiohttp.ClientResponse, asyncio.Event | None], - Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]], - ] - ] = None, - async_response_parser: Optional[Callable[[dict], tuple[APIResponse, Optional[tuple[int, int, int]]]]] = None, + stream_response_handler: Callable[ + [aiohttp.ClientResponse, asyncio.Event | None], + Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]], + ] + | None = None, + async_response_parser: Callable[[dict], tuple[APIResponse, tuple[int, int, int] | None]] | None = None, interrupt_flag: asyncio.Event | None = None, extra_params: dict[str, Any] | None = None, ) -> APIResponse: diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index eb74b0dfe..88f8601d6 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -1,12 +1,14 @@ import asyncio -from dataclasses import dataclass from abc import ABC, abstractmethod -from typing import Callable, Any, Optional +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from src.config.api_ada_configs import APIProvider, ModelInfo -from src.config.api_ada_configs import ModelInfo, APIProvider from ..payload_content.message import Message from ..payload_content.resp_format import RespFormat -from ..payload_content.tool_option import ToolOption, ToolCall +from ..payload_content.tool_option import ToolCall, ToolOption @dataclass @@ -75,9 +77,8 @@ class BaseClient(ABC): max_tokens: int = 1024, temperature: float = 0.7, response_format: RespFormat | None = None, - stream_response_handler: Optional[ - Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]] - ] = None, + stream_response_handler: Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]] + | None = None, async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None, interrupt_flag: asyncio.Event | None = None, extra_params: dict[str, Any] | None = None, diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 0ef79a89b..8005affaa 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -1,17 +1,17 @@ import asyncio -import io -import orjson -import re import base64 -from collections.abc import Iterable -from typing import Callable, Any, Coroutine, Optional -from json_repair import repair_json +import io +import re +from collections.abc import Callable, Coroutine, Iterable +from typing import Any +import orjson +from json_repair import repair_json from openai import ( - AsyncOpenAI, + NOT_GIVEN, APIConnectionError, APIStatusError, - NOT_GIVEN, + AsyncOpenAI, AsyncStream, ) from openai.types.chat import ( @@ -22,18 +22,19 @@ from openai.types.chat import ( ) from openai.types.chat.chat_completion_chunk import ChoiceDelta -from src.config.api_ada_configs import ModelInfo, APIProvider from src.common.logger import get_logger -from .base_client import APIResponse, UsageRecord, BaseClient, client_registry +from src.config.api_ada_configs import APIProvider, ModelInfo + from ..exceptions import ( - RespParseException, NetworkConnectionError, - RespNotOkException, ReqAbortException, + RespNotOkException, + RespParseException, ) from ..payload_content.message import Message, RoleType from ..payload_content.resp_format import RespFormat -from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall +from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam +from .base_client import APIResponse, BaseClient, UsageRecord, client_registry logger = get_logger("OpenAI客户端") @@ -241,7 +242,7 @@ def _build_stream_api_resp( async def _default_stream_response_handler( resp_stream: AsyncStream[ChatCompletionChunk], interrupt_flag: asyncio.Event | None, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: +) -> tuple[APIResponse, tuple[int, int, int] | None]: """ 流式响应处理函数 - 处理OpenAI API的流式响应 :param resp_stream: 流式响应对象 @@ -315,7 +316,7 @@ pattern = re.compile( def _default_normal_response_parser( resp: ChatCompletion, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: +) -> tuple[APIResponse, tuple[int, int, int] | None]: """ 解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象 :param resp: 响应对象 @@ -391,15 +392,13 @@ class OpenaiClient(BaseClient): max_tokens: int = 1024, temperature: float = 0.7, response_format: RespFormat | None = None, - stream_response_handler: Optional[ - Callable[ - [AsyncStream[ChatCompletionChunk], asyncio.Event | None], - Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]], - ] - ] = None, - async_response_parser: Optional[ - Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]] - ] = None, + stream_response_handler: Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], + Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]], + ] + | None = None, + async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int] | None]] + | None = None, interrupt_flag: asyncio.Event | None = None, extra_params: dict[str, Any] | None = None, ) -> APIResponse: @@ -514,17 +513,17 @@ class OpenaiClient(BaseClient): ) except APIConnectionError as e: # 添加详细的错误信息以便调试 - logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}") + logger.error(f"OpenAI API连接错误(嵌入模型): {e!s}") logger.error(f"错误类型: {type(e)}") if hasattr(e, "__cause__") and e.__cause__: - logger.error(f"底层错误: {str(e.__cause__)}") + logger.error(f"底层错误: {e.__cause__!s}") raise NetworkConnectionError() from e except APIStatusError as e: # 重封装APIError为RespNotOkException raise RespNotOkException(e.status_code) from e except Exception as e: # 添加通用异常处理和日志记录 - logger.error(f"获取嵌入时发生未知错误: {str(e)}") + logger.error(f"获取嵌入时发生未知错误: {e!s}") logger.error(f"错误类型: {type(e)}") raise diff --git a/src/llm_models/payload_content/message.py b/src/llm_models/payload_content/message.py index 17d1fa30b..7a34349a3 100644 --- a/src/llm_models/payload_content/message.py +++ b/src/llm_models/payload_content/message.py @@ -1,6 +1,5 @@ from enum import Enum - # 设计这系列类的目的是为未来可能的扩展做准备 diff --git a/src/llm_models/payload_content/resp_format.py b/src/llm_models/payload_content/resp_format.py index e1baa3742..342fbf327 100644 --- a/src/llm_models/payload_content/resp_format.py +++ b/src/llm_models/payload_content/resp_format.py @@ -1,8 +1,8 @@ from enum import Enum -from typing import Optional, Any +from typing import Any from pydantic import BaseModel -from typing_extensions import TypedDict, Required +from typing_extensions import Required, TypedDict class RespFormatType(Enum): @@ -20,7 +20,7 @@ class JsonSchema(TypedDict, total=False): of 64. """ - description: Optional[str] + description: str | None """ A description of what the response format is for, used by the model to determine how to respond in the format. @@ -32,7 +32,7 @@ class JsonSchema(TypedDict, total=False): to build JSON schemas [here](https://json-schema.org/). """ - strict: Optional[bool] + strict: bool | None """ Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the `schema` @@ -100,7 +100,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]: # 如果当前Schema是列表,则遍历每个元素 for i in range(len(sub_schema)): if isinstance(sub_schema[i], dict): - sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs) + sub_schema[i] = link_definitions_recursive(f"{path}/{i!s}", sub_schema[i], defs) else: # 否则为字典 if "$defs" in sub_schema: @@ -140,8 +140,7 @@ def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]: schema[idx] = _remove_title(item) elif isinstance(schema, dict): # 是字典,移除title字段,并对所有dict/list子元素递归调用 - if "$defs" in schema: - del schema["$defs"] + schema.pop("$defs", None) for key, value in schema.items(): if isinstance(value, (dict, list)): schema[key] = _remove_title(value) diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index c322e2ffb..bcac832f1 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -1,14 +1,15 @@ import base64 import io - -from PIL import Image from datetime import datetime -from src.common.logger import get_logger +from PIL import Image + from src.common.database.sqlalchemy_models import LLMUsage, get_db_session +from src.common.logger import get_logger from src.config.api_ada_configs import ModelInfo -from .payload_content.message import Message, MessageBuilder + from .model_client.base_client import UsageRecord +from .payload_content.message import Message, MessageBuilder logger = get_logger("消息压缩工具") @@ -38,7 +39,7 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * return image_data except Exception as e: - logger.error(f"图片转换格式失败: {str(e)}") + logger.error(f"图片转换格式失败: {e!s}") return image_data def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]: @@ -87,7 +88,7 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * return output_buffer.getvalue(), original_size, new_size except Exception as e: - logger.error(f"图片缩放失败: {str(e)}") + logger.error(f"图片缩放失败: {e!s}") import traceback logger.error(traceback.format_exc()) @@ -188,7 +189,7 @@ class LLMUsageRecorder: f"总计: {model_usage.total_tokens}" ) except Exception as e: - logger.error(f"记录token使用情况失败: {str(e)}") + logger.error(f"记录token使用情况失败: {e!s}") llm_usage_recorder = LLMUsageRecorder() diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index a8a68c2fb..afb2f13ed 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ @desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。 它被设计为一个高度容错和可扩展的系统,包含以下主要组件: @@ -19,24 +18,26 @@ 作为模块的统一入口(Facade),为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。 """ -import re import asyncio -import time import random +import re import string - +import time +from collections.abc import Callable, Coroutine from enum import Enum +from typing import Any + from rich.traceback import install -from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine from src.common.logger import get_logger -from src.config.config import model_config from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig -from .payload_content.message import MessageBuilder, Message -from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder -from .model_client.base_client import BaseClient, APIResponse, client_registry, UsageRecord -from .utils import compress_messages, llm_usage_recorder +from src.config.config import model_config + from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException +from .model_client.base_client import APIResponse, BaseClient, UsageRecord, client_registry +from .payload_content.message import Message, MessageBuilder +from .payload_content.tool_option import ToolCall, ToolOption, ToolOptionBuilder +from .utils import compress_messages, llm_usage_recorder install(extra_lines=3) @@ -139,7 +140,7 @@ class _ModelSelector: CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数 DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量 - def __init__(self, model_list: List[str], model_usage: Dict[str, Tuple[int, int, int]]): + def __init__(self, model_list: list[str], model_usage: dict[str, tuple[int, int, int]]): """ 初始化模型选择器。 @@ -153,7 +154,7 @@ class _ModelSelector: def select_best_available_model( self, failed_models_in_this_request: set, request_type: str - ) -> Optional[Tuple[ModelInfo, APIProvider, BaseClient]]: + ) -> tuple[ModelInfo, APIProvider, BaseClient] | None: """ 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。 @@ -306,7 +307,7 @@ class _PromptProcessor: return processed_prompt - def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]: + def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]: """ 处理响应内容,提取思维链并检查截断。 @@ -393,7 +394,7 @@ class _PromptProcessor: return " ".join(result) @staticmethod - def _extract_reasoning(content: str) -> Tuple[str, str]: + def _extract_reasoning(content: str) -> tuple[str, str]: """ 从模型返回的完整内容中提取被...标签包裹的思考过程, 并返回清理后的内容和思考过程。 @@ -462,7 +463,7 @@ class _RequestExecutor: RuntimeError: 如果达到最大重试次数。 """ retry_remain = api_provider.max_retry - compressed_messages: Optional[List[Message]] = None + compressed_messages: list[Message] | None = None while retry_remain > 0: try: @@ -487,7 +488,7 @@ class _RequestExecutor: return await client.get_audio_transcriptions(model_info=model_info, **kwargs) except Exception as e: - logger.debug(f"请求失败: {str(e)}") + logger.debug(f"请求失败: {e!s}") # 记录失败并更新模型的惩罚值 self.model_selector.update_failure_penalty(model_info.name, e) @@ -514,7 +515,7 @@ class _RequestExecutor: def _handle_exception( self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info - ) -> Tuple[int, Optional[List[Message]]]: + ) -> tuple[int, list[Message] | None]: """ 默认异常处理函数,决定是否重试。 @@ -532,12 +533,12 @@ class _RequestExecutor: logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}") return -1, None else: - logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {str(e)}") + logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {e!s}") return -1, None def _handle_resp_not_ok( self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info - ) -> Tuple[int, Optional[List[Message]]]: + ) -> tuple[int, list[Message] | None]: """ 处理非200的HTTP响应异常。 @@ -583,7 +584,7 @@ class _RequestExecutor: logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}") return -1, None - def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> Tuple[int, None]: + def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]: """ 辅助函数,根据剩余次数决定是否进行下一次重试。 @@ -620,7 +621,7 @@ class _RequestStrategy: model_selector: _ModelSelector, prompt_processor: _PromptProcessor, executor: _RequestExecutor, - model_list: List[str], + model_list: list[str], task_name: str, ): """ @@ -644,13 +645,13 @@ class _RequestStrategy: request_type: RequestType, raise_when_empty: bool = True, **kwargs, - ) -> Tuple[APIResponse, ModelInfo]: + ) -> tuple[APIResponse, ModelInfo]: """ 执行请求,动态选择最佳可用模型,并在模型失败时进行故障转移。 """ failed_models_in_this_request = set() max_attempts = len(self.model_list) - last_exception: Optional[Exception] = None + last_exception: Exception | None = None for attempt in range(max_attempts): selection_result = self.model_selector.select_best_available_model( @@ -787,9 +788,7 @@ class LLMRequest: """ self.task_name = request_type self.model_for_task = model_set - self.model_usage: Dict[str, Tuple[int, int, int]] = { - model: (0, 0, 0) for model in self.model_for_task.model_list - } + self.model_usage: dict[str, tuple[int, int, int]] = dict.fromkeys(self.model_for_task.model_list, (0, 0, 0)) """模型使用量记录,(total_tokens, penalty, usage_penalty)""" # 初始化辅助类 @@ -805,9 +804,9 @@ class LLMRequest: prompt: str, image_base64: str, image_format: str, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + temperature: float | None = None, + max_tokens: int | None = None, + ) -> tuple[str, tuple[str, str, list[ToolCall] | None]]: """ 为图像生成响应。 @@ -855,7 +854,7 @@ class LLMRequest: return content, (reasoning, model_info.name, response.tool_calls) - async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: + async def generate_response_for_voice(self, voice_base64: str) -> str | None: """ 为语音生成响应(语音转文字)。 使用故障转移策略来确保即使主模型失败也能获得结果。 @@ -872,11 +871,11 @@ class LLMRequest: async def generate_response_async( self, prompt: str, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, + temperature: float | None = None, + max_tokens: int | None = None, + tools: list[dict[str, Any]] | None = None, raise_when_empty: bool = True, - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + ) -> tuple[str, tuple[str, str, list[ToolCall] | None]]: """ 异步生成响应,支持并发请求。 @@ -914,11 +913,11 @@ class LLMRequest: async def _execute_single_text_request( self, prompt: str, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, + temperature: float | None = None, + max_tokens: int | None = None, + tools: list[dict[str, Any]] | None = None, raise_when_empty: bool = True, - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + ) -> tuple[str, tuple[str, str, list[ToolCall] | None]]: """ 执行单次文本生成请求的内部方法。 这是 `generate_response_async` 的核心实现,处理单个请求的完整生命周期, @@ -956,7 +955,7 @@ class LLMRequest: return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls) - async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: + async def get_embedding(self, embedding_input: str) -> tuple[list[float], str]: """ 获取嵌入向量。 @@ -978,7 +977,7 @@ class LLMRequest: return response.embedding, model_info.name - async def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str): + async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str): """ 记录模型使用情况。 @@ -1009,7 +1008,7 @@ class LLMRequest: ) @staticmethod - def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: + def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None: """ 根据输入的字典列表构建并验证 `ToolOption` 对象列表。 @@ -1028,7 +1027,7 @@ class LLMRequest: if not tools: return None - tool_options: List[ToolOption] = [] + tool_options: list[ToolOption] = [] # 遍历每个工具定义 for tool in tools: try: diff --git a/src/main.py b/src/main.py index 4e91f1419..914647508 100644 --- a/src/main.py +++ b/src/main.py @@ -1,40 +1,40 @@ # 再用这个就写一行注释来混提交的我直接全部🌿飞😡 import asyncio -import time import signal import sys -from functools import partial +import time import traceback -from typing import Dict, Any +from functools import partial +from typing import Any from maim_message import MessageServer - -from src.common.remote import TelemetryHeartBeatTask -from src.manager.async_task_manager import async_task_manager -from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask -from src.chat.emoji_system.emoji_manager import get_emoji_manager -from src.chat.message_receive.chat_stream import get_chat_manager -from src.config.config import global_config -from src.chat.message_receive.bot import chat_bot -from src.common.logger import get_logger -from src.individuality.individuality import get_individuality, Individuality -from src.common.server import get_global_server, Server -from src.mood.mood_manager import mood_manager from rich.traceback import install -from src.schedule.schedule_manager import schedule_manager -from src.schedule.monthly_plan_manager import monthly_plan_manager -from src.plugin_system.core.event_manager import event_manager -from src.plugin_system.base.component_types import EventType -# from src.api.main import start_api_server -# 导入新的插件管理器 -from src.plugin_system.core.plugin_manager import plugin_manager - -# 导入消息API和traceback模块 -from src.common.message import get_global_api +from src.chat.emoji_system.emoji_manager import get_emoji_manager # 导入增强记忆系统管理器 from src.chat.memory_system.memory_manager import memory_manager +from src.chat.message_receive.bot import chat_bot +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask +from src.common.logger import get_logger + +# 导入消息API和traceback模块 +from src.common.message import get_global_api +from src.common.remote import TelemetryHeartBeatTask +from src.common.server import Server, get_global_server +from src.config.config import global_config +from src.individuality.individuality import Individuality, get_individuality +from src.manager.async_task_manager import async_task_manager +from src.mood.mood_manager import mood_manager +from src.plugin_system.base.component_types import EventType +from src.plugin_system.core.event_manager import event_manager + +# from src.api.main import start_api_server +# 导入新的插件管理器 +from src.plugin_system.core.plugin_manager import plugin_manager +from src.schedule.monthly_plan_manager import monthly_plan_manager +from src.schedule.schedule_manager import schedule_manager # 插件系统现在使用统一的插件加载器 @@ -115,8 +115,8 @@ class MainSystem: # 停止消息重组器 try: - from src.plugin_system.core.event_manager import event_manager from src.plugin_system import EventType + from src.plugin_system.core.event_manager import event_manager from src.utils.message_chunker import reassembler await event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM") @@ -151,7 +151,7 @@ class MainSystem: except Exception as e: logger.error(f"同步清理资源时出错: {e}") - async def _message_process_wrapper(self, message_data: Dict[str, Any]): + async def _message_process_wrapper(self, message_data: dict[str, Any]): """并行处理消息的包装器""" try: start_time = time.time() @@ -225,8 +225,8 @@ MoFox_Bot(第三方修改版) event_manager.init_default_events() # 初始化权限管理器 - from src.plugin_system.core.permission_manager import PermissionManager from src.plugin_system.apis.permission_api import permission_api + from src.plugin_system.core.permission_manager import PermissionManager permission_manager = PermissionManager() await permission_manager.initialize() diff --git a/src/mais4u/mai_think.py b/src/mais4u/mai_think.py index 4c34c4798..6725e43db 100644 --- a/src/mais4u/mai_think.py +++ b/src/mais4u/mai_think.py @@ -1,12 +1,13 @@ -from src.chat.message_receive.chat_stream import get_chat_manager import time -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config + +from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.message import MessageRecvS4U -from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor -from src.mais4u.mais4u_chat.internal_manager import internal_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.common.logger import get_logger +from src.config.config import model_config +from src.llm_models.utils_model import LLMRequest +from src.mais4u.mais4u_chat.internal_manager import internal_manager +from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor logger = get_logger(__name__) diff --git a/src/mais4u/mais4u_chat/body_emotion_action_manager.py b/src/mais4u/mais4u_chat/body_emotion_action_manager.py index 38073baa4..423eeaf16 100644 --- a/src/mais4u/mais4u_chat/body_emotion_action_manager.py +++ b/src/mais4u/mais4u_chat/body_emotion_action_manager.py @@ -1,18 +1,18 @@ -import orjson import time +import orjson from json_repair import repair_json + from src.chat.message_receive.message import MessageRecv -from src.llm_models.utils_model import LLMRequest -from src.common.logger import get_logger from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive -from src.config.config import global_config, model_config from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest +from src.mais4u.s4u_config import s4u_config from src.manager.async_task_manager import AsyncTask, async_task_manager from src.plugin_system.apis import send_api -from src.mais4u.s4u_config import s4u_config - logger = get_logger("action") HEAD_CODE = { diff --git a/src/mais4u/mais4u_chat/context_web_manager.py b/src/mais4u/mais4u_chat/context_web_manager.py index 3bd107c55..422e6207b 100644 --- a/src/mais4u/mais4u_chat/context_web_manager.py +++ b/src/mais4u/mais4u_chat/context_web_manager.py @@ -1,10 +1,10 @@ import asyncio -import orjson from collections import deque from datetime import datetime -from typing import Dict, List, Optional -from aiohttp import web, WSMsgType + import aiohttp_cors +import orjson +from aiohttp import WSMsgType, web from src.chat.message_receive.message import MessageRecv from src.common.logger import get_logger @@ -57,8 +57,8 @@ class ContextWebManager: def __init__(self, max_messages: int = 10, port: int = 8765): self.max_messages = max_messages self.port = port - self.contexts: Dict[str, deque] = {} # chat_id -> deque of ContextMessage - self.websockets: List[web.WebSocketResponse] = [] + self.contexts: dict[str, deque] = {} # chat_id -> deque of ContextMessage + self.websockets: list[web.WebSocketResponse] = [] self.app = None self.runner = None self.site = None @@ -674,7 +674,7 @@ class ContextWebManager: # 全局实例 -_context_web_manager: Optional[ContextWebManager] = None +_context_web_manager: ContextWebManager | None = None def get_context_web_manager() -> ContextWebManager: diff --git a/src/mais4u/mais4u_chat/gift_manager.py b/src/mais4u/mais4u_chat/gift_manager.py index d489550c3..976476225 100644 --- a/src/mais4u/mais4u_chat/gift_manager.py +++ b/src/mais4u/mais4u_chat/gift_manager.py @@ -1,5 +1,5 @@ import asyncio -from typing import Dict, Tuple, Callable, Optional +from collections.abc import Callable from dataclasses import dataclass from src.chat.message_receive.message import MessageRecvS4U @@ -23,11 +23,11 @@ class GiftManager: def __init__(self): """初始化礼物管理器""" - self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {} + self.pending_gifts: dict[tuple[str, str], PendingGift] = {} self.debounce_timeout = 5.0 # 3秒防抖时间 async def handle_gift( - self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None + self, message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None = None ) -> bool: """处理礼物消息,返回是否应该立即处理 @@ -53,7 +53,7 @@ class GiftManager: await self._create_pending_gift(gift_key, message, callback) return False - async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None: + async def _merge_gift(self, gift_key: tuple[str, str], new_message: MessageRecvS4U) -> None: """合并礼物消息""" pending_gift = self.pending_gifts[gift_key] @@ -81,7 +81,7 @@ class GiftManager: logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}") async def _create_pending_gift( - self, gift_key: Tuple[str, str], message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] + self, gift_key: tuple[str, str], message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None ) -> None: """创建新的等待礼物""" try: @@ -100,7 +100,7 @@ class GiftManager: logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}") - async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None: + async def _gift_timeout(self, gift_key: tuple[str, str]) -> None: """礼物防抖超时处理""" try: # 等待防抖时间 diff --git a/src/mais4u/mais4u_chat/internal_manager.py b/src/mais4u/mais4u_chat/internal_manager.py index 4b3db3263..3e4a518d4 100644 --- a/src/mais4u/mais4u_chat/internal_manager.py +++ b/src/mais4u/mais4u_chat/internal_manager.py @@ -1,6 +1,6 @@ class InternalManager: def __init__(self): - self.now_internal_state = str() + self.now_internal_state = "" def set_internal_state(self, internal_state: str): self.now_internal_state = internal_state diff --git a/src/mais4u/mais4u_chat/s4u_chat.py b/src/mais4u/mais4u_chat/s4u_chat.py index 192e858b6..80bd91e22 100644 --- a/src/mais4u/mais4u_chat/s4u_chat.py +++ b/src/mais4u/mais4u_chat/s4u_chat.py @@ -1,25 +1,27 @@ import asyncio -import traceback -import time import random -from typing import Optional, Dict, Tuple, List # 导入类型提示 -from maim_message import UserInfo, Seg -from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from .s4u_stream_generator import S4UStreamGenerator -from src.chat.message_receive.message import MessageSending, MessageRecv, MessageRecvS4U -from src.config.config import global_config -from src.common.message.api import get_global_api -from src.chat.message_receive.storage import MessageStorage -from .s4u_watching_manager import watching_manager +import time +import traceback + import orjson -from .s4u_mood_manager import mood_manager -from src.person_info.relationship_builder_manager import relationship_builder_manager +from maim_message import Seg, UserInfo + +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.message import MessageRecv, MessageRecvS4U, MessageSending +from src.chat.message_receive.storage import MessageStorage +from src.common.logger import get_logger +from src.common.message.api import get_global_api +from src.config.config import global_config +from src.mais4u.constant_s4u import ENABLE_S4U from src.mais4u.s4u_config import s4u_config from src.person_info.person_info import PersonInfoManager +from src.person_info.relationship_builder_manager import relationship_builder_manager + +from .s4u_mood_manager import mood_manager +from .s4u_stream_generator import S4UStreamGenerator +from .s4u_watching_manager import watching_manager from .super_chat_manager import get_super_chat_manager from .yes_or_no import yes_or_no_head -from src.mais4u.constant_s4u import ENABLE_S4U logger = get_logger("S4U_chat") @@ -32,7 +34,7 @@ class MessageSenderContainer: self.original_message = original_message self.queue = asyncio.Queue() self.storage = MessageStorage() - self._task: Optional[asyncio.Task] = None + self._task: asyncio.Task | None = None self._paused_event = asyncio.Event() self._paused_event.set() # 默认设置为非暂停状态 @@ -158,7 +160,7 @@ class MessageSenderContainer: class S4UChatManager: def __init__(self): - self.s4u_chats: Dict[str, "S4UChat"] = {} + self.s4u_chats: dict[str, "S4UChat"] = {} def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat": if chat_stream.stream_id not in self.s4u_chats: @@ -196,16 +198,16 @@ class S4UChat: self._new_message_event = asyncio.Event() # 用于唤醒处理器 self._processing_task = asyncio.create_task(self._message_processor()) - self._current_generation_task: Optional[asyncio.Task] = None + self._current_generation_task: asyncio.Task | None = None # 当前消息的元数据:(队列类型, 优先级分数, 计数器, 消息对象) - self._current_message_being_replied: Optional[Tuple[str, float, int, MessageRecv]] = None + self._current_message_being_replied: tuple[str, float, int, MessageRecv] | None = None self._is_replying = False self.gpt = S4UStreamGenerator() self.gpt.chat_stream = self.chat_stream - self.interest_dict: Dict[str, float] = {} # 用户兴趣分 + self.interest_dict: dict[str, float] = {} # 用户兴趣分 - self.internal_message: List[MessageRecvS4U] = [] + self.internal_message: list[MessageRecvS4U] = [] self.msg_id = "" self.voice_done = "" diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py index d235843d4..2031f7c56 100644 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py @@ -1,16 +1,17 @@ import asyncio -import orjson import time +import orjson + from src.chat.message_receive.message import MessageRecv -from src.llm_models.utils_model import LLMRequest -from src.common.logger import get_logger from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive -from src.config.config import global_config, model_config from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest +from src.mais4u.constant_s4u import ENABLE_S4U from src.manager.async_task_manager import AsyncTask, async_task_manager from src.plugin_system.apis import send_api -from src.mais4u.constant_s4u import ENABLE_S4U """ 情绪管理系统使用说明: diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py index ba8ee54eb..c7b855394 100644 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -1,33 +1,33 @@ import asyncio import math -from typing import Tuple + +from maim_message.message_base import GroupInfo + +from src.chat.message_receive.chat_stream import get_chat_manager # 旧的Hippocampus系统已被移除,现在使用增强记忆系统 # from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager from src.chat.message_receive.message import MessageRecv, MessageRecvS4U -from maim_message.message_base import GroupInfo from src.chat.message_receive.storage import MessageStorage -from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.timer_calculator import Timer from src.chat.utils.utils import is_mentioned_bot_in_message from src.common.logger import get_logger from src.config.config import global_config from src.mais4u.mais4u_chat.body_emotion_action_manager import action_manager -from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager -from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager from src.mais4u.mais4u_chat.gift_manager import gift_manager +from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager +from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager from src.mais4u.mais4u_chat.screen_manager import screen_manager from .s4u_chat import get_s4u_chat_manager - # from ..message_receive.message_buffer import message_buffer logger = get_logger("chat") -async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: +async def _calculate_interest(message: MessageRecv) -> tuple[float, bool]: """计算消息的兴趣度 Args: diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index 1c8782d23..b53a8b3f6 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -1,25 +1,27 @@ -from src.config.config import global_config -from src.common.logger import get_logger -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat -import time -from src.chat.utils.utils import get_recent_group_speaker +import asyncio # 旧的Hippocampus系统已被移除,现在使用增强记忆系统 # from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager import random +import time from datetime import datetime -import asyncio -from src.mais4u.s4u_config import s4u_config -from src.chat.message_receive.message import MessageRecvS4U -from src.person_info.relationship_fetcher import relationship_fetcher_manager -from src.person_info.person_info import PersonInfoManager, get_person_info_manager -from src.chat.message_receive.chat_stream import ChatStream -from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager -from src.mais4u.mais4u_chat.screen_manager import screen_manager + from src.chat.express.expression_selector import expression_selector -from .s4u_mood_manager import mood_manager +from src.chat.message_receive.chat_stream import ChatStream +from src.chat.message_receive.message import MessageRecvS4U +from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.chat.utils.utils import get_recent_group_speaker +from src.common.logger import get_logger +from src.config.config import global_config from src.mais4u.mais4u_chat.internal_manager import internal_manager +from src.mais4u.mais4u_chat.screen_manager import screen_manager +from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager +from src.mais4u.s4u_config import s4u_config +from src.person_info.person_info import PersonInfoManager, get_person_info_manager +from src.person_info.relationship_fetcher import relationship_fetcher_manager + +from .s4u_mood_manager import mood_manager logger = get_logger("prompt") @@ -206,7 +208,7 @@ class PromptBuilder: limit=300, ) - talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}" + talk_type = f"{message.message_info.platform}:{message.chat_stream.user_info.user_id!s}" core_dialogue_list = [] background_dialogue_list = [] diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py index d4ec70edd..3f2ac4a80 100644 --- a/src/mais4u/mais4u_chat/s4u_stream_generator.py +++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py @@ -1,12 +1,12 @@ -from typing import AsyncGenerator -from src.mais4u.openai_client import AsyncOpenAIClient -from src.config.config import model_config -from src.chat.message_receive.message import MessageRecvS4U -from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder -from src.common.logger import get_logger import asyncio import re +from collections.abc import AsyncGenerator +from src.chat.message_receive.message import MessageRecvS4U +from src.common.logger import get_logger +from src.config.config import model_config +from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder +from src.mais4u.openai_client import AsyncOpenAIClient logger = get_logger("s4u_stream_generator") @@ -99,7 +99,7 @@ class S4UStreamGenerator: logger.info( f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}" - ) # noqa: E501 + ) current_client = self.client_1 self.current_model_name = self.model_1_name diff --git a/src/mais4u/mais4u_chat/screen_manager.py b/src/mais4u/mais4u_chat/screen_manager.py index 996e63990..60a7f914d 100644 --- a/src/mais4u/mais4u_chat/screen_manager.py +++ b/src/mais4u/mais4u_chat/screen_manager.py @@ -1,6 +1,6 @@ class ScreenManager: def __init__(self): - self.now_screen = str() + self.now_screen = "" def set_screen(self, screen_str: str): self.now_screen = screen_str diff --git a/src/mais4u/mais4u_chat/super_chat_manager.py b/src/mais4u/mais4u_chat/super_chat_manager.py index 5f0ee2ac2..df6245746 100644 --- a/src/mais4u/mais4u_chat/super_chat_manager.py +++ b/src/mais4u/mais4u_chat/super_chat_manager.py @@ -1,9 +1,9 @@ import asyncio import time from dataclasses import dataclass -from typing import Dict, List, Optional -from src.common.logger import get_logger + from src.chat.message_receive.message import MessageRecvS4U +from src.common.logger import get_logger # 全局SuperChat管理器实例 from src.mais4u.constant_s4u import ENABLE_S4U @@ -23,7 +23,7 @@ class SuperChatRecord: message_text: str timestamp: float expire_time: float - group_name: Optional[str] = None + group_name: str | None = None def is_expired(self) -> bool: """检查SuperChat是否已过期""" @@ -53,8 +53,8 @@ class SuperChatManager: """SuperChat管理器,负责管理和跟踪SuperChat消息""" def __init__(self): - self.super_chats: Dict[str, List[SuperChatRecord]] = {} # chat_id -> SuperChat列表 - self._cleanup_task: Optional[asyncio.Task] = None + self.super_chats: dict[str, list[SuperChatRecord]] = {} # chat_id -> SuperChat列表 + self._cleanup_task: asyncio.Task | None = None self._is_initialized = False logger.info("SuperChat管理器已初始化") @@ -186,7 +186,7 @@ class SuperChatManager: logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}") - def get_superchats_by_chat(self, chat_id: str) -> List[SuperChatRecord]: + def get_superchats_by_chat(self, chat_id: str) -> list[SuperChatRecord]: """获取指定聊天的所有有效SuperChat""" # 确保清理任务已启动 self._ensure_cleanup_task_started() @@ -198,7 +198,7 @@ class SuperChatManager: valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()] return valid_superchats - def get_all_valid_superchats(self) -> Dict[str, List[SuperChatRecord]]: + def get_all_valid_superchats(self) -> dict[str, list[SuperChatRecord]]: """获取所有有效的SuperChat""" # 确保清理任务已启动 self._ensure_cleanup_task_started() diff --git a/src/mais4u/mais4u_chat/yes_or_no.py b/src/mais4u/mais4u_chat/yes_or_no.py index c71c160d3..51fba0416 100644 --- a/src/mais4u/mais4u_chat/yes_or_no.py +++ b/src/mais4u/mais4u_chat/yes_or_no.py @@ -1,6 +1,6 @@ -from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger from src.config.config import model_config +from src.llm_models.utils_model import LLMRequest from src.plugin_system.apis import send_api logger = get_logger(__name__) diff --git a/src/mais4u/openai_client.py b/src/mais4u/openai_client.py index 2a5873dec..6f5e0484e 100644 --- a/src/mais4u/openai_client.py +++ b/src/mais4u/openai_client.py @@ -1,5 +1,6 @@ -from typing import AsyncGenerator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator from dataclasses import dataclass + from openai import AsyncOpenAI from openai.types.chat import ChatCompletion, ChatCompletionChunk @@ -11,14 +12,14 @@ class ChatMessage: role: str content: str - def to_dict(self) -> Dict[str, str]: + def to_dict(self) -> dict[str, str]: return {"role": self.role, "content": self.content} class AsyncOpenAIClient: """异步OpenAI客户端,支持流式传输""" - def __init__(self, api_key: str, base_url: Optional[str] = None): + def __init__(self, api_key: str, base_url: str | None = None): """ 初始化客户端 @@ -34,10 +35,10 @@ class AsyncOpenAIClient: async def chat_completion( self, - messages: List[Union[ChatMessage, Dict[str, str]]], + messages: list[ChatMessage | dict[str, str]], model: str = "gpt-3.5-turbo", temperature: float = 0.7, - max_tokens: Optional[int] = None, + max_tokens: int | None = None, **kwargs, ) -> ChatCompletion: """ @@ -81,10 +82,10 @@ class AsyncOpenAIClient: async def chat_completion_stream( self, - messages: List[Union[ChatMessage, Dict[str, str]]], + messages: list[ChatMessage | dict[str, str]], model: str = "gpt-3.5-turbo", temperature: float = 0.7, - max_tokens: Optional[int] = None, + max_tokens: int | None = None, **kwargs, ) -> AsyncGenerator[ChatCompletionChunk, None]: """ @@ -129,10 +130,10 @@ class AsyncOpenAIClient: async def get_stream_content( self, - messages: List[Union[ChatMessage, Dict[str, str]]], + messages: list[ChatMessage | dict[str, str]], model: str = "gpt-3.5-turbo", temperature: float = 0.7, - max_tokens: Optional[int] = None, + max_tokens: int | None = None, **kwargs, ) -> AsyncGenerator[str, None]: """ @@ -156,10 +157,10 @@ class AsyncOpenAIClient: async def collect_stream_response( self, - messages: List[Union[ChatMessage, Dict[str, str]]], + messages: list[ChatMessage | dict[str, str]], model: str = "gpt-3.5-turbo", temperature: float = 0.7, - max_tokens: Optional[int] = None, + max_tokens: int | None = None, **kwargs, ) -> str: """ @@ -199,7 +200,7 @@ class AsyncOpenAIClient: class ConversationManager: """对话管理器,用于管理对话历史""" - def __init__(self, client: AsyncOpenAIClient, system_prompt: Optional[str] = None): + def __init__(self, client: AsyncOpenAIClient, system_prompt: str | None = None): """ 初始化对话管理器 @@ -208,7 +209,7 @@ class ConversationManager: system_prompt: 系统提示词 """ self.client = client - self.messages: List[ChatMessage] = [] + self.messages: list[ChatMessage] = [] if system_prompt: self.messages.append(ChatMessage(role="system", content=system_prompt)) @@ -281,6 +282,6 @@ class ConversationManager: """获取消息数量""" return len(self.messages) - def get_conversation_history(self) -> List[Dict[str, str]]: + def get_conversation_history(self) -> list[dict[str, str]]: """获取对话历史""" return [msg.to_dict() for msg in self.messages] diff --git a/src/mais4u/s4u_config.py b/src/mais4u/s4u_config.py index 79a8f92c4..f42e871bc 100644 --- a/src/mais4u/s4u_config.py +++ b/src/mais4u/s4u_config.py @@ -1,13 +1,16 @@ import os -import tomlkit import shutil +from dataclasses import MISSING, dataclass, field, fields from datetime import datetime +from typing import Any, Literal, TypeVar, get_args, get_origin + +import tomlkit from tomlkit import TOMLDocument from tomlkit.items import Table -from dataclasses import dataclass, fields, MISSING, field -from typing import TypeVar, Type, Any, get_origin, get_args, Literal -from src.mais4u.constant_s4u import ENABLE_S4U +from typing_extensions import Self + from src.common.logger import get_logger +from src.mais4u.constant_s4u import ENABLE_S4U logger = get_logger("s4u_config") @@ -46,7 +49,7 @@ class S4UConfigBase: """S4U配置类的基类""" @classmethod - def from_dict(cls: Type[T], data: dict[str, Any]) -> T: + def from_dict(cls, data: dict[str, Any]) -> Self: """从字典加载配置字段""" data = table_to_dict(data) # 递归转dict,兼容tomlkit Table if not is_dict_like(data): @@ -81,7 +84,7 @@ class S4UConfigBase: return cls() @classmethod - def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any: + def _convert_field(cls, value: Any, field_type: type[Any]) -> Any: """转换字段值为指定类型""" # 如果是嵌套的 dataclass,递归调用 from_dict 方法 if isinstance(field_type, type) and issubclass(field_type, S4UConfigBase): @@ -271,9 +274,9 @@ def update_s4u_config(): return # 读取旧配置文件和模板文件 - with open(CONFIG_PATH, "r", encoding="utf-8") as f: + with open(CONFIG_PATH, encoding="utf-8") as f: old_config = tomlkit.load(f) - with open(TEMPLATE_PATH, "r", encoding="utf-8") as f: + with open(TEMPLATE_PATH, encoding="utf-8") as f: new_config = tomlkit.load(f) # 检查version是否相同 @@ -344,7 +347,7 @@ def load_s4u_config(config_path: str) -> S4UGlobalConfig: :return: S4UGlobalConfig对象 """ # 读取配置文件 - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: config_data = tomlkit.load(f) # 创建S4UGlobalConfig对象 diff --git a/src/manager/async_task_manager.py b/src/manager/async_task_manager.py index 92f6675bd..157849381 100644 --- a/src/manager/async_task_manager.py +++ b/src/manager/async_task_manager.py @@ -1,8 +1,7 @@ -from abc import abstractmethod, ABCMeta - import asyncio -from asyncio import Task, Event, Lock -from typing import Callable, Dict +from abc import ABCMeta, abstractmethod +from asyncio import Event, Lock, Task +from collections.abc import Callable from src.common.logger import get_logger @@ -46,7 +45,7 @@ class AsyncTaskManager: """异步任务管理器""" def __init__(self): - self.tasks: Dict[str, Task] = {} + self.tasks: dict[str, Task] = {} """任务列表""" self.abort_flag: Event = Event() @@ -116,7 +115,7 @@ class AsyncTaskManager: self.tasks[task.task_name] = task_inst # 将任务添加到任务列表 logger.debug(f"已启动任务 '{task.task_name}'") - def get_tasks_status(self) -> Dict[str, Dict[str, str]]: + def get_tasks_status(self) -> dict[str, dict[str, str]]: """ 获取所有任务的状态 """ diff --git a/src/manager/local_store_manager.py b/src/manager/local_store_manager.py index 63d191ef1..f5b5a28ca 100644 --- a/src/manager/local_store_manager.py +++ b/src/manager/local_store_manager.py @@ -1,6 +1,7 @@ -import orjson import os +import orjson + from src.common.logger import get_logger LOCAL_STORE_FILE_PATH = "data/local_store.json" @@ -24,7 +25,7 @@ class LocalStoreManager: """获取本地存储数据""" return self.store.get(item) - def __setitem__(self, key: str, value: str | list | dict | int | float | bool): + def __setitem__(self, key: str, value: str | list | dict | float | bool): """设置本地存储数据""" self.store[key] = value self.save_local_store() @@ -48,7 +49,7 @@ class LocalStoreManager: logger.info("正在阅读记事本......我在看,我真的在看!") logger.debug(f"加载本地存储数据: {self.file_path}") try: - with open(self.file_path, "r", encoding="utf-8") as f: + with open(self.file_path, encoding="utf-8") as f: self.store = orjson.loads(f.read()) logger.info("全都记起来了!") except orjson.JSONDecodeError: diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index 66fcee96f..76f8a547e 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -2,17 +2,16 @@ import math import random import time +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.message import MessageRecv +from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config, model_config -from src.chat.message_receive.message import MessageRecv -from src.common.data_models.database_data_model import DatabaseMessages -from src.chat.message_receive.chat_stream import get_chat_manager -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.llm_models.utils_model import LLMRequest from src.manager.async_task_manager import AsyncTask, async_task_manager - logger = get_logger("mood") diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 478d4c9fb..afde489dc 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -2,7 +2,8 @@ import copy import datetime import hashlib import time -from typing import Any, Callable, Dict, Union, Optional +from collections.abc import Callable +from typing import Any import orjson from json_repair import repair_json @@ -86,7 +87,7 @@ class PersonInfoManager: logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}") @staticmethod - def get_person_id(platform: str, user_id: Union[int, str]) -> str: + def get_person_id(platform: str, user_id: int | str) -> str: """获取唯一id(同步) 说明: 原来该方法为异步并在内部尝试执行数据库检查/迁移,导致在许多调用处未 await 时返回 coroutine 对象。 @@ -167,7 +168,7 @@ class PersonInfoManager: ) @staticmethod - async def create_person_info(person_id: str, data: Optional[dict] = None): + async def create_person_info(person_id: str, data: dict | None = None): """创建一个项""" if not person_id: logger.debug("创建失败,person_id不存在") @@ -228,7 +229,7 @@ class PersonInfoManager: await _db_create_async(final_data) @staticmethod - async def _safe_create_person_info(person_id: str, data: Optional[dict] = None): + async def _safe_create_person_info(person_id: str, data: dict | None = None): """安全地创建用户信息,处理竞态条件""" if not person_id: logger.debug("创建失败,person_id不存在") @@ -296,7 +297,7 @@ class PersonInfoManager: await _db_safe_create_async(final_data) - async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None): + async def update_one_field(self, person_id: str, field_name: str, value, data: dict | None = None): """更新某一个字段,会补全""" # 获取 SQLAlchemy 模型的所有字段名 model_fields = [column.name for column in PersonInfo.__table__.columns] @@ -628,7 +629,7 @@ class PersonInfoManager: async def get_specific_value_list( field_name: str, way: Callable[[Any], bool], - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ 获取满足条件的字段值字典 """ @@ -649,18 +650,18 @@ class PersonInfoManager: found_results[record.person_id] = value except Exception as e_query: logger.error( - f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True + f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {e_query!s}", exc_info=True ) return found_results try: return await _db_get_specific_async(field_name) except Exception as e: - logger.error(f"执行 get_specific_value_list 时出错: {str(e)}", exc_info=True) + logger.error(f"执行 get_specific_value_list 时出错: {e!s}", exc_info=True) return {} async def get_or_create_person( - self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None + self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: str | None = None ) -> str: """ 根据 platform 和 user_id 获取 person_id。 diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index 4dc478f6c..10f1d3d97 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -1,20 +1,21 @@ -import time -import traceback import os import pickle import random -from typing import List, Dict, Any -from src.config.config import global_config -from src.common.logger import get_logger -from src.person_info.relationship_manager import get_relationship_manager -from src.person_info.person_info import get_person_info_manager, PersonInfoManager +import time +import traceback +from typing import Any + from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import ( + get_raw_msg_before_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat_inclusive, - get_raw_msg_before_timestamp_with_chat, num_new_messages_since, ) +from src.common.logger import get_logger +from src.config.config import global_config +from src.person_info.person_info import PersonInfoManager, get_person_info_manager +from src.person_info.relationship_manager import get_relationship_manager logger = get_logger("relationship_builder") @@ -45,7 +46,7 @@ class RelationshipBuilder: self.chat_id = chat_id # 新的消息段缓存结构: # {person_id: [{"start_time": float, "end_time": float, "last_msg_time": float, "message_count": int}, ...]} - self.person_engaged_cache: Dict[str, List[Dict[str, Any]]] = {} + self.person_engaged_cache: dict[str, list[dict[str, Any]]] = {} # 持久化存储文件路径 self.cache_file_path = os.path.join("data", "relationship", f"relationship_cache_{self.chat_id}.pkl") @@ -401,7 +402,7 @@ class RelationshipBuilder: # 负责触发关系构建、整合消息段、更新用户印象 # ================================ - async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]): + async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: list[dict[str, Any]]): """基于消息段更新用户印象""" original_segment_count = len(segments) logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象") diff --git a/src/person_info/relationship_builder_manager.py b/src/person_info/relationship_builder_manager.py index f3bca25d2..61cad42e2 100644 --- a/src/person_info/relationship_builder_manager.py +++ b/src/person_info/relationship_builder_manager.py @@ -1,6 +1,7 @@ -from typing import Dict, Optional, List, Any +from typing import Any from src.common.logger import get_logger + from .relationship_builder import RelationshipBuilder logger = get_logger("relationship_builder_manager") @@ -13,7 +14,7 @@ class RelationshipBuilderManager: """ def __init__(self): - self.builders: Dict[str, RelationshipBuilder] = {} + self.builders: dict[str, RelationshipBuilder] = {} def get_or_create_builder(self, chat_id: str) -> RelationshipBuilder: """获取或创建关系构建器 @@ -30,7 +31,7 @@ class RelationshipBuilderManager: return self.builders[chat_id] - def get_builder(self, chat_id: str) -> Optional[RelationshipBuilder]: + def get_builder(self, chat_id: str) -> RelationshipBuilder | None: """获取关系构建器 Args: @@ -56,7 +57,7 @@ class RelationshipBuilderManager: return True return False - def get_all_chat_ids(self) -> List[str]: + def get_all_chat_ids(self) -> list[str]: """获取所有管理的聊天ID列表 Returns: @@ -64,7 +65,7 @@ class RelationshipBuilderManager: """ return list(self.builders.keys()) - def get_status(self) -> Dict[str, Any]: + def get_status(self) -> dict[str, Any]: """获取管理器状态 Returns: diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 90a353291..b0835fcb4 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -1,18 +1,17 @@ import time import traceback -import orjson +from typing import Any -from typing import List, Dict, Any +import orjson from json_repair import repair_json +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.message_receive.chat_stream import get_chat_manager from src.person_info.person_info import get_person_info_manager - logger = get_logger("relationship_fetcher") @@ -64,10 +63,10 @@ class RelationshipFetcher: self.chat_id = chat_id # 信息获取缓存:记录正在获取的信息请求 - self.info_fetching_cache: List[Dict[str, Any]] = [] + self.info_fetching_cache: list[dict[str, Any]] = [] # 信息结果缓存:存储已获取的信息结果,带TTL - self.info_fetched_cache: Dict[str, Dict[str, Any]] = {} + self.info_fetched_cache: dict[str, dict[str, Any]] = {} # 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}} # LLM模型配置 @@ -471,7 +470,7 @@ class RelationshipFetcherManager: """ def __init__(self): - self._fetchers: Dict[str, RelationshipFetcher] = {} + self._fetchers: dict[str, RelationshipFetcher] = {} def get_fetcher(self, chat_id: str) -> RelationshipFetcher: """获取或创建指定 chat_id 的 RelationshipFetcher @@ -499,7 +498,7 @@ class RelationshipFetcherManager: """清空所有 RelationshipFetcher""" self._fetchers.clear() - def get_active_chat_ids(self) -> List[str]: + def get_active_chat_ids(self) -> list[str]: """获取所有活跃的 chat_id 列表""" return list(self._fetchers.keys()) diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index a6ce8ab02..7792798f1 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -1,18 +1,21 @@ -from src.common.logger import get_logger -from .person_info import PersonInfoManager, get_person_info_manager -import time import random -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.chat.utils.chat_message_builder import build_readable_messages -import orjson -from json_repair import repair_json +import time from datetime import datetime from difflib import SequenceMatcher +from typing import Any + import jieba +import orjson +from json_repair import repair_json from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity -from typing import List, Dict, Any + +from src.chat.utils.chat_message_builder import build_readable_messages +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest + +from .person_info import PersonInfoManager, get_person_info_manager logger = get_logger("relation") @@ -54,7 +57,7 @@ class RelationshipManager: # person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar # ) - async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]): + async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: list[dict[str, Any]]): """更新用户印象 Args: diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index ae66a9803..9a3bb85d6 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -5,33 +5,49 @@ MaiBot 插件系统 """ # 导出主要的公共接口 +from .apis import ( + chat_api, + component_manage_api, + config_api, + database_api, + emoji_api, + generator_api, + get_logger, + llm_api, + message_api, + person_api, + plugin_manage_api, + register_plugin, + send_api, + tool_api, +) from .base import ( - BasePlugin, + ActionActivationType, + ActionInfo, BaseAction, BaseCommand, - BaseTool, - ConfigField, - ComponentType, - ActionActivationType, - ChatMode, - ComponentInfo, - ActionInfo, - CommandInfo, - PlusCommandInfo, - PluginInfo, - ToolInfo, - PythonDependency, BaseEventHandler, + BasePlugin, + BaseTool, + ChatMode, + ChatType, + CommandArgs, + CommandInfo, + ComponentInfo, + ComponentType, + ConfigField, EventHandlerInfo, EventType, MaiMessages, - ToolParamType, + PluginInfo, # 新增的增强命令系统 PlusCommand, - CommandArgs, PlusCommandAdapter, + PlusCommandInfo, + PythonDependency, + ToolInfo, + ToolParamType, create_plus_command_adapter, - ChatType, ) # 导入工具模块 @@ -41,28 +57,10 @@ from .utils import ( # validate_plugin_manifest, # generate_plugin_manifest, ) +from .utils.dependency_config import configure_dependency_settings, get_dependency_config # 导入依赖管理模块 -from .utils.dependency_manager import get_dependency_manager, configure_dependency_manager -from .utils.dependency_config import get_dependency_config, configure_dependency_settings - -from .apis import ( - chat_api, - tool_api, - component_manage_api, - config_api, - database_api, - emoji_api, - generator_api, - llm_api, - message_api, - person_api, - plugin_manage_api, - send_api, - register_plugin, - get_logger, -) - +from .utils.dependency_manager import configure_dependency_manager, get_dependency_manager __version__ = "2.0.0" diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py index c80c5942c..cc67b9348 100644 --- a/src/plugin_system/apis/__init__.py +++ b/src/plugin_system/apis/__init__.py @@ -14,14 +14,15 @@ from src.plugin_system.apis import ( generator_api, llm_api, message_api, + permission_api, person_api, plugin_manage_api, + schedule_api, send_api, tool_api, - permission_api, - schedule_api, ) from src.plugin_system.apis.chat_api import ChatManager as context_api + from .logging_api import get_logger from .plugin_register_api import register_plugin @@ -30,18 +31,18 @@ __all__ = [ "chat_api", "component_manage_api", "config_api", + "context_api", "database_api", "emoji_api", "generator_api", + "get_logger", "llm_api", "message_api", + "permission_api", "person_api", "plugin_manage_api", - "send_api", - "get_logger", "register_plugin", - "tool_api", - "permission_api", - "context_api", "schedule_api", + "send_api", + "tool_api", ] diff --git a/src/plugin_system/apis/chat_api.py b/src/plugin_system/apis/chat_api.py index 9e995d36f..47cecd2d5 100644 --- a/src/plugin_system/apis/chat_api.py +++ b/src/plugin_system/apis/chat_api.py @@ -12,11 +12,11 @@ streams = chat.get_all_group_streams() """ -from typing import List, Dict, Any, Optional from enum import Enum +from typing import Any -from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.common.logger import get_logger logger = get_logger("chat_api") @@ -31,7 +31,7 @@ class ChatManager: """聊天管理器 - 专门负责聊天信息的查询和管理""" @staticmethod - def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: + def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: # sourcery skip: for-append-to-extend """获取所有聊天流 @@ -57,7 +57,7 @@ class ChatManager: return streams @staticmethod - def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: + def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: # sourcery skip: for-append-to-extend """获取所有群聊聊天流 @@ -80,7 +80,7 @@ class ChatManager: return streams @staticmethod - def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: + def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: # sourcery skip: for-append-to-extend """获取所有私聊聊天流 @@ -107,8 +107,8 @@ class ChatManager: @staticmethod def get_group_stream_by_group_id( - group_id: str, platform: Optional[str] | SpecialTypes = "qq" - ) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast + group_id: str, platform: str | None | SpecialTypes = "qq" + ) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast """根据群ID获取聊天流 Args: @@ -144,8 +144,8 @@ class ChatManager: @staticmethod def get_private_stream_by_user_id( - user_id: str, platform: Optional[str] | SpecialTypes = "qq" - ) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast + user_id: str, platform: str | None | SpecialTypes = "qq" + ) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast """根据用户ID获取私聊流 Args: @@ -203,7 +203,7 @@ class ChatManager: return "unknown" @staticmethod - def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]: + def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]: """获取聊天流详细信息 Args: @@ -222,7 +222,7 @@ class ChatManager: raise TypeError("chat_stream 必须是 ChatStream 类型") try: - info: Dict[str, Any] = { + info: dict[str, Any] = { "stream_id": chat_stream.stream_id, "platform": chat_stream.platform, "type": ChatManager.get_stream_type(chat_stream), @@ -250,7 +250,7 @@ class ChatManager: return {} @staticmethod - def get_streams_summary() -> Dict[str, int]: + def get_streams_summary() -> dict[str, int]: """获取聊天流统计摘要 Returns: @@ -285,27 +285,27 @@ class ChatManager: # ============================================================================= -def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: +def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: """获取所有聊天流的便捷函数""" return ChatManager.get_all_streams(platform) -def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: +def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: """获取群聊聊天流的便捷函数""" return ChatManager.get_group_streams(platform) -def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: +def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: """获取私聊聊天流的便捷函数""" return ChatManager.get_private_streams(platform) -def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]: +def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None: """根据群ID获取聊天流的便捷函数""" return ChatManager.get_group_stream_by_group_id(group_id, platform) -def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]: +def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None: """根据用户ID获取私聊流的便捷函数""" return ChatManager.get_private_stream_by_user_id(user_id, platform) @@ -315,11 +315,11 @@ def get_stream_type(chat_stream: ChatStream) -> str: return ChatManager.get_stream_type(chat_stream) -def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]: +def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]: """获取聊天流信息的便捷函数""" return ChatManager.get_stream_info(chat_stream) -def get_streams_summary() -> Dict[str, int]: +def get_streams_summary() -> dict[str, int]: """获取聊天流统计摘要的便捷函数""" return ChatManager.get_streams_summary() diff --git a/src/plugin_system/apis/component_manage_api.py b/src/plugin_system/apis/component_manage_api.py index 1ffa0833e..490237188 100644 --- a/src/plugin_system/apis/component_manage_api.py +++ b/src/plugin_system/apis/component_manage_api.py @@ -1,16 +1,15 @@ -from typing import Optional, Union, Dict from src.plugin_system.base.component_types import ( - CommandInfo, ActionInfo, + CommandInfo, + ComponentType, EventHandlerInfo, PluginInfo, - ComponentType, ToolInfo, ) # === 插件信息查询 === -def get_all_plugin_info() -> Dict[str, PluginInfo]: +def get_all_plugin_info() -> dict[str, PluginInfo]: """ 获取所有插件的信息。 @@ -22,7 +21,7 @@ def get_all_plugin_info() -> Dict[str, PluginInfo]: return component_registry.get_all_plugins() -def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]: +def get_plugin_info(plugin_name: str) -> PluginInfo | None: """ 获取指定插件的信息。 @@ -40,7 +39,7 @@ def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]: # === 组件查询方法 === def get_component_info( component_name: str, component_type: ComponentType -) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]: +) -> CommandInfo | ActionInfo | EventHandlerInfo | None: """ 获取指定组件的信息。 @@ -57,7 +56,7 @@ def get_component_info( def get_components_info_by_type( component_type: ComponentType, -) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]: +) -> dict[str, CommandInfo | ActionInfo | EventHandlerInfo]: """ 获取指定类型的所有组件信息。 @@ -74,7 +73,7 @@ def get_components_info_by_type( def get_enabled_components_info_by_type( component_type: ComponentType, -) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]: +) -> dict[str, CommandInfo | ActionInfo | EventHandlerInfo]: """ 获取指定类型的所有启用的组件信息。 @@ -90,7 +89,7 @@ def get_enabled_components_info_by_type( # === Action 查询方法 === -def get_registered_action_info(action_name: str) -> Optional[ActionInfo]: +def get_registered_action_info(action_name: str) -> ActionInfo | None: """ 获取指定 Action 的注册信息。 @@ -105,7 +104,7 @@ def get_registered_action_info(action_name: str) -> Optional[ActionInfo]: return component_registry.get_registered_action_info(action_name) -def get_registered_command_info(command_name: str) -> Optional[CommandInfo]: +def get_registered_command_info(command_name: str) -> CommandInfo | None: """ 获取指定 Command 的注册信息。 @@ -120,7 +119,7 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]: return component_registry.get_registered_command_info(command_name) -def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]: +def get_registered_tool_info(tool_name: str) -> ToolInfo | None: """ 获取指定 Tool 的注册信息。 @@ -138,7 +137,7 @@ def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]: # === EventHandler 特定查询方法 === def get_registered_event_handler_info( event_handler_name: str, -) -> Optional[EventHandlerInfo]: +) -> EventHandlerInfo | None: """ 获取指定 EventHandler 的注册信息。 diff --git a/src/plugin_system/apis/config_api.py b/src/plugin_system/apis/config_api.py index 05556414e..3ec8694b2 100644 --- a/src/plugin_system/apis/config_api.py +++ b/src/plugin_system/apis/config_api.py @@ -8,6 +8,7 @@ """ from typing import Any + from src.common.logger import get_logger from src.config.config import global_config diff --git a/src/plugin_system/apis/cross_context_api.py b/src/plugin_system/apis/cross_context_api.py index 76bd45bde..3e84cc26b 100644 --- a/src/plugin_system/apis/cross_context_api.py +++ b/src/plugin_system/apis/cross_context_api.py @@ -3,20 +3,20 @@ """ import time -from typing import Dict, Any, Optional, List +from typing import Any +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.utils.chat_message_builder import ( + build_readable_messages_with_id, + get_raw_msg_before_timestamp_with_chat, +) from src.common.logger import get_logger from src.config.config import global_config -from src.chat.utils.chat_message_builder import ( - get_raw_msg_before_timestamp_with_chat, - build_readable_messages_with_id, -) -from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream logger = get_logger("cross_context_api") -def get_context_groups(chat_id: str) -> Optional[List[List[str]]]: +def get_context_groups(chat_id: str) -> list[list[str]] | None: """ 获取当前聊天所在的共享组的其他聊天ID """ @@ -41,7 +41,7 @@ def get_context_groups(chat_id: str) -> Optional[List[List[str]]]: return None -async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: List[List[str]]) -> str: +async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: list[list[str]]) -> str: """ 构建跨群聊/私聊上下文 (Normal模式) """ @@ -74,8 +74,8 @@ async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: async def build_cross_context_s4u( chat_stream: ChatStream, - other_chat_infos: List[List[str]], - target_user_info: Optional[Dict[str, Any]], + other_chat_infos: list[list[str]], + target_user_info: dict[str, Any] | None, ) -> str: """ 构建跨群聊/私聊上下文 (S4U模式) diff --git a/src/plugin_system/apis/database_api.py b/src/plugin_system/apis/database_api.py index c3195bab4..aa6714655 100644 --- a/src/plugin_system/apis/database_api.py +++ b/src/plugin_system/apis/database_api.py @@ -9,7 +9,7 @@ 注意:此模块现在使用SQLAlchemy实现,提供更好的连接管理和错误处理 """ -from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get, store_action_info, MODEL_MAPPING +from src.common.database.sqlalchemy_database_api import MODEL_MAPPING, db_get, db_query, db_save, store_action_info # 保持向后兼容性 -__all__ = ["db_query", "db_save", "db_get", "store_action_info", "MODEL_MAPPING"] +__all__ = ["MODEL_MAPPING", "db_get", "db_query", "db_save", "store_action_info"] diff --git a/src/plugin_system/apis/emoji_api.py b/src/plugin_system/apis/emoji_api.py index 4fbadb98f..a62977d66 100644 --- a/src/plugin_system/apis/emoji_api.py +++ b/src/plugin_system/apis/emoji_api.py @@ -10,10 +10,9 @@ import random -from typing import Optional, Tuple, List -from src.common.logger import get_logger from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.utils.utils_image import image_path_to_base64 +from src.common.logger import get_logger logger = get_logger("emoji_api") @@ -23,7 +22,7 @@ logger = get_logger("emoji_api") # ============================================================================= -async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]: +async def get_by_description(description: str) -> tuple[str, str, str] | None: """根据描述选择表情包 Args: @@ -65,7 +64,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]] return None -async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: +async def get_random(count: int | None = 1) -> list[tuple[str, str, str]]: """随机获取指定数量的表情包 Args: @@ -137,7 +136,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: return [] -async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: +async def get_by_emotion(emotion: str) -> tuple[str, str, str] | None: """根据情感标签获取表情包 Args: @@ -227,7 +226,7 @@ def get_info(): return {"current_count": 0, "max_count": 0, "available_emojis": 0} -def get_emotions() -> List[str]: +def get_emotions() -> list[str]: """获取所有可用的情感标签 Returns: @@ -247,7 +246,7 @@ def get_emotions() -> List[str]: return [] -def get_descriptions() -> List[str]: +def get_descriptions() -> list[str]: """获取所有表情包描述 Returns: diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 2a907c60b..21bc6fdde 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -9,13 +9,15 @@ """ import traceback -from typing import Tuple, Any, Dict, List, Optional +from typing import Any + from rich.traceback import install -from src.common.logger import get_logger -from src.chat.replyer.default_generator import DefaultReplyer + from src.chat.message_receive.chat_stream import ChatStream -from src.chat.utils.utils import process_llm_response +from src.chat.replyer.default_generator import DefaultReplyer from src.chat.replyer.replyer_manager import replyer_manager +from src.chat.utils.utils import process_llm_response +from src.common.logger import get_logger from src.plugin_system.base.component_types import ActionInfo install(extra_lines=3) @@ -30,10 +32,10 @@ logger = get_logger("generator_api") def get_replyer( - chat_stream: Optional[ChatStream] = None, - chat_id: Optional[str] = None, + chat_stream: ChatStream | None = None, + chat_id: str | None = None, request_type: str = "replyer", -) -> Optional[DefaultReplyer]: +) -> DefaultReplyer | None: """获取回复器对象 优先使用chat_stream,如果没有则使用chat_id直接查找。 @@ -71,13 +73,13 @@ def get_replyer( async def generate_reply( - chat_stream: Optional[ChatStream] = None, - chat_id: Optional[str] = None, - action_data: Optional[Dict[str, Any]] = None, + chat_stream: ChatStream | None = None, + chat_id: str | None = None, + action_data: dict[str, Any] | None = None, reply_to: str = "", - reply_message: Optional[Dict[str, Any]] = None, + reply_message: dict[str, Any] | None = None, extra_info: str = "", - available_actions: Optional[Dict[str, ActionInfo]] = None, + available_actions: dict[str, ActionInfo] | None = None, enable_tool: bool = False, enable_splitter: bool = True, enable_chinese_typo: bool = True, @@ -85,7 +87,7 @@ async def generate_reply( request_type: str = "generator_api", from_plugin: bool = True, read_mark: float = 0.0, -) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: +) -> tuple[bool, list[tuple[str, Any]], str | None]: """生成回复 Args: @@ -168,9 +170,9 @@ async def generate_reply( async def rewrite_reply( - chat_stream: Optional[ChatStream] = None, - reply_data: Optional[Dict[str, Any]] = None, - chat_id: Optional[str] = None, + chat_stream: ChatStream | None = None, + reply_data: dict[str, Any] | None = None, + chat_id: str | None = None, enable_splitter: bool = True, enable_chinese_typo: bool = True, raw_reply: str = "", @@ -178,7 +180,7 @@ async def rewrite_reply( reply_to: str = "", return_prompt: bool = False, request_type: str = "generator_api", -) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: +) -> tuple[bool, list[tuple[str, Any]], str | None]: """重写回复 Args: @@ -237,7 +239,7 @@ async def rewrite_reply( return False, [], None -def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]: +def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> list[tuple[str, Any]]: """将文本处理为更拟人化的文本 Args: @@ -266,11 +268,11 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: async def generate_response_custom( - chat_stream: Optional[ChatStream] = None, - chat_id: Optional[str] = None, + chat_stream: ChatStream | None = None, + chat_id: str | None = None, request_type: str = "generator_api", prompt: str = "", -) -> Optional[str]: +) -> str | None: """ 使用自定义提示生成回复 diff --git a/src/plugin_system/apis/llm_api.py b/src/plugin_system/apis/llm_api.py index debb67d7e..e868d40a2 100644 --- a/src/plugin_system/apis/llm_api.py +++ b/src/plugin_system/apis/llm_api.py @@ -7,12 +7,13 @@ success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config) """ -from typing import Tuple, Dict, List, Any, Optional +from typing import Any + from src.common.logger import get_logger +from src.config.api_ada_configs import TaskConfig +from src.config.config import model_config from src.llm_models.payload_content.tool_option import ToolCall from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config -from src.config.api_ada_configs import TaskConfig logger = get_logger("llm_api") @@ -21,7 +22,7 @@ logger = get_logger("llm_api") # ============================================================================= -def get_available_models() -> Dict[str, TaskConfig]: +def get_available_models() -> dict[str, TaskConfig]: """获取所有可用的模型配置 Returns: @@ -31,7 +32,7 @@ def get_available_models() -> Dict[str, TaskConfig]: # 自动获取所有属性并转换为字典形式 models = model_config.model_task_config attrs = dir(models) - rets: Dict[str, TaskConfig] = {} + rets: dict[str, TaskConfig] = {} for attr in attrs: if not attr.startswith("__"): try: @@ -52,9 +53,9 @@ async def generate_with_model( prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, -) -> Tuple[bool, str, str, str]: + temperature: float | None = None, + max_tokens: int | None = None, +) -> tuple[bool, str, str, str]: """使用指定模型生成内容 Args: @@ -78,7 +79,7 @@ async def generate_with_model( return True, response, reasoning_content, model_name except Exception as e: - error_msg = f"生成内容时出错: {str(e)}" + error_msg = f"生成内容时出错: {e!s}" logger.error(f"[LLMAPI] {error_msg}") return False, error_msg, "", "" @@ -86,11 +87,11 @@ async def generate_with_model( async def generate_with_model_with_tools( prompt: str, model_config: TaskConfig, - tool_options: List[Dict[str, Any]] | None = None, + tool_options: list[dict[str, Any]] | None = None, request_type: str = "plugin.generate", - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, -) -> Tuple[bool, str, str, str, List[ToolCall] | None]: + temperature: float | None = None, + max_tokens: int | None = None, +) -> tuple[bool, str, str, str, list[ToolCall] | None]: """使用指定模型和工具生成内容 Args: @@ -117,6 +118,6 @@ async def generate_with_model_with_tools( return True, response, reasoning_content, model_name, tool_call except Exception as e: - error_msg = f"生成内容时出错: {str(e)}" + error_msg = f"生成内容时出错: {e!s}" logger.error(f"[LLMAPI] {error_msg}") return False, error_msg, "", "", None diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index baf6418dd..4a9610ca2 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -8,26 +8,26 @@ readable_text = message_api.build_readable_messages(messages) """ -from typing import List, Dict, Any, Tuple, Optional -from src.config.config import global_config import time +from typing import Any + from src.chat.utils.chat_message_builder import ( - get_raw_msg_by_timestamp, - get_raw_msg_by_timestamp_with_chat, - get_raw_msg_by_timestamp_with_chat_inclusive, - get_raw_msg_by_timestamp_with_chat_users, - get_raw_msg_by_timestamp_random, - get_raw_msg_by_timestamp_with_users, - get_raw_msg_before_timestamp, - get_raw_msg_before_timestamp_with_chat, - get_raw_msg_before_timestamp_with_users, - num_new_messages_since, - num_new_messages_since_with_users, build_readable_messages, build_readable_messages_with_list, get_person_id_list, + get_raw_msg_before_timestamp, + get_raw_msg_before_timestamp_with_chat, + get_raw_msg_before_timestamp_with_users, + get_raw_msg_by_timestamp, + get_raw_msg_by_timestamp_random, + get_raw_msg_by_timestamp_with_chat, + get_raw_msg_by_timestamp_with_chat_inclusive, + get_raw_msg_by_timestamp_with_chat_users, + get_raw_msg_by_timestamp_with_users, + num_new_messages_since, + num_new_messages_since_with_users, ) - +from src.config.config import global_config # ============================================================================= # 消息查询API函数 @@ -36,7 +36,7 @@ from src.chat.utils.chat_message_builder import ( async def get_messages_by_time( start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 获取指定时间范围内的消息 @@ -70,7 +70,7 @@ async def get_messages_by_time_in_chat( limit_mode: str = "latest", filter_mai: bool = False, filter_command: bool = False, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 获取指定聊天中指定时间范围内的消息 @@ -111,7 +111,7 @@ async def get_messages_by_time_in_chat_inclusive( limit_mode: str = "latest", filter_mai: bool = False, filter_command: bool = False, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 获取指定聊天中指定时间范围内的消息(包含边界) @@ -152,10 +152,10 @@ async def get_messages_by_time_in_chat_for_users( chat_id: str, start_time: float, end_time: float, - person_ids: List[str], + person_ids: list[str], limit: int = 0, limit_mode: str = "latest", -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 获取指定聊天中指定用户在指定时间范围内的消息 @@ -186,7 +186,7 @@ async def get_messages_by_time_in_chat_for_users( async def get_random_chat_messages( start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 随机选择一个聊天,返回该聊天在指定时间范围内的消息 @@ -213,8 +213,8 @@ async def get_random_chat_messages( async def get_messages_by_time_for_users( - start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" -) -> List[Dict[str, Any]]: + start_time: float, end_time: float, person_ids: list[str], limit: int = 0, limit_mode: str = "latest" +) -> list[dict[str, Any]]: """ 获取指定用户在所有聊天中指定时间范围内的消息 @@ -238,7 +238,7 @@ async def get_messages_by_time_for_users( return await get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode) -async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]: +async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> list[dict[str, Any]]: """ 获取指定时间戳之前的消息 @@ -294,8 +294,8 @@ async def get_messages_before_time_in_chat( async def get_messages_before_time_for_users( - timestamp: float, person_ids: List[str], limit: int = 0 -) -> List[Dict[str, Any]]: + timestamp: float, person_ids: list[str], limit: int = 0 +) -> list[dict[str, Any]]: """ 获取指定用户在指定时间戳之前的消息 @@ -319,7 +319,7 @@ async def get_messages_before_time_for_users( async def get_recent_messages( chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ 获取指定聊天中最近一段时间的消息 @@ -358,7 +358,7 @@ async def get_recent_messages( # ============================================================================= -async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int: +async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: float | None = None) -> int: """ 计算指定聊天中从开始时间到结束时间的新消息数量 @@ -382,7 +382,7 @@ async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Op return await num_new_messages_since(chat_id, start_time, end_time) -async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int: +async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: list[str]) -> int: """ 计算指定聊天中指定用户从开始时间到结束时间的新消息数量 @@ -413,7 +413,7 @@ async def count_new_messages_for_users(chat_id: str, start_time: float, end_time async def build_readable_messages_to_str( - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, timestamp_mode: str = "relative", @@ -442,12 +442,12 @@ async def build_readable_messages_to_str( async def build_readable_messages_with_details( - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, timestamp_mode: str = "relative", truncate: bool = False, -) -> Tuple[str, List[Tuple[float, str, str]]]: +) -> tuple[str, list[tuple[float, str, str]]]: """ 将消息列表构建成可读的字符串,并返回详细信息 @@ -464,7 +464,7 @@ async def build_readable_messages_with_details( return await build_readable_messages_with_list(messages, replace_bot_name, merge_messages, timestamp_mode, truncate) -async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]: +async def get_person_ids_from_messages(messages: list[dict[str, Any]]) -> list[str]: """ 从消息列表中提取不重复的用户ID列表 @@ -482,7 +482,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s # ============================================================================= -async def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +async def filter_mai_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """ 从消息列表中移除麦麦的消息 Args: diff --git a/src/plugin_system/apis/permission_api.py b/src/plugin_system/apis/permission_api.py index 61b4ca40f..3c42f9eab 100644 --- a/src/plugin_system/apis/permission_api.py +++ b/src/plugin_system/apis/permission_api.py @@ -1,9 +1,9 @@ """纯异步权限API定义。所有外部调用方必须使用 await。""" -from typing import Optional, List, Dict, Any +from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from abc import ABC, abstractmethod +from typing import Any from src.common.logger import get_logger @@ -48,18 +48,18 @@ class IPermissionManager(ABC): async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: ... @abstractmethod - async def get_user_permissions(self, user: UserInfo) -> List[str]: ... + async def get_user_permissions(self, user: UserInfo) -> list[str]: ... @abstractmethod - async def get_all_permission_nodes(self) -> List[PermissionNode]: ... + async def get_all_permission_nodes(self) -> list[PermissionNode]: ... @abstractmethod - async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: ... + async def get_plugin_permission_nodes(self, plugin_name: str) -> list[PermissionNode]: ... class PermissionAPI: def __init__(self): - self._permission_manager: Optional[IPermissionManager] = None + self._permission_manager: IPermissionManager | None = None # 需要保留的前缀(视为绝对节点名,不再自动加 plugins.. 前缀) self.RESERVED_PREFIXES: tuple[str, ...] = "system." # 系统节点列表 (name, description, default_granted) @@ -147,11 +147,11 @@ class PermissionAPI: self._ensure_manager() return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node) - async def get_user_permissions(self, platform: str, user_id: str) -> List[str]: + async def get_user_permissions(self, platform: str, user_id: str) -> list[str]: self._ensure_manager() return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id)) - async def get_all_permission_nodes(self) -> List[Dict[str, Any]]: + async def get_all_permission_nodes(self) -> list[dict[str, Any]]: self._ensure_manager() nodes = await self._permission_manager.get_all_permission_nodes() return [ @@ -164,7 +164,7 @@ class PermissionAPI: for n in nodes ] - async def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]: + async def get_plugin_permission_nodes(self, plugin_name: str) -> list[dict[str, Any]]: self._ensure_manager() nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name) return [ diff --git a/src/plugin_system/apis/person_api.py b/src/plugin_system/apis/person_api.py index e3f7be714..5c3427dff 100644 --- a/src/plugin_system/apis/person_api.py +++ b/src/plugin_system/apis/person_api.py @@ -7,9 +7,10 @@ value = await person_api.get_person_value(person_id, "nickname") """ -from typing import Any, Optional +from typing import Any + from src.common.logger import get_logger -from src.person_info.person_info import get_person_info_manager, PersonInfoManager +from src.person_info.person_info import PersonInfoManager, get_person_info_manager logger = get_logger("person_api") @@ -63,7 +64,7 @@ async def get_person_value(person_id: str, field_name: str, default: Any = None) return default -async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict: +async def get_person_values(person_id: str, field_names: list, default_dict: dict | None = None) -> dict: """批量获取用户信息字段值 Args: diff --git a/src/plugin_system/apis/plugin_manage_api.py b/src/plugin_system/apis/plugin_manage_api.py index d428eb282..d7a802b8c 100644 --- a/src/plugin_system/apis/plugin_manage_api.py +++ b/src/plugin_system/apis/plugin_manage_api.py @@ -1,7 +1,4 @@ -from typing import Tuple, List - - -def list_loaded_plugins() -> List[str]: +def list_loaded_plugins() -> list[str]: """ 列出所有当前加载的插件。 @@ -13,7 +10,7 @@ def list_loaded_plugins() -> List[str]: return plugin_manager.list_loaded_plugins() -def list_registered_plugins() -> List[str]: +def list_registered_plugins() -> list[str]: """ 列出所有已注册的插件。 @@ -80,7 +77,7 @@ async def reload_plugin(plugin_name: str) -> bool: return await plugin_manager.reload_registered_plugin(plugin_name) -def load_plugin(plugin_name: str) -> Tuple[bool, int]: +def load_plugin(plugin_name: str) -> tuple[bool, int]: """ 加载指定的插件。 @@ -109,7 +106,7 @@ def add_plugin_directory(plugin_directory: str) -> bool: return plugin_manager.add_plugin_directory(plugin_directory) -def rescan_plugin_directory() -> Tuple[int, int]: +def rescan_plugin_directory() -> tuple[int, int]: """ 重新扫描插件目录,加载新插件。 Returns: diff --git a/src/plugin_system/apis/plugin_register_api.py b/src/plugin_system/apis/plugin_register_api.py index 2e14b0c84..6741c7ea9 100644 --- a/src/plugin_system/apis/plugin_register_api.py +++ b/src/plugin_system/apis/plugin_register_api.py @@ -6,8 +6,8 @@ logger = get_logger("plugin_manager") # 复用plugin_manager名称 def register_plugin(cls): - from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.base.base_plugin import BasePlugin + from src.plugin_system.core.plugin_manager import plugin_manager """插件注册装饰器 diff --git a/src/plugin_system/apis/schedule_api.py b/src/plugin_system/apis/schedule_api.py index e3e759968..61c5d13f4 100644 --- a/src/plugin_system/apis/schedule_api.py +++ b/src/plugin_system/apis/schedule_api.py @@ -30,7 +30,7 @@ """ from datetime import datetime -from typing import List, Dict, Any, Optional +from typing import Any from src.common.database.sqlalchemy_models import MonthlyPlan from src.common.logger import get_logger @@ -44,7 +44,7 @@ class ScheduleAPI: """日程表与月度计划API - 负责日程和计划信息的查询与管理""" @staticmethod - async def get_today_schedule() -> Optional[List[Dict[str, Any]]]: + async def get_today_schedule() -> list[dict[str, Any]] | None: """(异步) 获取今天的日程安排 Returns: @@ -58,7 +58,7 @@ class ScheduleAPI: return None @staticmethod - async def get_current_activity() -> Optional[str]: + async def get_current_activity() -> str | None: """(异步) 获取当前正在进行的活动 Returns: @@ -87,7 +87,7 @@ class ScheduleAPI: return False @staticmethod - async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]: + async def get_monthly_plans(target_month: str | None = None) -> list[MonthlyPlan]: """(异步) 获取指定月份的有效月度计划 Args: @@ -106,7 +106,7 @@ class ScheduleAPI: return [] @staticmethod - async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool: + async def ensure_monthly_plans(target_month: str | None = None) -> bool: """(异步) 确保指定月份存在月度计划,如果不存在则触发生成 Args: @@ -125,7 +125,7 @@ class ScheduleAPI: return False @staticmethod - async def archive_monthly_plans(target_month: Optional[str] = None) -> bool: + async def archive_monthly_plans(target_month: str | None = None) -> bool: """(异步) 归档指定月份的月度计划 Args: @@ -150,12 +150,12 @@ class ScheduleAPI: # ============================================================================= -async def get_today_schedule() -> Optional[List[Dict[str, Any]]]: +async def get_today_schedule() -> list[dict[str, Any]] | None: """(异步) 获取今天的日程安排的便捷函数""" return await ScheduleAPI.get_today_schedule() -async def get_current_activity() -> Optional[str]: +async def get_current_activity() -> str | None: """(异步) 获取当前正在进行的活动的便捷函数""" return await ScheduleAPI.get_current_activity() @@ -165,16 +165,16 @@ async def regenerate_schedule() -> bool: return await ScheduleAPI.regenerate_schedule() -async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]: +async def get_monthly_plans(target_month: str | None = None) -> list[MonthlyPlan]: """(异步) 获取指定月份的有效月度计划的便捷函数""" return await ScheduleAPI.get_monthly_plans(target_month) -async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool: +async def ensure_monthly_plans(target_month: str | None = None) -> bool: """(异步) 确保指定月份存在月度计划的便捷函数""" return await ScheduleAPI.ensure_monthly_plans(target_month) -async def archive_monthly_plans(target_month: Optional[str] = None) -> bool: +async def archive_monthly_plans(target_month: str | None = None) -> bool: """(异步) 归档指定月份的月度计划的便捷函数""" return await ScheduleAPI.archive_monthly_plans(target_month) diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index c770db78b..d05e50355 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -28,29 +28,28 @@ """ -import traceback -import time import asyncio -from typing import Optional, Union, Dict, Any -from src.common.logger import get_logger +import time +import traceback +from typing import Any + +from maim_message import Seg, UserInfo # 导入依赖 -from src.chat.message_receive.chat_stream import get_chat_manager -from maim_message import UserInfo -from src.chat.message_receive.chat_stream import ChatStream +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.message import MessageRecv, MessageSending from src.chat.message_receive.uni_message_sender import HeartFCSender -from src.chat.message_receive.message import MessageSending, MessageRecv -from maim_message import Seg +from src.common.logger import get_logger from src.config.config import global_config # 日志记录器 logger = get_logger("send_api") # 适配器命令响应等待池 -_adapter_response_pool: Dict[str, asyncio.Future] = {} +_adapter_response_pool: dict[str, asyncio.Future] = {} -def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[MessageRecv]: +def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None: """查找要回复的消息 Args: @@ -134,13 +133,13 @@ async def wait_adapter_response(request_id: str, timeout: float = 30.0) -> dict: async def _send_to_target( message_type: str, - content: Union[str, dict], + content: str | dict, stream_id: str, display_message: str = "", typing: bool = False, reply_to: str = "", set_reply: bool = False, - reply_to_message: Optional[Dict[str, Any]] = None, + reply_to_message: dict[str, Any] | None = None, storage_message: bool = True, show_log: bool = True, ) -> bool: @@ -247,7 +246,7 @@ async def text_to_stream( stream_id: str, typing: bool = False, reply_to: str = "", - reply_to_message: Optional[Dict[str, Any]] = None, + reply_to_message: dict[str, Any] | None = None, set_reply: bool = True, storage_message: bool = True, ) -> bool: @@ -313,7 +312,7 @@ async def image_to_stream( async def command_to_stream( - command: Union[str, dict], + command: str | dict, stream_id: str, storage_message: bool = True, display_message: str = "", @@ -341,7 +340,7 @@ async def custom_to_stream( display_message: str = "", typing: bool = False, reply_to: str = "", - reply_to_message: Optional[Dict[str, Any]] = None, + reply_to_message: dict[str, Any] | None = None, set_reply: bool = True, storage_message: bool = True, show_log: bool = True, @@ -377,8 +376,8 @@ async def custom_to_stream( async def adapter_command_to_stream( action: str, params: dict, - platform: Optional[str] = "qq", - stream_id: Optional[str] = None, + platform: str | None = "qq", + stream_id: str | None = None, timeout: float = 30.0, storage_message: bool = False, ) -> dict: @@ -497,4 +496,4 @@ async def adapter_command_to_stream( except Exception as e: logger.error(f"[SendAPI] 发送适配器命令时出错: {e}") traceback.print_exc() - return {"status": "error", "message": f"发送适配器命令时出错: {str(e)}"} + return {"status": "error", "message": f"发送适配器命令时出错: {e!s}"} diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py index c3472243a..6b949b2e5 100644 --- a/src/plugin_system/apis/tool_api.py +++ b/src/plugin_system/apis/tool_api.py @@ -1,13 +1,11 @@ -from typing import Optional, Type +from src.common.logger import get_logger from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.component_types import ComponentType -from src.common.logger import get_logger - logger = get_logger("tool_api") -def get_tool_instance(tool_name: str) -> Optional[BaseTool]: +def get_tool_instance(tool_name: str) -> BaseTool | None: """获取公开工具实例""" from src.plugin_system.core import component_registry @@ -18,7 +16,7 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]: else: plugin_config = None - tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore + tool_class: type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore return tool_class(plugin_config) if tool_class else None diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py index 83debab01..87f004ff5 100644 --- a/src/plugin_system/base/__init__.py +++ b/src/plugin_system/base/__init__.py @@ -4,31 +4,31 @@ 提供插件开发的基础类和类型定义 """ -from .base_plugin import BasePlugin from .base_action import BaseAction -from .base_tool import BaseTool from .base_command import BaseCommand from .base_events_handler import BaseEventHandler +from .base_plugin import BasePlugin +from .base_tool import BaseTool +from .command_args import CommandArgs from .component_types import ( - ComponentType, ActionActivationType, + ActionInfo, ChatMode, ChatType, - ComponentInfo, - ActionInfo, CommandInfo, - PlusCommandInfo, - ToolInfo, - PluginInfo, - PythonDependency, + ComponentInfo, + ComponentType, EventHandlerInfo, EventType, MaiMessages, + PluginInfo, + PlusCommandInfo, + PythonDependency, + ToolInfo, ToolParamType, ) from .config_types import ConfigField from .plus_command import PlusCommand, PlusCommandAdapter, create_plus_command_adapter -from .command_args import CommandArgs __all__ = [ "BasePlugin", diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index d3f012be5..37711794b 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -1,14 +1,11 @@ -import time import asyncio - +import time from abc import ABC, abstractmethod -from typing import Tuple, Optional, List, Dict -from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream -from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType, ChatType -from src.plugin_system.apis import send_api, database_api, message_api - +from src.common.logger import get_logger +from src.plugin_system.apis import database_api, message_api, send_api +from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ChatMode, ChatType, ComponentType logger = get_logger("base_action") @@ -39,7 +36,7 @@ class BaseAction(ABC): """是否为二步Action。如果为True,Action将分两步执行:第一步选择操作,第二步执行具体操作""" step_one_description: str = "" """第一步的描述,用于向LLM展示Action的基本功能""" - sub_actions: List[Tuple[str, str, Dict[str, str]]] = [] + sub_actions: list[tuple[str, str, dict[str, str]]] = [] """子Action列表,格式为[(子Action名, 子Action描述, 子Action参数)]。仅在二步Action中使用""" def __init__( @@ -50,8 +47,8 @@ class BaseAction(ABC): thinking_id: str, chat_stream: ChatStream, log_prefix: str = "", - plugin_config: Optional[dict] = None, - action_message: Optional[dict] = None, + plugin_config: dict | None = None, + action_message: dict | None = None, **kwargs, ): # sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs @@ -109,8 +106,8 @@ class BaseAction(ABC): # 二步Action相关实例属性 self.is_two_step_action: bool = getattr(self.__class__, "is_two_step_action", False) self.step_one_description: str = getattr(self.__class__, "step_one_description", "") - self.sub_actions: List[Tuple[str, str, Dict[str, str]]] = getattr(self.__class__, "sub_actions", []).copy() - self._selected_sub_action: Optional[str] = None + self.sub_actions: list[tuple[str, str, dict[str, str]]] = getattr(self.__class__, "sub_actions", []).copy() + self._selected_sub_action: str | None = None """当前选择的子Action名称,用于二步Action的状态管理""" # ============================================================================= @@ -200,7 +197,7 @@ class BaseAction(ABC): """ return self._validate_chat_type() - async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]: + async def wait_for_new_message(self, timeout: int = 1200) -> tuple[bool, str]: """等待新消息或超时 在loop_start_time之后等待新消息,如果没有新消息且没有超时,就一直等待。 @@ -232,7 +229,7 @@ class BaseAction(ABC): # 检查新消息 current_time = time.time() - new_message_count = message_api.count_new_messages( + new_message_count = await message_api.count_new_messages( chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time ) @@ -258,7 +255,7 @@ class BaseAction(ABC): return False, "" except Exception as e: logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}") - return False, f"等待新消息失败: {str(e)}" + return False, f"等待新消息失败: {e!s}" async def send_text(self, content: str, reply_to: str = "", typing: bool = False) -> bool: """发送文本消息 @@ -359,7 +356,7 @@ class BaseAction(ABC): ) async def send_command( - self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True + self, command_name: str, args: dict | None = None, display_message: str = "", storage_message: bool = True ) -> bool: """发送命令消息 @@ -400,7 +397,7 @@ class BaseAction(ABC): logger.error(f"{self.log_prefix} 发送命令时出错: {e}") return False - async def call_action(self, action_name: str, action_data: Optional[dict] = None) -> Tuple[bool, str]: + async def call_action(self, action_name: str, action_data: dict | None = None) -> tuple[bool, str]: """ 在当前Action中调用另一个Action。 @@ -514,7 +511,7 @@ class BaseAction(ABC): sub_actions=getattr(cls, "sub_actions", []).copy(), ) - async def handle_step_one(self) -> Tuple[bool, str]: + async def handle_step_one(self) -> tuple[bool, str]: """处理二步Action的第一步 Returns: @@ -546,7 +543,7 @@ class BaseAction(ABC): # 调用第二步执行 return await self.execute_step_two(selected_action) - async def execute_step_two(self, sub_action_name: str) -> Tuple[bool, str]: + async def execute_step_two(self, sub_action_name: str) -> tuple[bool, str]: """执行二步Action的第二步 Args: @@ -562,7 +559,7 @@ class BaseAction(ABC): return False, f"二步Action必须实现execute_step_two方法来处理操作: {sub_action_name}" @abstractmethod - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """执行Action的抽象方法,子类必须实现 对于二步Action,会自动处理第一步逻辑 @@ -577,7 +574,7 @@ class BaseAction(ABC): # 普通Action由子类实现 pass - async def handle_action(self) -> Tuple[bool, str]: + async def handle_action(self) -> tuple[bool, str]: """兼容旧系统的handle_action接口,委托给execute方法 为了保持向后兼容性,旧系统的代码可能会调用handle_action方法。 diff --git a/src/plugin_system/base/base_chatter.py b/src/plugin_system/base/base_chatter.py index 1dd225252..b8a1288af 100644 --- a/src/plugin_system/base/base_chatter.py +++ b/src/plugin_system/base/base_chatter.py @@ -1,9 +1,11 @@ from abc import ABC, abstractmethod -from typing import List, TYPE_CHECKING +from typing import TYPE_CHECKING + from src.common.data_models.message_manager_data_model import StreamContext -from .component_types import ChatType from src.plugin_system.base.component_types import ChatterInfo, ComponentType +from .component_types import ChatType + if TYPE_CHECKING: from src.chat.planner_actions.action_manager import ChatterActionManager @@ -13,7 +15,7 @@ class BaseChatter(ABC): """Chatter组件的名称""" chatter_description: str = "" """Chatter组件的描述""" - chat_types: List[ChatType] = [ChatType.PRIVATE, ChatType.GROUP] + chat_types: list[ChatType] = [ChatType.PRIVATE, ChatType.GROUP] def __init__(self, stream_id: str, action_manager: "ChatterActionManager"): """ diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 212634d5d..9cb41ed04 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from typing import Dict, Tuple, Optional -from src.common.logger import get_logger -from src.plugin_system.base.component_types import CommandInfo, ComponentType, ChatType + from src.chat.message_receive.message import MessageRecv +from src.common.logger import get_logger from src.plugin_system.apis import send_api +from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType logger = get_logger("base_command") @@ -29,7 +29,7 @@ class BaseCommand(ABC): chat_type_allow: ChatType = ChatType.ALL """允许的聊天类型,默认为所有类型""" - def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None): + def __init__(self, message: MessageRecv, plugin_config: dict | None = None): """初始化Command组件 Args: @@ -37,7 +37,7 @@ class BaseCommand(ABC): plugin_config: 插件配置字典 """ self.message = message - self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组 + self.matched_groups: dict[str, str] = {} # 存储正则表达式匹配的命名组 self.plugin_config = plugin_config or {} # 直接存储插件配置字典 self.log_prefix = "[Command]" @@ -55,7 +55,7 @@ class BaseCommand(ABC): f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}" ) - def set_matched_groups(self, groups: Dict[str, str]) -> None: + def set_matched_groups(self, groups: dict[str, str]) -> None: """设置正则表达式匹配的命名组 Args: @@ -93,7 +93,7 @@ class BaseCommand(ABC): return self._validate_chat_type() @abstractmethod - async def execute(self) -> Tuple[bool, Optional[str], bool]: + async def execute(self) -> tuple[bool, str | None, bool]: """执行Command的抽象方法,子类必须实现 Returns: @@ -175,7 +175,7 @@ class BaseCommand(ABC): ) async def send_command( - self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True + self, command_name: str, args: dict | None = None, display_message: str = "", storage_message: bool = True ) -> bool: """发送命令消息 diff --git a/src/plugin_system/base/base_event.py b/src/plugin_system/base/base_event.py index c7dd09a58..f8c45e54d 100644 --- a/src/plugin_system/base/base_event.py +++ b/src/plugin_system/base/base_event.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Dict, Any, Optional +from typing import Any from src.common.logger import get_logger @@ -25,22 +25,22 @@ class HandlerResult: class HandlerResultsCollection: """HandlerResult集合,提供便捷的查询方法""" - def __init__(self, results: List[HandlerResult]): + def __init__(self, results: list[HandlerResult]): self.results = results def all_continue_process(self) -> bool: """检查是否所有handler的continue_process都为True""" return all(result.continue_process for result in self.results) - def get_all_results(self) -> List[HandlerResult]: + def get_all_results(self) -> list[HandlerResult]: """获取所有HandlerResult""" return self.results - def get_failed_handlers(self) -> List[HandlerResult]: + def get_failed_handlers(self) -> list[HandlerResult]: """获取执行失败的handler结果""" return [result for result in self.results if not result.success] - def get_stopped_handlers(self) -> List[HandlerResult]: + def get_stopped_handlers(self) -> list[HandlerResult]: """获取continue_process为False的handler结果""" return [result for result in self.results if not result.continue_process] @@ -57,7 +57,7 @@ class HandlerResultsCollection: else: return {result.handler_name: result.message for result in self.results} - def get_handler_result(self, handler_name: str) -> Optional[HandlerResult]: + def get_handler_result(self, handler_name: str) -> HandlerResult | None: """获取指定handler的结果""" for result in self.results: if result.handler_name == handler_name: @@ -72,7 +72,7 @@ class HandlerResultsCollection: """获取执行失败的handler数量""" return sum(1 for result in self.results if not result.success) - def get_summary(self) -> Dict[str, Any]: + def get_summary(self) -> dict[str, Any]: """获取执行摘要""" return { "total_handlers": len(self.results), @@ -85,13 +85,13 @@ class HandlerResultsCollection: class BaseEvent: - def __init__(self, name: str, allowed_subscribers: List[str] = None, allowed_triggers: List[str] = None): + def __init__(self, name: str, allowed_subscribers: list[str] = None, allowed_triggers: list[str] = None): self.name = name self.enabled = True self.allowed_subscribers = allowed_subscribers # 记录事件处理器名 self.allowed_triggers = allowed_triggers # 记录插件名 - self.subscribers: List["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表 + self.subscribers: list["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表 self.event_handle_lock = asyncio.Lock() diff --git a/src/plugin_system/base/base_events_handler.py b/src/plugin_system/base/base_events_handler.py index 517de92c2..fa73dccc8 100644 --- a/src/plugin_system/base/base_events_handler.py +++ b/src/plugin_system/base/base_events_handler.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod -from typing import Tuple, Optional, List, Union from src.common.logger import get_logger -from .component_types import EventType, EventHandlerInfo, ComponentType + +from .component_types import ComponentType, EventHandlerInfo, EventType logger = get_logger("base_event_handler") @@ -21,7 +21,7 @@ class BaseEventHandler(ABC): """处理器权重,越大权重越高""" intercept_message: bool = False """是否拦截消息,默认为否""" - init_subscribe: List[Union[EventType, str]] = [EventType.UNKNOWN] + init_subscribe: list[EventType | str] = [EventType.UNKNOWN] """初始化时订阅的事件名称""" plugin_name = None @@ -44,7 +44,7 @@ class BaseEventHandler(ABC): self.plugin_config = getattr(self.__class__, "plugin_config", {}) @abstractmethod - async def execute(self, kwargs: dict | None) -> Tuple[bool, bool, Optional[str]]: + async def execute(self, kwargs: dict | None) -> tuple[bool, bool, str | None]: """执行事件处理的抽象方法,子类必须实现 Args: kwargs (dict | None): 事件消息对象,当你注册的事件为ON_START和ON_STOP时message为None diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index 8916fadfd..232365bce 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -1,13 +1,13 @@ from abc import abstractmethod -from typing import List, Type, Tuple, Union -from .plugin_base import PluginBase from src.common.logger import get_logger -from src.plugin_system.base.component_types import ActionInfo, CommandInfo, PlusCommandInfo, EventHandlerInfo, ToolInfo +from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, PlusCommandInfo, ToolInfo + from .base_action import BaseAction from .base_command import BaseCommand from .base_events_handler import BaseEventHandler from .base_tool import BaseTool +from .plugin_base import PluginBase from .plus_command import PlusCommand logger = get_logger("base_plugin") @@ -28,14 +28,12 @@ class BasePlugin(PluginBase): @abstractmethod def get_plugin_components( self, - ) -> List[ - Union[ - Tuple[ActionInfo, Type[BaseAction]], - Tuple[CommandInfo, Type[BaseCommand]], - Tuple[PlusCommandInfo, Type[PlusCommand]], - Tuple[EventHandlerInfo, Type[BaseEventHandler]], - Tuple[ToolInfo, Type[BaseTool]], - ] + ) -> list[ + tuple[ActionInfo, type[BaseAction]] + | tuple[CommandInfo, type[BaseCommand]] + | tuple[PlusCommandInfo, type[PlusCommand]] + | tuple[EventHandlerInfo, type[BaseEventHandler]] + | tuple[ToolInfo, type[BaseTool]] ]: """获取插件包含的组件列表 diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py index 229cadb63..5cd04b485 100644 --- a/src/plugin_system/base/base_tool.py +++ b/src/plugin_system/base/base_tool.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional, Tuple +from typing import Any + from rich.traceback import install from src.common.logger import get_logger @@ -17,7 +18,7 @@ class BaseTool(ABC): """工具的名称""" description: str = "" """工具的描述""" - parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = [] + parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = [] """工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式 param_name: 参数名称 param_type: 参数类型 @@ -35,7 +36,7 @@ class BaseTool(ABC): """是否为该工具启用缓存""" cache_ttl: int = 3600 """缓存的TTL值(秒),默认为3600秒(1小时)""" - semantic_cache_query_key: Optional[str] = None + semantic_cache_query_key: str | None = None """用于语义缓存的查询参数键名。如果设置,将使用此参数的值进行语义相似度搜索""" # 二步工具调用相关属性 @@ -43,10 +44,10 @@ class BaseTool(ABC): """是否为二步工具。如果为True,工具将分两步调用:第一步展示工具信息,第二步执行具体操作""" step_one_description: str = "" """第一步的描述,用于向LLM展示工具的基本功能""" - sub_tools: List[Tuple[str, str, List[Tuple[str, ToolParamType, str, bool, List[str] | None]]]] = [] + sub_tools: list[tuple[str, str, list[tuple[str, ToolParamType, str, bool, list[str] | None]]]] = [] """子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用""" - def __init__(self, plugin_config: Optional[dict] = None): + def __init__(self, plugin_config: dict | None = None): self.plugin_config = plugin_config or {} # 直接存储插件配置字典 @classmethod @@ -101,7 +102,7 @@ class BaseTool(ABC): raise ValueError(f"未找到子工具: {sub_tool_name}") @classmethod - def get_all_sub_tool_definitions(cls) -> List[dict[str, Any]]: + def get_all_sub_tool_definitions(cls) -> list[dict[str, Any]]: """获取所有子工具的定义 Returns: diff --git a/src/plugin_system/base/command_args.py b/src/plugin_system/base/command_args.py index 980eb958f..72d55dd6b 100644 --- a/src/plugin_system/base/command_args.py +++ b/src/plugin_system/base/command_args.py @@ -3,7 +3,6 @@ 提供简单易用的命令参数解析功能 """ -from typing import List, Optional import shlex @@ -20,7 +19,7 @@ class CommandArgs: raw_args: 原始参数字符串 """ self._raw_args = raw_args.strip() - self._parsed_args: Optional[List[str]] = None + self._parsed_args: list[str] | None = None def get_raw(self) -> str: """获取完整的参数字符串 @@ -30,7 +29,7 @@ class CommandArgs: """ return self._raw_args - def get_args(self) -> List[str]: + def get_args(self) -> list[str]: """获取解析后的参数列表 将参数按空格分割,支持引号包围的参数 diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 2b1122b9f..9ae921466 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -1,10 +1,11 @@ -from enum import Enum -from typing import Dict, Any, List, Optional, Tuple from dataclasses import dataclass, field +from enum import Enum +from typing import Any + from maim_message import Seg -from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType from src.llm_models.payload_content.tool_option import ToolCall as ToolCall +from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType # 组件类型枚举 @@ -114,7 +115,7 @@ class ComponentInfo: enabled: bool = True # 是否启用 plugin_name: str = "" # 所属插件名称 is_built_in: bool = False # 是否为内置组件 - metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据 + metadata: dict[str, Any] = field(default_factory=dict) # 额外元数据 def __post_init__(self): if self.metadata is None: @@ -125,18 +126,18 @@ class ComponentInfo: class ActionInfo(ComponentInfo): """动作组件信息""" - action_parameters: Dict[str, str] = field( + action_parameters: dict[str, str] = field( default_factory=dict ) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"} - action_require: List[str] = field(default_factory=list) # 动作需求说明 - associated_types: List[str] = field(default_factory=list) # 关联的消息类型 + action_require: list[str] = field(default_factory=list) # 动作需求说明 + associated_types: list[str] = field(default_factory=list) # 关联的消息类型 # 激活类型相关 focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS activation_type: ActionActivationType = ActionActivationType.ALWAYS random_activation_probability: float = 0.0 llm_judge_prompt: str = "" - activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表 + activation_keywords: list[str] = field(default_factory=list) # 激活关键词列表 keyword_case_sensitive: bool = False # 模式和并行设置 mode_enable: ChatMode = ChatMode.ALL @@ -145,7 +146,7 @@ class ActionInfo(ComponentInfo): # 二步Action相关属性 is_two_step_action: bool = False # 是否为二步Action step_one_description: str = "" # 第一步的描述 - sub_actions: List[Tuple[str, str, Dict[str, str]]] = field(default_factory=list) # 子Action列表 + sub_actions: list[tuple[str, str, dict[str, str]]] = field(default_factory=list) # 子Action列表 def __post_init__(self): super().__post_init__() @@ -178,7 +179,7 @@ class CommandInfo(ComponentInfo): class PlusCommandInfo(ComponentInfo): """增强命令组件信息""" - command_aliases: List[str] = field(default_factory=list) # 命令别名列表 + command_aliases: list[str] = field(default_factory=list) # 命令别名列表 priority: int = 0 # 命令优先级 chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型 intercept_message: bool = False # 是否拦截消息 @@ -194,7 +195,7 @@ class PlusCommandInfo(ComponentInfo): class ToolInfo(ComponentInfo): """工具组件信息""" - tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field( + tool_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = field( default_factory=list ) # 工具参数定义 tool_description: str = "" # 工具描述 @@ -248,18 +249,18 @@ class PluginInfo: author: str = "" # 插件作者 enabled: bool = True # 是否启用 is_built_in: bool = False # 是否为内置插件 - components: List[ComponentInfo] = field(default_factory=list) # 包含的组件列表 - dependencies: List[str] = field(default_factory=list) # 依赖的其他插件 - python_dependencies: List[PythonDependency] = field(default_factory=list) # Python包依赖 + components: list[ComponentInfo] = field(default_factory=list) # 包含的组件列表 + dependencies: list[str] = field(default_factory=list) # 依赖的其他插件 + python_dependencies: list[PythonDependency] = field(default_factory=list) # Python包依赖 config_file: str = "" # 配置文件路径 - metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据 + metadata: dict[str, Any] = field(default_factory=dict) # 额外元数据 # 新增:manifest相关信息 - manifest_data: Dict[str, Any] = field(default_factory=dict) # manifest文件数据 + manifest_data: dict[str, Any] = field(default_factory=dict) # manifest文件数据 license: str = "" # 插件许可证 homepage_url: str = "" # 插件主页 repository_url: str = "" # 插件仓库地址 - keywords: List[str] = field(default_factory=list) # 插件关键词 - categories: List[str] = field(default_factory=list) # 插件分类 + keywords: list[str] = field(default_factory=list) # 插件关键词 + categories: list[str] = field(default_factory=list) # 插件分类 min_host_version: str = "" # 最低主机版本要求 max_host_version: str = "" # 最高主机版本要求 @@ -279,7 +280,7 @@ class PluginInfo: if self.categories is None: self.categories = [] - def get_missing_packages(self) -> List[PythonDependency]: + def get_missing_packages(self) -> list[PythonDependency]: """检查缺失的Python包""" missing = [] for dep in self.python_dependencies: @@ -290,7 +291,7 @@ class PluginInfo: missing.append(dep) return missing - def get_pip_requirements(self) -> List[str]: + def get_pip_requirements(self) -> list[str]: """获取所有pip安装格式的依赖""" return [dep.get_pip_requirement() for dep in self.python_dependencies] @@ -299,16 +300,16 @@ class PluginInfo: class MaiMessages: """MaiM插件消息""" - message_segments: List[Seg] = field(default_factory=list) + message_segments: list[Seg] = field(default_factory=list) """消息段列表,支持多段消息""" - message_base_info: Dict[str, Any] = field(default_factory=dict) + message_base_info: dict[str, Any] = field(default_factory=dict) """消息基本信息,包含平台,用户信息等数据""" plain_text: str = "" """纯文本消息内容""" - raw_message: Optional[str] = None + raw_message: str | None = None """原始消息内容""" is_group_message: bool = False @@ -317,28 +318,28 @@ class MaiMessages: is_private_message: bool = False """是否为私聊消息""" - stream_id: Optional[str] = None + stream_id: str | None = None """流ID,用于标识消息流""" - llm_prompt: Optional[str] = None + llm_prompt: str | None = None """LLM提示词""" - llm_response_content: Optional[str] = None + llm_response_content: str | None = None """LLM响应内容""" - llm_response_reasoning: Optional[str] = None + llm_response_reasoning: str | None = None """LLM响应推理内容""" - llm_response_model: Optional[str] = None + llm_response_model: str | None = None """LLM响应模型名称""" - llm_response_tool_call: Optional[List[ToolCall]] = None + llm_response_tool_call: list[ToolCall] | None = None """LLM使用的工具调用""" - action_usage: Optional[List[str]] = None + action_usage: list[str] | None = None """使用的Action""" - additional_data: Dict[Any, Any] = field(default_factory=dict) + additional_data: dict[Any, Any] = field(default_factory=dict) """附加数据,可以存储额外信息""" def __post_init__(self): diff --git a/src/plugin_system/base/config_types.py b/src/plugin_system/base/config_types.py index 752b33453..9dc9b58eb 100644 --- a/src/plugin_system/base/config_types.py +++ b/src/plugin_system/base/config_types.py @@ -2,8 +2,8 @@ 插件系统配置类型定义 """ -from typing import Any, Optional, List from dataclasses import dataclass, field +from typing import Any @dataclass @@ -13,6 +13,6 @@ class ConfigField: type: type # 字段类型 default: Any # 默认值 description: str # 字段描述 - example: Optional[str] = None # 示例值 + example: str | None = None # 示例值 required: bool = False # 是否必需 - choices: Optional[List[Any]] = field(default_factory=list) # 可选值列表 + choices: list[Any] | None = field(default_factory=list) # 可选值列表 diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index a61b8e04c..8cc3312db 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -1,11 +1,12 @@ -from abc import ABC, abstractmethod -from typing import Dict, List, Any, Union -import os -import toml -import orjson -import shutil import datetime +import os +import shutil +from abc import ABC, abstractmethod from pathlib import Path +from typing import Any + +import orjson +import toml from src.common.logger import get_logger from src.config.config import CONFIG_DIR @@ -38,12 +39,12 @@ class PluginBase(ABC): @property @abstractmethod - def dependencies(self) -> List[str]: + def dependencies(self) -> list[str]: return [] # 依赖的其他插件 @property @abstractmethod - def python_dependencies(self) -> List[Union[str, PythonDependency]]: + def python_dependencies(self) -> list[str | PythonDependency]: return [] # Python包依赖,支持字符串列表或PythonDependency对象列表 @property @@ -53,15 +54,15 @@ class PluginBase(ABC): # manifest文件相关 manifest_file_name: str = "_manifest.json" # manifest文件名 - manifest_data: Dict[str, Any] = {} # manifest数据 + manifest_data: dict[str, Any] = {} # manifest数据 # 配置定义 @property @abstractmethod - def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]: + def config_schema(self) -> dict[str, dict[str, ConfigField] | str]: return {} - config_section_descriptions: Dict[str, str] = {} + config_section_descriptions: dict[str, str] = {} def __init__(self, plugin_dir: str): """初始化插件 @@ -69,7 +70,7 @@ class PluginBase(ABC): Args: plugin_dir: 插件目录路径,由插件管理器传递 """ - self.config: Dict[str, Any] = {} # 插件配置 + self.config: dict[str, Any] = {} # 插件配置 self.plugin_dir = plugin_dir # 插件目录路径 self.log_prefix = f"[Plugin:{self.plugin_name}]" self._is_enabled = self.enable_plugin # 从插件定义中获取默认启用状态 @@ -144,7 +145,7 @@ class PluginBase(ABC): raise FileNotFoundError(error_msg) try: - with open(manifest_path, "r", encoding="utf-8") as f: + with open(manifest_path, encoding="utf-8") as f: self.manifest_data = orjson.loads(f.read()) logger.debug(f"{self.log_prefix} 成功加载manifest文件: {manifest_path}") @@ -155,8 +156,8 @@ class PluginBase(ABC): except orjson.JSONDecodeError as e: error_msg = f"{self.log_prefix} manifest文件格式错误: {e}" logger.error(error_msg) - raise ValueError(error_msg) # noqa - except IOError as e: + raise ValueError(error_msg) + except OSError as e: error_msg = f"{self.log_prefix} 读取manifest文件失败: {e}" logger.error(error_msg) raise IOError(error_msg) # noqa @@ -266,7 +267,7 @@ class PluginBase(ABC): with open(config_file_path, "w", encoding="utf-8") as f: f.write(toml_str) logger.info(f"{self.log_prefix} 已生成默认配置文件: {config_file_path}") - except IOError as e: + except OSError as e: logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True) def _backup_config_file(self, config_file_path: str) -> str: @@ -288,13 +289,13 @@ class PluginBase(ABC): return "" def _synchronize_config( - self, schema_config: Dict[str, Any], user_config: Dict[str, Any] - ) -> tuple[Dict[str, Any], bool]: + self, schema_config: dict[str, Any], user_config: dict[str, Any] + ) -> tuple[dict[str, Any], bool]: """递归地将用户配置与 schema 同步,返回同步后的配置和是否发生变化的标志""" changed = False # 内部递归函数 - def _sync_dicts(schema_dict: Dict[str, Any], user_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, Any]: + def _sync_dicts(schema_dict: dict[str, Any], user_dict: dict[str, Any], parent_key: str = "") -> dict[str, Any]: nonlocal changed synced_dict = schema_dict.copy() @@ -326,7 +327,7 @@ class PluginBase(ABC): final_config = _sync_dicts(schema_config, user_config) return final_config, changed - def _generate_config_from_schema(self) -> Dict[str, Any]: + def _generate_config_from_schema(self) -> dict[str, Any]: # sourcery skip: dict-comprehension """根据schema生成配置数据结构(不写入文件)""" if not self.config_schema: @@ -348,7 +349,7 @@ class PluginBase(ABC): return config_data - def _save_config_to_file(self, config_data: Dict[str, Any], config_file_path: str): + def _save_config_to_file(self, config_data: dict[str, Any], config_file_path: str): """将配置数据保存为TOML文件(包含注释)""" if not self.config_schema: logger.debug(f"{self.log_prefix} 插件未定义config_schema,不生成配置文件") @@ -410,7 +411,7 @@ class PluginBase(ABC): with open(config_file_path, "w", encoding="utf-8") as f: f.write(toml_str) logger.info(f"{self.log_prefix} 配置文件已保存: {config_file_path}") - except IOError as e: + except OSError as e: logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True) def _load_plugin_config(self): # sourcery skip: extract-method @@ -456,7 +457,7 @@ class PluginBase(ABC): return try: - with open(user_config_path, "r", encoding="utf-8") as f: + with open(user_config_path, encoding="utf-8") as f: user_config = toml.load(f) or {} except Exception as e: logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True) @@ -520,7 +521,7 @@ class PluginBase(ABC): return current - def _normalize_python_dependencies(self, dependencies: Any) -> List[PythonDependency]: + def _normalize_python_dependencies(self, dependencies: Any) -> list[PythonDependency]: """将依赖列表标准化为PythonDependency对象""" from packaging.requirements import Requirement @@ -549,7 +550,7 @@ class PluginBase(ABC): return normalized - def _check_python_dependencies(self, dependencies: List[PythonDependency]) -> bool: + def _check_python_dependencies(self, dependencies: list[PythonDependency]) -> bool: """检查Python依赖并尝试自动安装""" if not dependencies: logger.info(f"{self.log_prefix} 无Python依赖需要检查") diff --git a/src/plugin_system/base/plus_command.py b/src/plugin_system/base/plus_command.py index a64866806..1319560b6 100644 --- a/src/plugin_system/base/plus_command.py +++ b/src/plugin_system/base/plus_command.py @@ -3,17 +3,16 @@ 提供更简单易用的命令处理方式,无需手写正则表达式 """ -from abc import ABC, abstractmethod -from typing import Tuple, Optional, List import re +from abc import ABC, abstractmethod -from src.common.logger import get_logger -from src.plugin_system.base.component_types import PlusCommandInfo, ComponentType, ChatType from src.chat.message_receive.message import MessageRecv -from src.plugin_system.apis import send_api -from src.plugin_system.base.command_args import CommandArgs -from src.plugin_system.base.base_command import BaseCommand +from src.common.logger import get_logger from src.config.config import global_config +from src.plugin_system.apis import send_api +from src.plugin_system.base.base_command import BaseCommand +from src.plugin_system.base.command_args import CommandArgs +from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo logger = get_logger("plus_command") @@ -39,7 +38,7 @@ class PlusCommand(ABC): command_description: str = "" """命令描述""" - command_aliases: List[str] = [] + command_aliases: list[str] = [] """命令别名列表,如 ['say', 'repeat']""" priority: int = 0 @@ -51,7 +50,7 @@ class PlusCommand(ABC): intercept_message: bool = False """是否拦截消息,不进行后续处理""" - def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None): + def __init__(self, message: MessageRecv, plugin_config: dict | None = None): """初始化命令组件 Args: @@ -172,7 +171,7 @@ class PlusCommand(ABC): return False @abstractmethod - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: """执行命令的抽象方法,子类必须实现 Args: @@ -341,7 +340,7 @@ class PlusCommandAdapter(BaseCommand): 将PlusCommand适配到现有的插件系统,继承BaseCommand """ - def __init__(self, plus_command_class, message: MessageRecv, plugin_config: Optional[dict] = None): + def __init__(self, plus_command_class, message: MessageRecv, plugin_config: dict | None = None): """初始化适配器 Args: @@ -363,7 +362,7 @@ class PlusCommandAdapter(BaseCommand): # 创建PlusCommand实例 self.plus_command = plus_command_class(message, plugin_config) - async def execute(self) -> Tuple[bool, Optional[str], bool]: + async def execute(self) -> tuple[bool, str | None, bool]: """执行命令 Returns: @@ -382,7 +381,7 @@ class PlusCommandAdapter(BaseCommand): return await self.plus_command.execute(self.plus_command.args) except Exception as e: logger.error(f"执行命令时出错: {e}", exc_info=True) - return False, f"命令执行出错: {str(e)}", self.intercept_message + return False, f"命令执行出错: {e!s}", self.intercept_message def create_plus_command_adapter(plus_command_class): @@ -401,13 +400,13 @@ def create_plus_command_adapter(plus_command_class): command_pattern = plus_command_class._generate_command_pattern() chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL) - def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None): + def __init__(self, message: MessageRecv, plugin_config: dict | None = None): super().__init__(message, plugin_config) self.plus_command = plus_command_class(message, plugin_config) self.priority = getattr(plus_command_class, "priority", 0) self.intercept_message = getattr(plus_command_class, "intercept_message", False) - async def execute(self) -> Tuple[bool, Optional[str], bool]: + async def execute(self) -> tuple[bool, str | None, bool]: """执行命令""" # 从BaseCommand的正则匹配结果中提取参数 args_text = "" @@ -429,7 +428,7 @@ def create_plus_command_adapter(plus_command_class): return await self.plus_command.execute(command_args) except Exception as e: logger.error(f"执行命令时出错: {e}", exc_info=True) - return False, f"命令执行出错: {str(e)}", self.intercept_message + return False, f"命令执行出错: {e!s}", self.intercept_message return AdapterClass diff --git a/src/plugin_system/core/__init__.py b/src/plugin_system/core/__init__.py index 46aa5a96c..4e43fba11 100644 --- a/src/plugin_system/core/__init__.py +++ b/src/plugin_system/core/__init__.py @@ -4,14 +4,14 @@ 提供插件的加载、注册和管理功能 """ -from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager +from src.plugin_system.core.plugin_manager import plugin_manager __all__ = [ - "plugin_manager", "component_registry", "event_manager", "global_announcement_manager", + "plugin_manager", ] diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 9c82553f8..878b6c465 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -1,27 +1,26 @@ -from pathlib import Path import re - -from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type +from pathlib import Path +from re import Pattern +from typing import Any, Optional, Union from src.common.logger import get_logger -from src.plugin_system.base.component_types import ( - ComponentInfo, - ActionInfo, - ToolInfo, - CommandInfo, - PlusCommandInfo, - EventHandlerInfo, - ChatterInfo, - PluginInfo, - ComponentType, -) - -from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_action import BaseAction -from src.plugin_system.base.base_tool import BaseTool -from src.plugin_system.base.base_events_handler import BaseEventHandler -from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.base.base_chatter import BaseChatter +from src.plugin_system.base.base_command import BaseCommand +from src.plugin_system.base.base_events_handler import BaseEventHandler +from src.plugin_system.base.base_tool import BaseTool +from src.plugin_system.base.component_types import ( + ActionInfo, + ChatterInfo, + CommandInfo, + ComponentInfo, + ComponentType, + EventHandlerInfo, + PluginInfo, + PlusCommandInfo, + ToolInfo, +) +from src.plugin_system.base.plus_command import PlusCommand logger = get_logger("component_registry") @@ -34,46 +33,46 @@ class ComponentRegistry: def __init__(self): # 命名空间式组件名构成法 f"{component_type}.{component_name}" - self._components: Dict[str, "ComponentInfo"] = {} + self._components: dict[str, "ComponentInfo"] = {} """组件注册表 命名空间式组件名 -> 组件信息""" - self._components_by_type: Dict["ComponentType", Dict[str, "ComponentInfo"]] = { + self._components_by_type: dict["ComponentType", dict[str, "ComponentInfo"]] = { types: {} for types in ComponentType } """类型 -> 组件原名称 -> 组件信息""" - self._components_classes: Dict[ - str, Type[Union["BaseCommand", "BaseAction", "BaseTool", "BaseEventHandler", "PlusCommand", "BaseChatter"]] + self._components_classes: dict[ + str, type["BaseCommand" | "BaseAction" | "BaseTool" | "BaseEventHandler" | "PlusCommand" | "BaseChatter"] ] = {} """命名空间式组件名 -> 组件类""" # 插件注册表 - self._plugins: Dict[str, "PluginInfo"] = {} + self._plugins: dict[str, "PluginInfo"] = {} """插件名 -> 插件信息""" # Action特定注册表 - self._action_registry: Dict[str, Type["BaseAction"]] = {} + self._action_registry: dict[str, type["BaseAction"]] = {} """Action注册表 action名 -> action类""" - self._default_actions: Dict[str, "ActionInfo"] = {} + self._default_actions: dict[str, "ActionInfo"] = {} """默认动作集,即启用的Action集,用于重置ActionManager状态""" # Command特定注册表 - self._command_registry: Dict[str, Type["BaseCommand"]] = {} + self._command_registry: dict[str, type["BaseCommand"]] = {} """Command类注册表 command名 -> command类""" - self._command_patterns: Dict[Pattern, str] = {} + self._command_patterns: dict[Pattern, str] = {} """编译后的正则 -> command名""" # 工具特定注册表 - self._tool_registry: Dict[str, Type["BaseTool"]] = {} # 工具名 -> 工具类 - self._llm_available_tools: Dict[str, Type["BaseTool"]] = {} # llm可用的工具名 -> 工具类 + self._tool_registry: dict[str, type["BaseTool"]] = {} # 工具名 -> 工具类 + self._llm_available_tools: dict[str, type["BaseTool"]] = {} # llm可用的工具名 -> 工具类 # EventHandler特定注册表 - self._event_handler_registry: Dict[str, Type["BaseEventHandler"]] = {} + self._event_handler_registry: dict[str, type["BaseEventHandler"]] = {} """event_handler名 -> event_handler类""" - self._enabled_event_handlers: Dict[str, Type["BaseEventHandler"]] = {} + self._enabled_event_handlers: dict[str, type["BaseEventHandler"]] = {} """启用的事件处理器 event_handler名 -> event_handler类""" - self._chatter_registry: Dict[str, Type["BaseChatter"]] = {} + self._chatter_registry: dict[str, type["BaseChatter"]] = {} """chatter名 -> chatter类""" - self._enabled_chatter_registry: Dict[str, Type["BaseChatter"]] = {} + self._enabled_chatter_registry: dict[str, type["BaseChatter"]] = {} """启用的chatter名 -> chatter类""" logger.info("组件注册中心初始化完成") @@ -101,7 +100,7 @@ class ComponentRegistry: def register_component( self, component_info: ComponentInfo, - component_class: Type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]], + component_class: type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]], ) -> bool: """注册组件 @@ -174,7 +173,7 @@ class ComponentRegistry: ) return True - def _register_action_component(self, action_info: "ActionInfo", action_class: Type["BaseAction"]) -> bool: + def _register_action_component(self, action_info: "ActionInfo", action_class: type["BaseAction"]) -> bool: """注册Action组件到Action特定注册表""" if not (action_name := action_info.name): logger.error(f"Action组件 {action_class.__name__} 必须指定名称") @@ -194,7 +193,7 @@ class ComponentRegistry: return True - def _register_command_component(self, command_info: "CommandInfo", command_class: Type["BaseCommand"]) -> bool: + def _register_command_component(self, command_info: "CommandInfo", command_class: type["BaseCommand"]) -> bool: """注册Command组件到Command特定注册表""" if not (command_name := command_info.name): logger.error(f"Command组件 {command_class.__name__} 必须指定名称") @@ -221,7 +220,7 @@ class ComponentRegistry: return True def _register_plus_command_component( - self, plus_command_info: "PlusCommandInfo", plus_command_class: Type["PlusCommand"] + self, plus_command_info: "PlusCommandInfo", plus_command_class: type["PlusCommand"] ) -> bool: """注册PlusCommand组件到特定注册表""" plus_command_name = plus_command_info.name @@ -235,7 +234,7 @@ class ComponentRegistry: # 创建专门的PlusCommand注册表(如果还没有) if not hasattr(self, "_plus_command_registry"): - self._plus_command_registry: Dict[str, Type["PlusCommand"]] = {} + self._plus_command_registry: dict[str, type["PlusCommand"]] = {} plus_command_class.plugin_name = plus_command_info.plugin_name # 设置插件配置 @@ -245,7 +244,7 @@ class ComponentRegistry: logger.debug(f"已注册PlusCommand组件: {plus_command_name}") return True - def _register_tool_component(self, tool_info: "ToolInfo", tool_class: Type["BaseTool"]) -> bool: + def _register_tool_component(self, tool_info: "ToolInfo", tool_class: type["BaseTool"]) -> bool: """注册Tool组件到Tool特定注册表""" tool_name = tool_info.name @@ -261,7 +260,7 @@ class ComponentRegistry: return True def _register_event_handler_component( - self, handler_info: "EventHandlerInfo", handler_class: Type["BaseEventHandler"] + self, handler_info: "EventHandlerInfo", handler_class: type["BaseEventHandler"] ) -> bool: if not (handler_name := handler_info.name): logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称") @@ -287,7 +286,7 @@ class ComponentRegistry: handler_class, self.get_plugin_config(handler_info.plugin_name) or {} ) - def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: Type["BaseChatter"]) -> bool: + def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: type["BaseChatter"]) -> bool: """注册Chatter组件到Chatter特定注册表""" chatter_name = chatter_info.name @@ -532,7 +531,7 @@ class ComponentRegistry: self, component_name: str, component_type: Optional["ComponentType"] = None, - ) -> Optional[Union[Type["BaseCommand"], Type["BaseAction"], Type["BaseEventHandler"], Type["BaseTool"]]]: + ) -> type["BaseCommand"] | type["BaseAction"] | type["BaseEventHandler"] | type["BaseTool"] | None: """获取组件类,支持自动命名空间解析 Args: @@ -574,18 +573,18 @@ class ComponentRegistry: # 4. 都没找到 return None - def get_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]: + def get_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]: """获取指定类型的所有组件""" return self._components_by_type.get(component_type, {}).copy() - def get_enabled_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]: + def get_enabled_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]: """获取指定类型的所有启用组件""" components = self.get_components_by_type(component_type) return {name: info for name, info in components.items() if info.enabled} # === Action特定查询方法 === - def get_action_registry(self) -> Dict[str, Type["BaseAction"]]: + def get_action_registry(self) -> dict[str, type["BaseAction"]]: """获取Action注册表""" return self._action_registry.copy() @@ -594,13 +593,13 @@ class ComponentRegistry: info = self.get_component_info(action_name, ComponentType.ACTION) return info if isinstance(info, ActionInfo) else None - def get_default_actions(self) -> Dict[str, ActionInfo]: + def get_default_actions(self) -> dict[str, ActionInfo]: """获取默认动作集""" return self._default_actions.copy() # === Command特定查询方法 === - def get_command_registry(self) -> Dict[str, Type["BaseCommand"]]: + def get_command_registry(self) -> dict[str, type["BaseCommand"]]: """获取Command注册表""" return self._command_registry.copy() @@ -609,11 +608,11 @@ class ComponentRegistry: info = self.get_component_info(command_name, ComponentType.COMMAND) return info if isinstance(info, CommandInfo) else None - def get_command_patterns(self) -> Dict[Pattern, str]: + def get_command_patterns(self) -> dict[Pattern, str]: """获取Command模式注册表""" return self._command_patterns.copy() - def find_command_by_text(self, text: str) -> Optional[Tuple[Type["BaseCommand"], dict, "CommandInfo"]]: + def find_command_by_text(self, text: str) -> tuple[type["BaseCommand"], dict, "CommandInfo"] | None: # sourcery skip: use-named-expression, use-next """根据文本查找匹配的命令 @@ -640,11 +639,11 @@ class ComponentRegistry: return None # === Tool 特定查询方法 === - def get_tool_registry(self) -> Dict[str, Type["BaseTool"]]: + def get_tool_registry(self) -> dict[str, type["BaseTool"]]: """获取Tool注册表""" return self._tool_registry.copy() - def get_llm_available_tools(self) -> Dict[str, Type["BaseTool"]]: + def get_llm_available_tools(self) -> dict[str, type["BaseTool"]]: """获取LLM可用的Tool列表""" return self._llm_available_tools.copy() @@ -661,10 +660,10 @@ class ComponentRegistry: return info if isinstance(info, ToolInfo) else None # === PlusCommand 特定查询方法 === - def get_plus_command_registry(self) -> Dict[str, Type["PlusCommand"]]: + def get_plus_command_registry(self) -> dict[str, type["PlusCommand"]]: """获取PlusCommand注册表""" if not hasattr(self, "_plus_command_registry"): - self._plus_command_registry: Dict[str, Type[PlusCommand]] = {} + self._plus_command_registry: dict[str, type[PlusCommand]] = {} return self._plus_command_registry.copy() def get_registered_plus_command_info(self, command_name: str) -> Optional["PlusCommandInfo"]: @@ -681,7 +680,7 @@ class ComponentRegistry: # === EventHandler 特定查询方法 === - def get_event_handler_registry(self) -> Dict[str, Type["BaseEventHandler"]]: + def get_event_handler_registry(self) -> dict[str, type["BaseEventHandler"]]: """获取事件处理器注册表""" return self._event_handler_registry.copy() @@ -690,21 +689,21 @@ class ComponentRegistry: info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER) return info if isinstance(info, EventHandlerInfo) else None - def get_enabled_event_handlers(self) -> Dict[str, Type["BaseEventHandler"]]: + def get_enabled_event_handlers(self) -> dict[str, type["BaseEventHandler"]]: """获取启用的事件处理器""" return self._enabled_event_handlers.copy() # === Chatter 特定查询方法 === - def get_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]: + def get_chatter_registry(self) -> dict[str, type["BaseChatter"]]: """获取Chatter注册表""" if not hasattr(self, "_chatter_registry"): - self._chatter_registry: Dict[str, Type[BaseChatter]] = {} + self._chatter_registry: dict[str, type[BaseChatter]] = {} return self._chatter_registry.copy() - def get_enabled_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]: + def get_enabled_chatter_registry(self) -> dict[str, type["BaseChatter"]]: """获取启用的Chatter注册表""" if not hasattr(self, "_enabled_chatter_registry"): - self._enabled_chatter_registry: Dict[str, Type[BaseChatter]] = {} + self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {} return self._enabled_chatter_registry.copy() def get_registered_chatter_info(self, chatter_name: str) -> Optional["ChatterInfo"]: @@ -718,7 +717,7 @@ class ComponentRegistry: """获取插件信息""" return self._plugins.get(plugin_name) - def get_all_plugins(self) -> Dict[str, "PluginInfo"]: + def get_all_plugins(self) -> dict[str, "PluginInfo"]: """获取所有插件""" return self._plugins.copy() @@ -726,7 +725,7 @@ class ComponentRegistry: # """获取所有启用的插件""" # return {name: info for name, info in self._plugins.items() if info.enabled} - def get_plugin_components(self, plugin_name: str) -> List["ComponentInfo"]: + def get_plugin_components(self, plugin_name: str) -> list["ComponentInfo"]: """获取插件的所有组件""" plugin_info = self.get_plugin_info(plugin_name) return plugin_info.components if plugin_info else [] @@ -753,7 +752,7 @@ class ComponentRegistry: config_path = Path("config") / "plugins" / plugin_name / "config.toml" if config_path.exists(): - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: config_data = toml.load(f) logger.debug(f"从配置文件读取插件 {plugin_name} 的配置") return config_data @@ -762,7 +761,7 @@ class ComponentRegistry: return {} - def get_registry_stats(self) -> Dict[str, Any]: + def get_registry_stats(self) -> dict[str, Any]: """获取注册中心统计信息""" action_components: int = 0 command_components: int = 0 diff --git a/src/plugin_system/core/event_manager.py b/src/plugin_system/core/event_manager.py index dac75b88f..8a7c7d66c 100644 --- a/src/plugin_system/core/event_manager.py +++ b/src/plugin_system/core/event_manager.py @@ -3,8 +3,8 @@ 提供统一的事件注册、管理和触发接口 """ -from typing import Dict, Type, List, Optional, Any, Union from threading import Lock +from typing import Any, Optional from src.common.logger import get_logger from src.plugin_system import BaseEventHandler @@ -37,17 +37,17 @@ class EventManager: if self._initialized: return - self._events: Dict[str, BaseEvent] = {} - self._event_handlers: Dict[str, Type[BaseEventHandler]] = {} - self._pending_subscriptions: Dict[str, List[str]] = {} # 缓存失败的订阅 + self._events: dict[str, BaseEvent] = {} + self._event_handlers: dict[str, type[BaseEventHandler]] = {} + self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅 self._initialized = True logger.info("EventManager 单例初始化完成") def register_event( self, - event_name: Union[EventType, str], - allowed_subscribers: List[str] = None, - allowed_triggers: List[str] = None, + event_name: EventType | str, + allowed_subscribers: list[str] = None, + allowed_triggers: list[str] = None, ) -> bool: """注册一个新的事件 @@ -75,7 +75,7 @@ class EventManager: return True - def get_event(self, event_name: Union[EventType, str]) -> Optional[BaseEvent]: + def get_event(self, event_name: EventType | str) -> BaseEvent | None: """获取指定事件实例 Args: @@ -86,7 +86,7 @@ class EventManager: """ return self._events.get(event_name) - def get_all_events(self) -> Dict[str, BaseEvent]: + def get_all_events(self) -> dict[str, BaseEvent]: """获取所有已注册的事件 Returns: @@ -94,7 +94,7 @@ class EventManager: """ return self._events.copy() - def get_enabled_events(self) -> Dict[str, BaseEvent]: + def get_enabled_events(self) -> dict[str, BaseEvent]: """获取所有已启用的事件 Returns: @@ -102,7 +102,7 @@ class EventManager: """ return {name: event for name, event in self._events.items() if event.enabled} - def get_disabled_events(self) -> Dict[str, BaseEvent]: + def get_disabled_events(self) -> dict[str, BaseEvent]: """获取所有已禁用的事件 Returns: @@ -110,7 +110,7 @@ class EventManager: """ return {name: event for name, event in self._events.items() if not event.enabled} - def enable_event(self, event_name: Union[EventType, str]) -> bool: + def enable_event(self, event_name: EventType | str) -> bool: """启用指定事件 Args: @@ -128,7 +128,7 @@ class EventManager: logger.info(f"事件 {event_name} 已启用") return True - def disable_event(self, event_name: Union[EventType, str]) -> bool: + def disable_event(self, event_name: EventType | str) -> bool: """禁用指定事件 Args: @@ -146,9 +146,7 @@ class EventManager: logger.info(f"事件 {event_name} 已禁用") return True - def register_event_handler( - self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None - ) -> bool: + def register_event_handler(self, handler_class: type[BaseEventHandler], plugin_config: dict | None = None) -> bool: """注册事件处理器 Args: @@ -190,7 +188,7 @@ class EventManager: logger.info(f"事件处理器 {handler_name} 注册成功") return True - def get_event_handler(self, handler_name: str) -> Optional[Type[BaseEventHandler]]: + def get_event_handler(self, handler_name: str) -> type[BaseEventHandler] | None: """获取指定事件处理器实例 Args: @@ -209,7 +207,7 @@ class EventManager: """ return self._event_handlers.copy() - def subscribe_handler_to_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool: + def subscribe_handler_to_event(self, handler_name: str, event_name: EventType | str) -> bool: """订阅事件处理器到指定事件 Args: @@ -246,7 +244,7 @@ class EventManager: logger.info(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成") return True - def unsubscribe_handler_from_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool: + def unsubscribe_handler_from_event(self, handler_name: str, event_name: EventType | str) -> bool: """从指定事件取消订阅事件处理器 Args: @@ -276,7 +274,7 @@ class EventManager: return removed - def get_event_subscribers(self, event_name: Union[EventType, str]) -> Dict[str, BaseEventHandler]: + def get_event_subscribers(self, event_name: EventType | str) -> dict[str, BaseEventHandler]: """获取订阅指定事件的所有事件处理器 Args: @@ -292,8 +290,8 @@ class EventManager: return {handler.handler_name: handler for handler in event.subscribers} async def trigger_event( - self, event_name: Union[EventType, str], permission_group: Optional[str] = "", **kwargs - ) -> Optional[HandlerResultsCollection]: + self, event_name: EventType | str, permission_group: str | None = "", **kwargs + ) -> HandlerResultsCollection | None: """触发指定事件 Args: @@ -345,7 +343,7 @@ class EventManager: self._event_handlers.clear() logger.info("所有事件和处理器已清除") - def get_event_summary(self) -> Dict[str, Any]: + def get_event_summary(self) -> dict[str, Any]: """获取事件系统摘要 Returns: @@ -364,7 +362,7 @@ class EventManager: "pending_subscriptions": len(self._pending_subscriptions), } - def _process_pending_subscriptions(self, event_name: Union[EventType, str]) -> None: + def _process_pending_subscriptions(self, event_name: EventType | str) -> None: """处理指定事件的缓存订阅 Args: diff --git a/src/plugin_system/core/global_announcement_manager.py b/src/plugin_system/core/global_announcement_manager.py index 05abf0b79..1dca4a53a 100644 --- a/src/plugin_system/core/global_announcement_manager.py +++ b/src/plugin_system/core/global_announcement_manager.py @@ -1,5 +1,3 @@ -from typing import List, Dict - from src.common.logger import get_logger logger = get_logger("global_announcement_manager") @@ -8,13 +6,13 @@ logger = get_logger("global_announcement_manager") class GlobalAnnouncementManager: def __init__(self) -> None: # 用户禁用的动作,chat_id -> [action_name] - self._user_disabled_actions: Dict[str, List[str]] = {} + self._user_disabled_actions: dict[str, list[str]] = {} # 用户禁用的命令,chat_id -> [command_name] - self._user_disabled_commands: Dict[str, List[str]] = {} + self._user_disabled_commands: dict[str, list[str]] = {} # 用户禁用的事件处理器,chat_id -> [handler_name] - self._user_disabled_event_handlers: Dict[str, List[str]] = {} + self._user_disabled_event_handlers: dict[str, list[str]] = {} # 用户禁用的工具,chat_id -> [tool_name] - self._user_disabled_tools: Dict[str, List[str]] = {} + self._user_disabled_tools: dict[str, list[str]] = {} def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool: """禁用特定聊天的某个动作""" @@ -100,19 +98,19 @@ class GlobalAnnouncementManager: return False return False - def get_disabled_chat_actions(self, chat_id: str) -> List[str]: + def get_disabled_chat_actions(self, chat_id: str) -> list[str]: """获取特定聊天禁用的所有动作""" return self._user_disabled_actions.get(chat_id, []).copy() - def get_disabled_chat_commands(self, chat_id: str) -> List[str]: + def get_disabled_chat_commands(self, chat_id: str) -> list[str]: """获取特定聊天禁用的所有命令""" return self._user_disabled_commands.get(chat_id, []).copy() - def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]: + def get_disabled_chat_event_handlers(self, chat_id: str) -> list[str]: """获取特定聊天禁用的所有事件处理器""" return self._user_disabled_event_handlers.get(chat_id, []).copy() - def get_disabled_chat_tools(self, chat_id: str) -> List[str]: + def get_disabled_chat_tools(self, chat_id: str) -> list[str]: """获取特定聊天禁用的所有工具""" return self._user_disabled_tools.get(chat_id, []).copy() diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index 0bb22afdf..99f00340c 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -4,16 +4,16 @@ 这个模块提供了权限系统的核心实现,包括权限检查、权限节点管理、用户权限管理等功能。 """ -from typing import List, Set, Tuple -from sqlalchemy.ext.asyncio import async_sessionmaker -from sqlalchemy.exc import IntegrityError, SQLAlchemyError from datetime import datetime -from sqlalchemy import select, delete +from sqlalchemy import delete, select +from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.ext.asyncio import async_sessionmaker + +from src.common.database.sqlalchemy_models import PermissionNodes, UserPermissions, get_engine from src.common.logger import get_logger -from src.common.database.sqlalchemy_models import get_engine, PermissionNodes, UserPermissions -from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo from src.config.config import global_config +from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo logger = get_logger(__name__) @@ -24,7 +24,7 @@ class PermissionManager(IPermissionManager): def __init__(self): self.engine = None self.SessionLocal = None - self._master_users: Set[Tuple[str, str]] = set() + self._master_users: set[tuple[str, str]] = set() self._load_master_users() async def initialize(self): @@ -276,7 +276,7 @@ class PermissionManager(IPermissionManager): logger.error(f"撤销权限时发生未知错误: {e}") return False - async def get_user_permissions(self, user: UserInfo) -> List[str]: + async def get_user_permissions(self, user: UserInfo) -> list[str]: """ 获取用户拥有的所有权限节点 @@ -328,7 +328,7 @@ class PermissionManager(IPermissionManager): logger.error(f"获取用户权限时发生未知错误: {e}") return [] - async def get_all_permission_nodes(self) -> List[PermissionNode]: + async def get_all_permission_nodes(self) -> list[PermissionNode]: """ 获取所有已注册的权限节点 @@ -356,7 +356,7 @@ class PermissionManager(IPermissionManager): logger.error(f"获取所有权限节点时发生未知错误: {e}") return [] - async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: + async def get_plugin_permission_nodes(self, plugin_name: str) -> list[PermissionNode]: """ 获取指定插件的所有权限节点 @@ -431,7 +431,7 @@ class PermissionManager(IPermissionManager): logger.error(f"删除插件权限时发生未知错误: {e}") return False - async def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]: + async def get_users_with_permission(self, permission_node: str) -> list[tuple[str, str]]: """ 获取拥有指定权限的所有用户 diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 2950101a9..046c05b4f 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -1,19 +1,17 @@ import asyncio +import importlib import os import traceback -import importlib - -from typing import Dict, List, Optional, Tuple, Type, Any -from importlib.util import spec_from_file_location, module_from_spec +from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path - +from typing import Any, Optional from src.common.logger import get_logger -from src.plugin_system.base.plugin_base import PluginBase from src.plugin_system.base.component_types import ComponentType +from src.plugin_system.base.plugin_base import PluginBase from src.plugin_system.utils.manifest_utils import VersionComparator -from .component_registry import component_registry +from .component_registry import component_registry logger = get_logger("plugin_manager") @@ -26,12 +24,12 @@ class PluginManager: """ def __init__(self): - self.plugin_directories: List[str] = [] # 插件根目录列表 - self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类 - self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径 + self.plugin_directories: list[str] = [] # 插件根目录列表 + self.plugin_classes: dict[str, type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类 + self.plugin_paths: dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径 - self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例 - self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息 + self.loaded_plugins: dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例 + self.failed_plugins: dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息 # 确保插件目录存在 self._ensure_plugin_directories() @@ -54,7 +52,7 @@ class PluginManager: # === 插件加载管理 === - def load_all_plugins(self) -> Tuple[int, int]: + def load_all_plugins(self) -> tuple[int, int]: """加载所有插件 Returns: @@ -87,7 +85,7 @@ class PluginManager: return total_registered, total_failed_registration - def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]: + def load_registered_plugin_classes(self, plugin_name: str) -> tuple[bool, int]: # sourcery skip: extract-duplicate-method, extract-method """ 加载已经注册的插件类 @@ -142,7 +140,7 @@ class PluginManager: except FileNotFoundError as e: # manifest文件缺失 - error_msg = f"缺少manifest文件: {str(e)}" + error_msg = f"缺少manifest文件: {e!s}" self.failed_plugins[plugin_name] = error_msg logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") return False, 1 @@ -150,14 +148,14 @@ class PluginManager: except ValueError as e: # manifest文件格式错误或验证失败 traceback.print_exc() - error_msg = f"manifest验证失败: {str(e)}" + error_msg = f"manifest验证失败: {e!s}" self.failed_plugins[plugin_name] = error_msg logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") return False, 1 except Exception as e: # 其他错误 - error_msg = f"未知错误: {str(e)}" + error_msg = f"未知错误: {e!s}" self.failed_plugins[plugin_name] = error_msg logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") logger.debug("详细错误信息: ", exc_info=True) @@ -192,7 +190,7 @@ class PluginManager: logger.debug(f"插件 {plugin_name} 重载成功") return True - def rescan_plugin_directory(self) -> Tuple[int, int]: + def rescan_plugin_directory(self) -> tuple[int, int]: """ 重新扫描插件根目录 """ @@ -220,7 +218,7 @@ class PluginManager: return self.loaded_plugins.get(plugin_name) # === 查询方法 === - def list_loaded_plugins(self) -> List[str]: + def list_loaded_plugins(self) -> list[str]: """ 列出所有当前加载的插件。 @@ -229,7 +227,7 @@ class PluginManager: """ return list(self.loaded_plugins.keys()) - def list_registered_plugins(self) -> List[str]: + def list_registered_plugins(self) -> list[str]: """ 列出所有已注册的插件类。 @@ -238,7 +236,7 @@ class PluginManager: """ return list(self.plugin_classes.keys()) - def get_plugin_path(self, plugin_name: str) -> Optional[str]: + def get_plugin_path(self, plugin_name: str) -> str | None: """ 获取指定插件的路径。 @@ -329,7 +327,7 @@ class PluginManager: # == 兼容性检查 == @staticmethod - def _check_plugin_version_compatibility(plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: + def _check_plugin_version_compatibility(plugin_name: str, manifest_data: dict[str, Any]) -> tuple[bool, str]: """检查插件版本兼容性 Args: @@ -569,7 +567,7 @@ class PluginManager: return True except Exception as e: - logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}", exc_info=True) + logger.error(f"❌ 插件卸载失败: {plugin_name} - {e!s}", exc_info=True) return False def reload_plugin(self, plugin_name: str) -> bool: @@ -606,7 +604,7 @@ class PluginManager: return False except Exception as e: - logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}", exc_info=True) + logger.error(f"❌ 插件重载失败: {plugin_name} - {e!s}", exc_info=True) return False def force_reload_plugin(self, plugin_name: str) -> bool: diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index e666e32d4..17fe46ddf 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -1,16 +1,17 @@ +import inspect import time -from typing import List, Dict, Tuple, Optional, Any +from typing import Any + +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.common.cache_manager import tool_cache +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.payload_content import ToolCall +from src.llm_models.utils_model import LLMRequest from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.core.global_announcement_manager import global_announcement_manager -from src.llm_models.utils_model import LLMRequest -from src.llm_models.payload_content import ToolCall -from src.config.config import global_config, model_config -from src.chat.utils.prompt import Prompt, global_prompt_manager -import inspect -from src.chat.message_receive.chat_stream import get_chat_manager -from src.common.logger import get_logger -from src.common.cache_manager import tool_cache logger = get_logger("tool_use") @@ -56,14 +57,14 @@ class ToolExecutor: self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor") # 二步工具调用状态管理 - self._pending_step_two_tools: Dict[str, Dict[str, Any]] = {} + self._pending_step_two_tools: dict[str, dict[str, Any]] = {} """待处理的第二步工具调用,格式为 {tool_name: step_two_definition}""" logger.info(f"{self.log_prefix}工具执行器初始化完成") async def execute_from_chat_message( self, target_message: str, chat_history: str, sender: str, return_details: bool = False - ) -> Tuple[List[Dict[str, Any]], List[str], str]: + ) -> tuple[list[dict[str, Any]], list[str], str]: """从聊天消息执行工具 Args: @@ -113,7 +114,7 @@ class ToolExecutor: else: return tool_results, [], "" - def _get_tool_definitions(self) -> List[Dict[str, Any]]: + def _get_tool_definitions(self) -> list[dict[str, Any]]: all_tools = get_llm_available_tool_definitions() user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) @@ -129,7 +130,7 @@ class ToolExecutor: return tool_definitions - async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: + async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]: """执行工具调用 Args: @@ -138,7 +139,7 @@ class ToolExecutor: Returns: Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表) """ - tool_results: List[Dict[str, Any]] = [] + tool_results: list[dict[str, Any]] = [] used_tools = [] if not tool_calls: @@ -192,7 +193,7 @@ class ToolExecutor: error_info = { "type": "tool_error", "id": f"tool_error_{time.time()}", - "content": f"工具{tool_name}执行失败: {str(e)}", + "content": f"工具{tool_name}执行失败: {e!s}", "tool_name": tool_name, "timestamp": time.time(), } @@ -201,8 +202,8 @@ class ToolExecutor: return tool_results, used_tools async def execute_tool_call( - self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None - ) -> Optional[Dict[str, Any]]: + self, tool_call: ToolCall, tool_instance: BaseTool | None = None + ) -> dict[str, Any] | None: """执行单个工具调用,并处理缓存""" function_args = tool_call.args or {} @@ -256,8 +257,8 @@ class ToolExecutor: return result async def _original_execute_tool_call( - self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None - ) -> Optional[Dict[str, Any]]: + self, tool_call: ToolCall, tool_instance: BaseTool | None = None + ) -> dict[str, Any] | None: """执行单个工具调用的原始逻辑""" try: function_name = tool_call.func_name @@ -323,10 +324,10 @@ class ToolExecutor: logger.warning(f"{self.log_prefix}工具 {function_name} 返回空结果") return None except Exception as e: - logger.error(f"执行工具调用时发生错误: {str(e)}") + logger.error(f"执行工具调用时发生错误: {e!s}") raise e - async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]: + async def execute_specific_tool_simple(self, tool_name: str, tool_args: dict) -> dict | None: """直接执行指定工具 Args: diff --git a/src/plugin_system/utils/dependency_alias.py b/src/plugin_system/utils/dependency_alias.py index 7a2aa1d80..a7e478d76 100644 --- a/src/plugin_system/utils/dependency_alias.py +++ b/src/plugin_system/utils/dependency_alias.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 本模块包含一个从Python包的“安装名”到其“导入名”的映射。 diff --git a/src/plugin_system/utils/dependency_config.py b/src/plugin_system/utils/dependency_config.py index b14f88b46..081d0216c 100644 --- a/src/plugin_system/utils/dependency_config.py +++ b/src/plugin_system/utils/dependency_config.py @@ -1,4 +1,3 @@ -from typing import Optional from src.common.logger import get_logger logger = get_logger("dependency_config") @@ -66,7 +65,7 @@ class DependencyConfig: # 全局配置实例 -_global_dependency_config: Optional[DependencyConfig] = None +_global_dependency_config: DependencyConfig | None = None def get_dependency_config() -> DependencyConfig: diff --git a/src/plugin_system/utils/dependency_manager.py b/src/plugin_system/utils/dependency_manager.py index 980f538cc..4d5e48a9d 100644 --- a/src/plugin_system/utils/dependency_manager.py +++ b/src/plugin_system/utils/dependency_manager.py @@ -1,8 +1,9 @@ -import subprocess -import sys import importlib import importlib.util -from typing import List, Tuple, Optional, Any +import subprocess +import sys +from typing import Any + from packaging import version from packaging.requirements import Requirement @@ -19,7 +20,7 @@ class DependencyManager: 负责检查和自动安装插件的Python包依赖 """ - def __init__(self, auto_install: bool = True, use_mirror: bool = False, mirror_url: Optional[str] = None): + def __init__(self, auto_install: bool = True, use_mirror: bool = False, mirror_url: str | None = None): """初始化依赖管理器 Args: @@ -46,7 +47,7 @@ class DependencyManager: self.mirror_url = mirror_url or "" self.install_timeout = 300 - def check_dependencies(self, dependencies: Any, plugin_name: str = "") -> Tuple[bool, List[str], List[str]]: + def check_dependencies(self, dependencies: Any, plugin_name: str = "") -> tuple[bool, list[str], list[str]]: """检查依赖包是否满足要求 Args: @@ -69,7 +70,7 @@ class DependencyManager: logger.info(f"{log_prefix}缺少依赖包: {dep.get_pip_requirement()}") missing_packages.append(dep.get_pip_requirement()) except Exception as e: - error_msg = f"检查依赖 {dep.package_name} 时发生错误: {str(e)}" + error_msg = f"检查依赖 {dep.package_name} 时发生错误: {e!s}" error_messages.append(error_msg) logger.error(f"{log_prefix}{error_msg}") @@ -84,7 +85,7 @@ class DependencyManager: return all_satisfied, missing_packages, error_messages - def install_dependencies(self, packages: List[str], plugin_name: str = "") -> Tuple[bool, List[str]]: + def install_dependencies(self, packages: list[str], plugin_name: str = "") -> tuple[bool, list[str]]: """自动安装缺失的依赖包 Args: @@ -115,7 +116,7 @@ class DependencyManager: logger.error(f"{log_prefix}❌ 安装失败: {package}") except Exception as e: failed_packages.append(package) - logger.error(f"{log_prefix}❌ 安装 {package} 时发生异常: {str(e)}") + logger.error(f"{log_prefix}❌ 安装 {package} 时发生异常: {e!s}") success = len(failed_packages) == 0 if success: @@ -125,7 +126,7 @@ class DependencyManager: return success, failed_packages - def check_and_install_dependencies(self, dependencies: Any, plugin_name: str = "") -> Tuple[bool, List[str]]: + def check_and_install_dependencies(self, dependencies: Any, plugin_name: str = "") -> tuple[bool, list[str]]: """检查并自动安装依赖(组合操作) Args: @@ -163,7 +164,7 @@ class DependencyManager: return False, all_errors @staticmethod - def _normalize_dependencies(dependencies: Any) -> List[PythonDependency]: + def _normalize_dependencies(dependencies: Any) -> list[PythonDependency]: """将依赖列表标准化为PythonDependency对象""" normalized = [] @@ -277,7 +278,7 @@ class DependencyManager: # 全局依赖管理器实例 -_global_dependency_manager: Optional[DependencyManager] = None +_global_dependency_manager: DependencyManager | None = None def get_dependency_manager() -> DependencyManager: @@ -288,7 +289,7 @@ def get_dependency_manager() -> DependencyManager: return _global_dependency_manager -def configure_dependency_manager(auto_install: bool = True, use_mirror: bool = False, mirror_url: Optional[str] = None): +def configure_dependency_manager(auto_install: bool = True, use_mirror: bool = False, mirror_url: str | None = None): """配置全局依赖管理器""" global _global_dependency_manager _global_dependency_manager = DependencyManager( diff --git a/src/plugin_system/utils/manifest_utils.py b/src/plugin_system/utils/manifest_utils.py index b714aefd7..21025127f 100644 --- a/src/plugin_system/utils/manifest_utils.py +++ b/src/plugin_system/utils/manifest_utils.py @@ -5,7 +5,8 @@ """ import re -from typing import Dict, Any, Tuple +from typing import Any + from src.common.logger import get_logger from src.config.config import MMC_VERSION @@ -70,7 +71,7 @@ class VersionComparator: return normalized @staticmethod - def parse_version(version: str) -> Tuple[int, int, int]: + def parse_version(version: str) -> tuple[int, int, int]: """解析版本号为元组 Args: @@ -109,7 +110,7 @@ class VersionComparator: return 0 @staticmethod - def check_forward_compatibility(current_version: str, max_version: str) -> Tuple[bool, str]: + def check_forward_compatibility(current_version: str, max_version: str) -> tuple[bool, str]: """检查向前兼容性(仅使用兼容性映射表) Args: @@ -131,7 +132,7 @@ class VersionComparator: return False, "" @staticmethod - def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]: + def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> tuple[bool, str]: """检查版本是否在指定范围内,支持兼容性检查 Args: @@ -195,7 +196,7 @@ class VersionComparator: logger.info(f"添加兼容性映射:{base_normalized} -> {compatible_versions}") @staticmethod - def get_compatibility_info() -> Dict[str, list]: + def get_compatibility_info() -> dict[str, list]: """获取当前的兼容性映射表 Returns: @@ -232,7 +233,7 @@ class ManifestValidator: self.validation_errors = [] self.validation_warnings = [] - def validate_manifest(self, manifest_data: Dict[str, Any]) -> bool: + def validate_manifest(self, manifest_data: dict[str, Any]) -> bool: """验证manifest数据 Args: @@ -266,7 +267,7 @@ class ManifestValidator: if "name" not in author or not author["name"]: self.validation_errors.append("作者信息缺少name字段或为空") # url字段是可选的 - if "url" in author and author["url"]: + if author.get("url"): url = author["url"] if not (url.startswith("http://") or url.startswith("https://")): self.validation_warnings.append("作者URL建议使用完整的URL格式") @@ -305,7 +306,7 @@ class ManifestValidator: # 检查URL格式(可选字段) for url_field in ["homepage_url", "repository_url"]: - if url_field in manifest_data and manifest_data[url_field]: + if manifest_data.get(url_field): url: str = manifest_data[url_field] if not (url.startswith("http://") or url.startswith("https://")): self.validation_warnings.append(f"{url_field}建议使用完整的URL格式") diff --git a/src/plugin_system/utils/permission_decorators.py b/src/plugin_system/utils/permission_decorators.py index 278ab2068..7629e608c 100644 --- a/src/plugin_system/utils/permission_decorators.py +++ b/src/plugin_system/utils/permission_decorators.py @@ -4,19 +4,19 @@ 提供方便的权限检查装饰器,用于插件命令和其他需要权限验证的地方。 """ +from collections.abc import Callable from functools import wraps -from typing import Callable, Optional from inspect import iscoroutinefunction +from src.chat.message_receive.chat_stream import ChatStream +from src.plugin_system.apis.logging_api import get_logger from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.apis.send_api import text_to_stream -from src.plugin_system.apis.logging_api import get_logger -from src.chat.message_receive.chat_stream import ChatStream logger = get_logger(__name__) -def require_permission(permission_node: str, deny_message: Optional[str] = None): +def require_permission(permission_node: str, deny_message: str | None = None): """ 权限检查装饰器 @@ -90,7 +90,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None) return decorator -def require_master(deny_message: Optional[str] = None): +def require_master(deny_message: str | None = None): """ Master权限检查装饰器 @@ -186,9 +186,7 @@ class PermissionChecker: return permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id) @staticmethod - async def ensure_permission( - chat_stream: ChatStream, permission_node: str, deny_message: Optional[str] = None - ) -> bool: + async def ensure_permission(chat_stream: ChatStream, permission_node: str, deny_message: str | None = None) -> bool: """ 确保用户拥有指定权限,如果没有权限会发送消息并返回False @@ -209,7 +207,7 @@ class PermissionChecker: return has_permission @staticmethod - async def ensure_master(chat_stream: ChatStream, deny_message: Optional[str] = None) -> bool: + async def ensure_master(chat_stream: ChatStream, deny_message: str | None = None) -> bool: """ 确保用户为Master用户,如果不是会发送消息并返回False diff --git a/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py b/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py index 3a652e2f4..25cdb1fa0 100644 --- a/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py +++ b/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py @@ -7,15 +7,15 @@ import asyncio import time import traceback from datetime import datetime -from typing import Dict, Any +from typing import Any +from src.chat.express.expression_learner import expression_learner_manager +from src.chat.planner_actions.action_manager import ChatterActionManager +from src.common.data_models.message_manager_data_model import StreamContext +from src.common.logger import get_logger from src.plugin_system.base.base_chatter import BaseChatter from src.plugin_system.base.component_types import ChatType -from src.common.data_models.message_manager_data_model import StreamContext from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner -from src.chat.planner_actions.action_manager import ChatterActionManager -from src.common.logger import get_logger -from src.chat.express.expression_learner import expression_learner_manager logger = get_logger("affinity_chatter") @@ -113,7 +113,7 @@ class AffinityChatter(BaseChatter): "executed_count": 0, } - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: """ 获取处理器统计信息 @@ -122,7 +122,7 @@ class AffinityChatter(BaseChatter): """ return self.stats.copy() - def get_planner_stats(self) -> Dict[str, Any]: + def get_planner_stats(self) -> dict[str, Any]: """ 获取规划器统计信息 @@ -131,7 +131,7 @@ class AffinityChatter(BaseChatter): """ return self.planner.get_planner_stats() - def get_interest_scoring_stats(self) -> Dict[str, Any]: + def get_interest_scoring_stats(self) -> dict[str, Any]: """ 获取兴趣度评分统计信息 @@ -140,7 +140,7 @@ class AffinityChatter(BaseChatter): """ return self.planner.get_interest_scoring_stats() - def get_relationship_stats(self) -> Dict[str, Any]: + def get_relationship_stats(self) -> dict[str, Any]: """ 获取用户关系统计信息 @@ -158,7 +158,7 @@ class AffinityChatter(BaseChatter): """ return self.planner.get_current_mood_state() - def get_mood_stats(self) -> Dict[str, Any]: + def get_mood_stats(self) -> dict[str, Any]: """ 获取情绪状态统计信息 diff --git a/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py index 1bb60146b..6892b0916 100644 --- a/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py +++ b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py @@ -5,11 +5,11 @@ """ import traceback -from typing import Dict, List, Any +from typing import Any +from src.chat.interest_system import bot_interest_manager from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.info_data_model import InterestScore -from src.chat.interest_system import bot_interest_manager from src.common.logger import get_logger from src.config.config import global_config @@ -47,11 +47,11 @@ class ChatterInterestScoringSystem: ) # 每次不回复增加的概率 # 用户关系数据 - self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score + self.user_relationships: dict[str, float] = {} # user_id -> relationship_score async def calculate_interest_scores( - self, messages: List[DatabaseMessages], bot_nickname: str - ) -> List[InterestScore]: + self, messages: list[DatabaseMessages], bot_nickname: str + ) -> list[InterestScore]: """计算消息的兴趣度评分""" user_messages = [msg for msg in messages if str(msg.user_info.user_id) != str(global_config.bot.qq_account)] if not user_messages: @@ -97,7 +97,7 @@ class ChatterInterestScoringSystem: details=details, ) - async def _calculate_interest_match_score(self, content: str, keywords: List[str] = None) -> float: + async def _calculate_interest_match_score(self, content: str, keywords: list[str] = None) -> float: """计算兴趣匹配度 - 使用智能embedding匹配""" if not content: return 0.0 @@ -109,7 +109,7 @@ class ChatterInterestScoringSystem: # 智能匹配未初始化,返回默认分数 return 0.3 - async def _calculate_smart_interest_match(self, content: str, keywords: List[str] = None) -> float: + async def _calculate_smart_interest_match(self, content: str, keywords: list[str] = None) -> float: """使用embedding计算智能兴趣匹配""" try: # 如果没有传入关键词,则提取 @@ -134,7 +134,7 @@ class ChatterInterestScoringSystem: logger.error(f"智能兴趣匹配计算失败: {e}") return 0.0 - def _extract_keywords_from_database(self, message: DatabaseMessages) -> List[str]: + def _extract_keywords_from_database(self, message: DatabaseMessages) -> list[str]: """从数据库消息中提取关键词""" keywords = [] @@ -166,7 +166,7 @@ class ChatterInterestScoringSystem: return keywords[:15] # 返回前15个关键词 - def _extract_keywords_from_content(self, content: str) -> List[str]: + def _extract_keywords_from_content(self, content: str) -> list[str]: """从内容中提取关键词(降级方案)""" import re @@ -287,7 +287,7 @@ class ChatterInterestScoringSystem: """获取用户关系分""" return self.user_relationships.get(user_id, 0.3) - def get_scoring_stats(self) -> Dict: + def get_scoring_stats(self) -> dict: """获取评分系统统计""" return { "no_reply_count": self.no_reply_count, @@ -318,7 +318,7 @@ class ChatterInterestScoringSystem: logger.error(f"初始化智能兴趣系统失败: {e}") traceback.print_exc() - def get_matching_config(self) -> Dict[str, Any]: + def get_matching_config(self) -> dict[str, Any]: """获取匹配配置信息""" return { "use_smart_matching": self.use_smart_matching, diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py index 8d322c880..b68591100 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py @@ -5,12 +5,11 @@ PlanExecutor: 接收 Plan 对象并执行其中的所有动作。 import asyncio import time -from typing import Dict, List -from src.config.config import global_config from src.chat.planner_actions.action_manager import ChatterActionManager -from src.common.data_models.info_data_model import Plan, ActionPlannerInfo +from src.common.data_models.info_data_model import ActionPlannerInfo, Plan from src.common.logger import get_logger +from src.config.config import global_config logger = get_logger("plan_executor") @@ -52,7 +51,7 @@ class ChatterPlanExecutor: """设置关系追踪器""" self.relationship_tracker = relationship_tracker - async def execute(self, plan: Plan) -> Dict[str, any]: + async def execute(self, plan: Plan) -> dict[str, any]: """ 遍历并执行Plan对象中`decided_actions`列表里的所有动作。 @@ -110,7 +109,7 @@ class ChatterPlanExecutor: "results": execution_results, } - async def _execute_reply_actions(self, reply_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]: + async def _execute_reply_actions(self, reply_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, any]: """串行执行所有回复动作,增加去重逻辑,避免对同一消息多次回复""" results = [] @@ -162,7 +161,7 @@ class ChatterPlanExecutor: async def _execute_single_reply_action( self, action_info: ActionPlannerInfo, plan: Plan, clear_unread: bool = True - ) -> Dict[str, any]: + ) -> dict[str, any]: """执行单个回复动作""" start_time = time.time() success = False @@ -240,7 +239,7 @@ class ChatterPlanExecutor: else reply_content, } - async def _execute_other_actions(self, other_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]: + async def _execute_other_actions(self, other_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, any]: """执行其他动作""" results = [] @@ -269,7 +268,7 @@ class ChatterPlanExecutor: return {"results": results} - async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> Dict[str, any]: + async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> dict[str, any]: """执行单个其他动作""" start_time = time.time() success = False @@ -378,7 +377,7 @@ class ChatterPlanExecutor: logger.debug(f"action_message类型: {type(action_info.action_message)}") logger.debug(f"action_message内容: {action_info.action_message}") - def get_execution_stats(self) -> Dict[str, any]: + def get_execution_stats(self) -> dict[str, any]: """获取执行统计信息""" stats = self.execution_stats.copy() @@ -409,7 +408,7 @@ class ChatterPlanExecutor: "execution_times": [], } - def get_recent_performance(self, limit: int = 10) -> List[Dict[str, any]]: + def get_recent_performance(self, limit: int = 10) -> list[dict[str, any]]: """获取最近的执行性能""" recent_times = self.execution_stats["execution_times"][-limit:] if not recent_times: diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py index 1bc153fad..92b299219 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py @@ -2,13 +2,13 @@ PlanFilter: 接收 Plan 对象,根据不同模式的逻辑进行筛选,决定最终要执行的动作。 """ -import orjson +import re import time import traceback -import re from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any +import orjson from json_repair import repair_json # 旧的Hippocampus系统已被移除,现在使用增强记忆系统 @@ -39,7 +39,7 @@ class ChatterPlanFilter: 根据 Plan 中的模式和信息,筛选并决定最终的动作。 """ - def __init__(self, chat_id: str, available_actions: List[str]): + def __init__(self, chat_id: str, available_actions: list[str]): """ 初始化动作计划筛选器。 @@ -316,8 +316,8 @@ class ChatterPlanFilter: """构建已读/未读历史消息块""" try: # 从message_manager获取真实的已读/未读消息 - from src.chat.utils.utils import assign_message_ids from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat + from src.chat.utils.utils import assign_message_ids # 获取聊天流的上下文 from src.plugin_system.apis.chat_api import get_chat_manager @@ -392,14 +392,15 @@ class ChatterPlanFilter: logger.error(f"构建已读/未读历史消息块时出错: {e}") return "构建已读历史消息时出错", "构建未读历史消息时出错", [] - async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]: + async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]: """为消息获取兴趣度评分""" interest_scores = {} try: - from .interest_scoring import chatter_interest_scoring_system from src.common.data_models.database_data_model import DatabaseMessages + from .interest_scoring import chatter_interest_scoring_system + # 使用插件内部的兴趣度评分系统计算评分 for msg_dict in messages: try: @@ -450,7 +451,7 @@ class ChatterPlanFilter: async def _parse_single_action( self, action_json: dict, message_id_list: list, plan: Plan - ) -> List[ActionPlannerInfo]: + ) -> list[ActionPlannerInfo]: parsed_actions = [] try: # 从新的actions结构中获取动作信息 @@ -599,7 +600,7 @@ class ChatterPlanFilter: ) return parsed_actions - def _filter_no_actions(self, action_list: List[ActionPlannerInfo]) -> List[ActionPlannerInfo]: + def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]: non_no_actions = [a for a in action_list if a.action_type not in ["no_action", "no_reply"]] if non_no_actions: return non_no_actions @@ -652,7 +653,7 @@ class ChatterPlanFilter: logger.error(f"获取长期记忆时出错: {e}") return "回忆时出现了一些问题。" - async def _build_action_options(self, current_available_actions: Dict[str, ActionInfo]) -> str: + async def _build_action_options(self, current_available_actions: dict[str, ActionInfo]) -> str: action_options_block = "" for action_name, action_info in current_available_actions.items(): # 构建参数的JSON示例 @@ -723,7 +724,7 @@ class ChatterPlanFilter: ) return action_options_block - def _find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]: + def _find_message_by_id(self, message_id: str, message_id_list: list) -> dict[str, Any] | None: """ 增强的消息查找函数,支持多种格式和模糊匹配 兼容大模型可能返回的各种格式变体 @@ -828,12 +829,12 @@ class ChatterPlanFilter: logger.warning(f"未找到任何匹配的消息: {original_id} (候选: {candidate_ids})") return None - def _get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]: + def _get_latest_message(self, message_id_list: list) -> dict[str, Any] | None: if not message_id_list: return None return message_id_list[-1].get("message") - def _find_poke_notice(self, message_id_list: list) -> Optional[Dict[str, Any]]: + def _find_poke_notice(self, message_id_list: list) -> dict[str, Any] | None: """在消息列表中寻找戳一戳的通知消息""" for item in reversed(message_id_list): message = item.get("message") diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_generator.py b/src/plugins/built_in/affinity_flow_chatter/plan_generator.py index 86539ac01..d946934d5 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_generator.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_generator.py @@ -3,7 +3,6 @@ PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个 """ import time -from typing import Dict from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat from src.chat.utils.utils import get_chat_type_and_target_info @@ -85,7 +84,7 @@ class ChatterPlanGenerator: chat_history=[], ) - async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> Dict[str, ActionInfo]: + async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> dict[str, ActionInfo]: """ 获取当前可用的动作列表。 @@ -152,7 +151,7 @@ class ChatterPlanGenerator: # 如果获取失败,返回空列表 return [] - def get_generator_stats(self) -> Dict: + def get_generator_stats(self) -> dict: """ 获取生成器统计信息。 diff --git a/src/plugins/built_in/affinity_flow_chatter/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner.py index b6be512b9..e2321aab1 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner.py @@ -4,22 +4,20 @@ """ from dataclasses import asdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple - -from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor -from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter -from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator -from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system -from src.mood.mood_manager import mood_manager - +from typing import TYPE_CHECKING, Any from src.common.logger import get_logger from src.config.config import global_config +from src.mood.mood_manager import mood_manager +from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system +from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor +from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter +from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator if TYPE_CHECKING: - from src.common.data_models.message_manager_data_model import StreamContext - from src.common.data_models.info_data_model import Plan from src.chat.planner_actions.action_manager import ChatterActionManager + from src.common.data_models.info_data_model import Plan + from src.common.data_models.message_manager_data_model import StreamContext # 导入提示词模块以确保其被初始化 from src.plugins.built_in.affinity_flow_chatter import planner_prompts # noqa @@ -62,7 +60,7 @@ class ChatterActionPlanner: "other_actions_executed": 0, } - async def plan(self, context: "StreamContext" = None) -> Tuple[List[Dict], Optional[Dict]]: + async def plan(self, context: "StreamContext" = None) -> tuple[list[dict], dict | None]: """ 执行完整的增强版规划流程。 @@ -84,7 +82,7 @@ class ChatterActionPlanner: self.planner_stats["failed_plans"] += 1 return [], None - async def _enhanced_plan_flow(self, context: "StreamContext") -> Tuple[List[Dict], Optional[Dict]]: + async def _enhanced_plan_flow(self, context: "StreamContext") -> tuple[list[dict], dict | None]: """执行增强版规划流程""" try: # 在规划前,先进行动作修改 @@ -104,7 +102,7 @@ class ChatterActionPlanner: score = 0.0 should_reply = False reply_not_available = False - interest_updates: List[Dict[str, Any]] = [] + interest_updates: list[dict[str, Any]] = [] if unread_messages: # 为每条消息计算兴趣度,并延迟提交数据库更新 @@ -193,7 +191,7 @@ class ChatterActionPlanner: self.planner_stats["failed_plans"] += 1 return [], None - async def _commit_interest_updates(self, updates: List[Dict[str, Any]]) -> None: + async def _commit_interest_updates(self, updates: list[dict[str, Any]]) -> None: """统一更新消息兴趣度,减少数据库写入次数""" if not updates: return @@ -220,7 +218,7 @@ class ChatterActionPlanner: except Exception as e: logger.warning(f"批量更新数据库兴趣度失败: {e}") - def _update_stats_from_execution_result(self, execution_result: Dict[str, any]): + def _update_stats_from_execution_result(self, execution_result: dict[str, any]): """根据执行结果更新规划器统计""" if not execution_result: return @@ -244,7 +242,7 @@ class ChatterActionPlanner: self.planner_stats["replies_generated"] += reply_count self.planner_stats["other_actions_executed"] += other_count - def _build_return_result(self, plan: "Plan") -> Tuple[List[Dict], Optional[Dict]]: + def _build_return_result(self, plan: "Plan") -> tuple[list[dict], dict | None]: """构建返回结果""" final_actions = plan.decided_actions or [] final_target_message = next((act.action_message for act in final_actions if act.action_message), None) @@ -261,7 +259,7 @@ class ChatterActionPlanner: return final_actions_dict, final_target_message_dict - def get_planner_stats(self) -> Dict[str, any]: + def get_planner_stats(self) -> dict[str, any]: """获取规划器统计""" return self.planner_stats.copy() @@ -270,7 +268,7 @@ class ChatterActionPlanner: chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id) return chat_mood.mood_state - def get_mood_stats(self) -> Dict[str, any]: + def get_mood_stats(self) -> dict[str, any]: """获取情绪状态统计""" chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id) return { diff --git a/src/plugins/built_in/affinity_flow_chatter/plugin.py b/src/plugins/built_in/affinity_flow_chatter/plugin.py index 7c86d13fe..32d869e67 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plugin.py +++ b/src/plugins/built_in/affinity_flow_chatter/plugin.py @@ -2,12 +2,10 @@ 亲和力聊天处理器插件 """ -from typing import List, Tuple, Type - +from src.common.logger import get_logger from src.plugin_system.apis.plugin_register_api import register_plugin from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system.base.component_types import ComponentInfo -from src.common.logger import get_logger logger = get_logger("affinity_chatter_plugin") @@ -29,7 +27,7 @@ class AffinityChatterPlugin(BasePlugin): # 简单的 config_schema 占位(如果将来需要配置可扩展) config_schema = {} - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: """返回插件包含的组件列表(ChatterInfo, AffinityChatter) 这里采用延迟导入 AffinityChatter 来避免循环依赖和启动顺序问题。 diff --git a/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py b/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py index 2320670a0..e3dcb9791 100644 --- a/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py +++ b/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py @@ -5,15 +5,15 @@ """ import time -from typing import Dict, List, Optional -from src.common.logger import get_logger -from src.config.config import model_config, global_config -from src.llm_models.utils_model import LLMRequest -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import UserRelationships, Messages -from sqlalchemy import select, desc +from sqlalchemy import desc, select + from src.common.data_models.database_data_model import DatabaseMessages +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import Messages, UserRelationships +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("chatter_relationship_tracker") @@ -22,15 +22,15 @@ class ChatterRelationshipTracker: """用户关系追踪器""" def __init__(self, interest_scoring_system=None): - self.tracking_users: Dict[str, Dict] = {} # user_id -> interaction_data + self.tracking_users: dict[str, dict] = {} # user_id -> interaction_data self.max_tracking_users = 3 self.update_interval_minutes = 30 self.last_update_time = time.time() - self.relationship_history: List[Dict] = [] + self.relationship_history: list[dict] = [] self.interest_scoring_system = interest_scoring_system # 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float}) - self.user_relationship_cache: Dict[str, Dict] = {} + self.user_relationship_cache: dict[str, dict] = {} self.cache_expiry_hours = 1 # 缓存过期时间(小时) # 关系更新LLM @@ -91,7 +91,7 @@ class ChatterRelationshipTracker: logger.debug(f"添加用户交互追踪: {user_id}") - async def check_and_update_relationships(self) -> List[Dict]: + async def check_and_update_relationships(self) -> list[dict]: """检查并更新用户关系""" current_time = time.time() if current_time - self.last_update_time < self.update_interval_minutes * 60: @@ -108,7 +108,7 @@ class ChatterRelationshipTracker: self.last_update_time = current_time return updates - async def _update_user_relationship(self, interaction: Dict) -> Optional[Dict]: + async def _update_user_relationship(self, interaction: dict) -> dict | None: """更新单个用户的关系""" try: # 获取bot人设信息 @@ -201,11 +201,11 @@ class ChatterRelationshipTracker: return None - def get_tracking_users(self) -> Dict[str, Dict]: + def get_tracking_users(self) -> dict[str, dict]: """获取正在追踪的用户""" return self.tracking_users.copy() - def get_user_interaction(self, user_id: str) -> Optional[Dict]: + def get_user_interaction(self, user_id: str) -> dict | None: """获取特定用户的交互记录""" return self.tracking_users.get(user_id) @@ -220,11 +220,11 @@ class ChatterRelationshipTracker: self.tracking_users.clear() logger.info("清空所有用户追踪") - def get_relationship_history(self) -> List[Dict]: + def get_relationship_history(self) -> list[dict]: """获取关系历史记录""" return self.relationship_history.copy() - def add_to_history(self, relationship_update: Dict): + def add_to_history(self, relationship_update: dict): """添加到关系历史""" self.relationship_history.append({**relationship_update, "update_time": time.time()}) @@ -232,7 +232,7 @@ class ChatterRelationshipTracker: if len(self.relationship_history) > 100: self.relationship_history = self.relationship_history[-100:] - def get_tracker_stats(self) -> Dict: + def get_tracker_stats(self) -> dict: """获取追踪器统计""" return { "tracking_users": len(self.tracking_users), @@ -268,7 +268,7 @@ class ChatterRelationshipTracker: self.add_to_history(update_info) logger.info(f"强制更新用户关系: {user_id} -> {new_score:.2f}") - def get_user_summary(self, user_id: str) -> Dict: + def get_user_summary(self, user_id: str) -> dict: """获取用户交互总结""" if user_id not in self.tracking_users: return {} @@ -313,7 +313,7 @@ class ChatterRelationshipTracker: # 数据库中也没有,返回默认值 return global_config.affinity_flow.base_relationship_score - async def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]: + async def _get_user_relationship_from_db(self, user_id: str) -> dict | None: """从数据库获取用户关系数据""" try: async with get_db_session() as session: @@ -431,7 +431,7 @@ class ChatterRelationshipTracker: return 0 - async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]: + async def _get_last_bot_reply_to_user(self, user_id: str) -> DatabaseMessages | None: """获取上次bot回复该用户的消息""" try: async with get_db_session() as session: @@ -455,7 +455,7 @@ class ChatterRelationshipTracker: return None - async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]: + async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> list[DatabaseMessages]: """获取用户在bot回复后的反应消息""" try: async with get_db_session() as session: @@ -511,7 +511,7 @@ class ChatterRelationshipTracker: user_id: str, user_name: str, last_bot_reply: DatabaseMessages, - user_reactions: List[DatabaseMessages], + user_reactions: list[DatabaseMessages], current_text: str, current_score: float, current_reply: str, diff --git a/src/plugins/built_in/core_actions/anti_injector_manager.py b/src/plugins/built_in/core_actions/anti_injector_manager.py index 3291ba8cf..3b207ab63 100644 --- a/src/plugins/built_in/core_actions/anti_injector_manager.py +++ b/src/plugins/built_in/core_actions/anti_injector_manager.py @@ -8,9 +8,9 @@ - 测试功能 """ -from src.plugin_system.base import BaseCommand from src.chat.antipromptinjector import get_anti_injector from src.common.logger import get_logger +from src.plugin_system.base import BaseCommand logger = get_logger("anti_injector.commands") @@ -56,5 +56,5 @@ class AntiInjectorStatusCommand(BaseCommand): except Exception as e: logger.error(f"获取反注入系统状态失败: {e}") - await self.send_text(f"获取状态失败: {str(e)}") - return False, f"获取状态失败: {str(e)}", True + await self.send_text(f"获取状态失败: {e!s}") + return False, f"获取状态失败: {e!s}", True diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index a477fdf0a..0dab1f88c 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -1,19 +1,18 @@ import random -from typing import Tuple -# 导入新插件系统 -from src.plugin_system import BaseAction, ActionActivationType, ChatMode +from src.chat.emoji_system.emoji_history import add_emoji_to_history, get_recent_emojis +from src.chat.emoji_system.emoji_manager import MaiEmoji, get_emoji_manager +from src.chat.utils.utils_image import image_path_to_base64 # 导入依赖的系统组件 from src.common.logger import get_logger +from src.config.config import global_config + +# 导入新插件系统 +from src.plugin_system import ActionActivationType, BaseAction, ChatMode # 导入API模块 - 标准Python包方式 from src.plugin_system.apis import llm_api, message_api -from src.chat.emoji_system.emoji_manager import get_emoji_manager, MaiEmoji -from src.chat.utils.utils_image import image_path_to_base64 -from src.config.config import global_config -from src.chat.emoji_system.emoji_history import get_recent_emojis, add_emoji_to_history - logger = get_logger("emoji") @@ -59,7 +58,7 @@ class EmojiAction(BaseAction): # 关联类型 associated_types = ["emoji"] - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """执行表情动作""" logger.info(f"{self.log_prefix} 决定发送表情") @@ -286,4 +285,4 @@ class EmojiAction(BaseAction): except Exception as e: logger.error(f"{self.log_prefix} 表情动作执行失败: {e}", exc_info=True) - return False, f"表情发送失败: {str(e)}" + return False, f"表情发送失败: {e!s}" diff --git a/src/plugins/built_in/core_actions/plugin.py b/src/plugins/built_in/core_actions/plugin.py index 473005a22..91a7e8d5e 100644 --- a/src/plugins/built_in/core_actions/plugin.py +++ b/src/plugins/built_in/core_actions/plugin.py @@ -5,19 +5,16 @@ 这是系统的内置插件,提供基础的聊天交互功能 """ -from typing import List, Tuple, Type - -# 导入新插件系统 -from src.plugin_system import BasePlugin, register_plugin, ComponentInfo -from src.plugin_system.base.config_types import ConfigField - - # 导入依赖的系统组件 from src.common.logger import get_logger +# 导入新插件系统 +from src.plugin_system import BasePlugin, ComponentInfo, register_plugin +from src.plugin_system.base.config_types import ConfigField +from src.plugins.built_in.core_actions.anti_injector_manager import AntiInjectorStatusCommand + # 导入API模块 - 标准Python包方式 from src.plugins.built_in.core_actions.emoji import EmojiAction -from src.plugins.built_in.core_actions.anti_injector_manager import AntiInjectorStatusCommand logger = get_logger("core_actions") @@ -62,7 +59,7 @@ class CoreActionsPlugin(BasePlugin): }, } - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: """返回插件包含的组件列表""" # --- 根据配置注册组件 --- diff --git a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py index 194a2c5ef..38da7e013 100644 --- a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py +++ b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py @@ -1,8 +1,8 @@ -from typing import Dict, Any +from typing import Any +from src.chat.knowledge.knowledge_lib import qa_manager from src.common.logger import get_logger from src.config.config import global_config -from src.chat.knowledge.knowledge_lib import qa_manager from src.plugin_system import BaseTool, ToolParamType logger = get_logger("lpmm_get_knowledge_tool") @@ -19,7 +19,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool): ] available_for_llm = global_config.lpmm_knowledge.enable - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """执行知识库搜索 Args: @@ -56,7 +56,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool): return {"type": "lpmm_knowledge", "id": query, "content": content} except Exception as e: # 捕获异常并记录错误 - logger.error(f"知识库搜索工具执行失败: {str(e)}") + logger.error(f"知识库搜索工具执行失败: {e!s}") # 在其他异常情况下,确保 id 仍然是 query (如果它被定义了) query_id = query if "query" in locals() else "unknown_query" - return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {str(e)}"} + return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {e!s}"} diff --git a/src/plugins/built_in/maizone_refactored/__init__.py b/src/plugins/built_in/maizone_refactored/__init__.py index 56a019c4b..dd094256f 100644 --- a/src/plugins/built_in/maizone_refactored/__init__.py +++ b/src/plugins/built_in/maizone_refactored/__init__.py @@ -1,9 +1,8 @@ -# -*- coding: utf-8 -*- """ 让框架能够发现并加载子目录中的组件。 """ -from .plugin import MaiZoneRefactoredPlugin as MaiZoneRefactoredPlugin -from .actions.send_feed_action import SendFeedAction as SendFeedAction from .actions.read_feed_action import ReadFeedAction as ReadFeedAction +from .actions.send_feed_action import SendFeedAction as SendFeedAction from .commands.send_feed_command import SendFeedCommand as SendFeedCommand +from .plugin import MaiZoneRefactoredPlugin as MaiZoneRefactoredPlugin diff --git a/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py b/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py index ee5a1b73a..6abef2141 100644 --- a/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py +++ b/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py @@ -1,14 +1,12 @@ -# -*- coding: utf-8 -*- """ 阅读说说动作组件 """ -from typing import Tuple - from src.common.logger import get_logger -from src.plugin_system import BaseAction, ActionActivationType, ChatMode +from src.plugin_system import ActionActivationType, BaseAction, ChatMode from src.plugin_system.apis import generator_api from src.plugin_system.apis.permission_api import permission_api + from ..services.manager import get_qzone_service logger = get_logger("MaiZone.ReadFeedAction") @@ -41,7 +39,7 @@ class ReadFeedAction(BaseAction): # 使用权限API检查用户是否有阅读说说的权限 return await permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed") - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """ 执行动作的核心逻辑。 """ diff --git a/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py b/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py index af8760c06..b242aae70 100644 --- a/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py +++ b/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py @@ -1,14 +1,12 @@ -# -*- coding: utf-8 -*- """ 发送说说动作组件 """ -from typing import Tuple - from src.common.logger import get_logger -from src.plugin_system import BaseAction, ActionActivationType, ChatMode +from src.plugin_system import ActionActivationType, BaseAction, ChatMode from src.plugin_system.apis import generator_api from src.plugin_system.apis.permission_api import permission_api + from ..services.manager import get_qzone_service logger = get_logger("MaiZone.SendFeedAction") @@ -41,7 +39,7 @@ class SendFeedAction(BaseAction): # 使用权限API检查用户是否有发送说说的权限 return await permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed") - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """ 执行动作的核心逻辑。 """ diff --git a/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py b/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py index 631ca430d..062252a99 100644 --- a/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py +++ b/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py @@ -1,15 +1,13 @@ -# -*- coding: utf-8 -*- """ 发送说说命令 await self.send_text(f"收到!正在为你生成关于"{topic or '随机'}"的说说,请稍候...【热重载测试成功】")件 """ -from typing import Tuple - from src.common.logger import get_logger -from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.base.command_args import CommandArgs +from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.utils.permission_decorators import require_permission -from ..services.manager import get_qzone_service, get_config_getter + +from ..services.manager import get_config_getter, get_qzone_service logger = get_logger("MaiZone.SendFeedCommand") @@ -28,7 +26,7 @@ class SendFeedCommand(PlusCommand): super().__init__(*args, **kwargs) @require_permission("plugin.maizone.send_feed") - async def execute(self, args: CommandArgs) -> Tuple[bool, str, bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str, bool]: """ 执行命令的核心逻辑。 """ diff --git a/src/plugins/built_in/maizone_refactored/plugin.py b/src/plugins/built_in/maizone_refactored/plugin.py index e8259b5cb..4ef92ff9e 100644 --- a/src/plugins/built_in/maizone_refactored/plugin.py +++ b/src/plugins/built_in/maizone_refactored/plugin.py @@ -1,28 +1,26 @@ -# -*- coding: utf-8 -*- """ MaiZone(麦麦空间)- 重构版 """ import asyncio from pathlib import Path -from typing import List, Tuple, Type from src.common.logger import get_logger from src.plugin_system import BasePlugin, ComponentInfo, register_plugin -from src.plugin_system.base.config_types import ConfigField from src.plugin_system.apis.permission_api import permission_api +from src.plugin_system.base.config_types import ConfigField from .actions.read_feed_action import ReadFeedAction from .actions.send_feed_action import SendFeedAction from .commands.send_feed_command import SendFeedCommand from .services.content_service import ContentService -from .services.image_service import ImageService -from .services.qzone_service import QZoneService -from .services.scheduler_service import SchedulerService -from .services.monitor_service import MonitorService from .services.cookie_service import CookieService -from .services.reply_tracker_service import ReplyTrackerService +from .services.image_service import ImageService from .services.manager import register_service +from .services.monitor_service import MonitorService +from .services.qzone_service import QZoneService +from .services.reply_tracker_service import ReplyTrackerService +from .services.scheduler_service import SchedulerService logger = get_logger("MaiZone.Plugin") @@ -35,8 +33,8 @@ class MaiZoneRefactoredPlugin(BasePlugin): plugin_description: str = "重构版的MaiZone插件" config_file_name: str = "config.toml" enable_plugin: bool = True - dependencies: List[str] = [] - python_dependencies: List[str] = [] + dependencies: list[str] = [] + python_dependencies: list[str] = [] config_schema: dict = { "plugin": {"enable": ConfigField(type=bool, default=True, description="是否启用插件")}, @@ -125,7 +123,7 @@ class MaiZoneRefactoredPlugin(BasePlugin): asyncio.create_task(monitor_service.start()) logger.info("MaiZone后台监控和定时任务已启动。") - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: return [ (SendFeedAction.get_action_info(), SendFeedAction), (ReadFeedAction.get_action_info(), ReadFeedAction), diff --git a/src/plugins/built_in/maizone_refactored/services/content_service.py b/src/plugins/built_in/maizone_refactored/services/content_service.py index 27f2a0ee9..553eb2a95 100644 --- a/src/plugins/built_in/maizone_refactored/services/content_service.py +++ b/src/plugins/built_in/maizone_refactored/services/content_service.py @@ -1,23 +1,23 @@ -# -*- coding: utf-8 -*- """ 内容服务模块 负责生成所有与QQ空间相关的文本内容,例如说说、评论等。 """ -from typing import Callable, Optional -import datetime - -import base64 -import aiohttp -from src.common.logger import get_logger -import imghdr import asyncio -from src.plugin_system.apis import llm_api, config_api, generator_api -from src.plugin_system.apis.cross_context_api import get_chat_history_by_group_name -from src.chat.message_receive.chat_stream import get_chat_manager +import base64 +import datetime +import imghdr +from collections.abc import Callable + +import aiohttp from maim_message import UserInfo -from src.llm_models.utils_model import LLMRequest + +from src.chat.message_receive.chat_stream import get_chat_manager +from src.common.logger import get_logger from src.config.api_ada_configs import TaskConfig +from src.llm_models.utils_model import LLMRequest +from src.plugin_system.apis import config_api, generator_api, llm_api +from src.plugin_system.apis.cross_context_api import get_chat_history_by_group_name # 导入旧的工具函数,我们稍后会考虑是否也需要重构它 from ..utils.history_utils import get_send_history @@ -38,7 +38,7 @@ class ContentService: """ self.get_config = get_config - async def generate_story(self, topic: str, context: Optional[str] = None) -> str: + async def generate_story(self, topic: str, context: str | None = None) -> str: """ 根据指定主题和可选的上下文生成一条QQ空间说说。 @@ -231,7 +231,7 @@ class ContentService: return "" return "" - async def _describe_image(self, image_url: str) -> Optional[str]: + async def _describe_image(self, image_url: str) -> str | None: """ 使用LLM识别图片内容。 """ diff --git a/src/plugins/built_in/maizone_refactored/services/cookie_service.py b/src/plugins/built_in/maizone_refactored/services/cookie_service.py index 9da05582c..c0a0b7ef9 100644 --- a/src/plugins/built_in/maizone_refactored/services/cookie_service.py +++ b/src/plugins/built_in/maizone_refactored/services/cookie_service.py @@ -1,14 +1,14 @@ -# -*- coding: utf-8 -*- """ Cookie服务模块 负责从多种来源获取、缓存和管理QZone的Cookie。 """ -import orjson +from collections.abc import Callable from pathlib import Path -from typing import Callable, Optional, Dict import aiohttp +import orjson + from src.common.logger import get_logger from src.plugin_system.apis import send_api @@ -29,28 +29,28 @@ class CookieService: """获取指定QQ账号的cookie文件路径""" return self.cookie_dir / f"cookies-{qq_account}.json" - def _save_cookies_to_file(self, qq_account: str, cookies: Dict[str, str]): + def _save_cookies_to_file(self, qq_account: str, cookies: dict[str, str]): """将Cookie保存到本地文件""" cookie_file_path = self._get_cookie_file_path(qq_account) try: with open(cookie_file_path, "w", encoding="utf-8") as f: f.write(orjson.dumps(cookies, option=orjson.OPT_INDENT_2).decode("utf-8")) logger.info(f"Cookie已成功缓存至: {cookie_file_path}") - except IOError as e: + except OSError as e: logger.error(f"无法写入Cookie文件 {cookie_file_path}: {e}") - def _load_cookies_from_file(self, qq_account: str) -> Optional[Dict[str, str]]: + def _load_cookies_from_file(self, qq_account: str) -> dict[str, str] | None: """从本地文件加载Cookie""" cookie_file_path = self._get_cookie_file_path(qq_account) if cookie_file_path.exists(): try: - with open(cookie_file_path, "r", encoding="utf-8") as f: + with open(cookie_file_path, encoding="utf-8") as f: return orjson.loads(f.read()) - except (IOError, orjson.JSONDecodeError) as e: + except (OSError, orjson.JSONDecodeError) as e: logger.error(f"无法读取或解析Cookie文件 {cookie_file_path}: {e}") return None - async def _get_cookies_from_adapter(self, stream_id: Optional[str]) -> Optional[Dict[str, str]]: + async def _get_cookies_from_adapter(self, stream_id: str | None) -> dict[str, str] | None: """通过Adapter API获取Cookie""" try: params = {"domain": "user.qzone.qq.com"} @@ -73,7 +73,7 @@ class CookieService: logger.error(f"通过Adapter获取Cookie时发生异常: {e}") return None - async def _get_cookies_from_http(self) -> Optional[Dict[str, str]]: + async def _get_cookies_from_http(self) -> dict[str, str] | None: """通过备用HTTP端点获取Cookie""" host = self.get_config("cookie.http_fallback_host", "172.20.130.55") port = self.get_config("cookie.http_fallback_port", "9999") @@ -110,7 +110,7 @@ class CookieService: logger.error(f"通过HTTP备用地址 {http_url} 获取Cookie失败: {e}") return None - async def get_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]: + async def get_cookies(self, qq_account: str, stream_id: str | None) -> dict[str, str] | None: """ 获取Cookie,按以下顺序尝试: 1. HTTP备用端点 (更稳定) diff --git a/src/plugins/built_in/maizone_refactored/services/image_service.py b/src/plugins/built_in/maizone_refactored/services/image_service.py index cbb411da7..58241ba7b 100644 --- a/src/plugins/built_in/maizone_refactored/services/image_service.py +++ b/src/plugins/built_in/maizone_refactored/services/image_service.py @@ -1,12 +1,11 @@ -# -*- coding: utf-8 -*- """ 图片服务模块 负责处理所有与图片相关的任务,特别是AI生成图片。 """ import base64 +from collections.abc import Callable from pathlib import Path -from typing import Callable import aiohttp diff --git a/src/plugins/built_in/maizone_refactored/services/manager.py b/src/plugins/built_in/maizone_refactored/services/manager.py index 74cbb844a..ec1588bd3 100644 --- a/src/plugins/built_in/maizone_refactored/services/manager.py +++ b/src/plugins/built_in/maizone_refactored/services/manager.py @@ -1,14 +1,15 @@ -# -*- coding: utf-8 -*- """ 服务管理器/定位器 这是一个独立的模块,用于注册和获取插件内的全局服务实例,以避免循环导入。 """ -from typing import Dict, Any, Callable +from collections.abc import Callable +from typing import Any + from .qzone_service import QZoneService # --- 全局服务注册表 --- -_services: Dict[str, Any] = {} +_services: dict[str, Any] = {} def register_service(name: str, instance: Any): diff --git a/src/plugins/built_in/maizone_refactored/services/monitor_service.py b/src/plugins/built_in/maizone_refactored/services/monitor_service.py index 114358ea3..b479f4183 100644 --- a/src/plugins/built_in/maizone_refactored/services/monitor_service.py +++ b/src/plugins/built_in/maizone_refactored/services/monitor_service.py @@ -1,13 +1,13 @@ -# -*- coding: utf-8 -*- """ 好友动态监控服务 """ import asyncio import traceback -from typing import Callable +from collections.abc import Callable from src.common.logger import get_logger + from .qzone_service import QZoneService logger = get_logger("MaiZone.MonitorService") diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index c0e00b80d..6220595cc 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -1,32 +1,33 @@ -# -*- coding: utf-8 -*- """ QQ空间服务模块 封装了所有与QQ空间API的直接交互,是插件的核心业务逻辑层。 """ import asyncio -import orjson +import base64 import os import random import time -import base64 +from collections.abc import Callable from pathlib import Path -from typing import Callable, Optional, Dict, Any, List, Tuple +from typing import Any import aiohttp import bs4 import json5 -from src.common.logger import get_logger -from src.plugin_system.apis import config_api, person_api +import orjson + from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import ( build_readable_messages_with_id, get_raw_msg_by_timestamp_with_chat, ) +from src.common.logger import get_logger +from src.plugin_system.apis import config_api, person_api from .content_service import ContentService -from .image_service import ImageService from .cookie_service import CookieService +from .image_service import ImageService from .reply_tracker_service import ReplyTrackerService logger = get_logger("MaiZone.QZoneService") @@ -64,7 +65,7 @@ class QZoneService: # --- Public Methods (High-Level Business Logic) --- - async def send_feed(self, topic: str, stream_id: Optional[str]) -> Dict[str, Any]: + async def send_feed(self, topic: str, stream_id: str | None) -> dict[str, Any]: """发送一条说说""" # --- 获取互通组上下文 --- context = await self._get_intercom_context(stream_id) if stream_id else None @@ -92,7 +93,7 @@ class QZoneService: logger.error(f"发布说说时发生异常: {e}", exc_info=True) return {"success": False, "message": f"发布说说异常: {e}"} - async def send_feed_from_activity(self, activity: str) -> Dict[str, Any]: + async def send_feed_from_activity(self, activity: str) -> dict[str, Any]: """根据日程活动发送一条说说""" story = await self.content_service.generate_story_from_activity(activity) if not story: @@ -118,7 +119,7 @@ class QZoneService: logger.error(f"根据活动发布说说时发生异常: {e}", exc_info=True) return {"success": False, "message": f"发布说说异常: {e}"} - async def read_and_process_feeds(self, target_name: str, stream_id: Optional[str]) -> Dict[str, Any]: + async def read_and_process_feeds(self, target_name: str, stream_id: str | None) -> dict[str, Any]: """读取并处理指定好友的说说""" target_person_id = await person_api.get_person_id_by_name(target_name) if not target_person_id: @@ -147,7 +148,7 @@ class QZoneService: logger.error(f"读取和处理说说时发生异常: {e}", exc_info=True) return {"success": False, "message": f"处理说说异常: {e}"} - async def monitor_feeds(self, stream_id: Optional[str] = None): + async def monitor_feeds(self, stream_id: str | None = None): """监控并处理所有好友的动态,包括回复自己说说的评论""" logger.info("开始执行好友动态监控...") qq_account = config_api.get_global_config("bot.qq_account", "") @@ -189,7 +190,7 @@ class QZoneService: # --- Internal Helper Methods --- - async def _get_intercom_context(self, stream_id: str) -> Optional[str]: + async def _get_intercom_context(self, stream_id: str) -> str | None: """ 根据 stream_id 查找其所属的互通组,并构建该组的聊天上下文。 @@ -247,7 +248,7 @@ class QZoneService: logger.debug(f"Stream ID '{stream_id}' 未在任何互通组中找到。") return None - async def _reply_to_own_feed_comments(self, feed: Dict, api_client: Dict): + async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict): """处理对自己说说的评论并进行回复""" qq_account = config_api.get_global_config("bot.qq_account", "") comments = feed.get("comments", []) @@ -309,7 +310,7 @@ class QZoneService: if comment_key in self.processing_comments: self.processing_comments.remove(comment_key) - async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: List[Dict]): + async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: list[dict]): """验证并清理已删除的回复记录""" # 获取当前记录中该说说的所有已回复评论ID recorded_replied_comments = self.reply_tracker.get_replied_comments(fid) @@ -333,7 +334,7 @@ class QZoneService: self.reply_tracker.remove_reply_record(fid, comment_tid) logger.debug(f"已清理删除的回复记录: feed_id={fid}, comment_id={comment_tid}") - async def _process_single_feed(self, feed: Dict, api_client: Dict, target_qq: str, target_name: str): + async def _process_single_feed(self, feed: dict, api_client: dict, target_qq: str, target_name: str): """处理单条说说,决定是否评论和点赞""" content = feed.get("content", "") fid = feed.get("tid", "") @@ -371,7 +372,7 @@ class QZoneService: if random.random() <= self.get_config("read.like_possibility", 1.0): await api_client["like"](target_qq, fid) - def _load_local_images(self, image_dir: str) -> List[bytes]: + def _load_local_images(self, image_dir: str) -> list[bytes]: """随机加载本地图片(不删除文件)""" images = [] if not image_dir or not os.path.exists(image_dir): @@ -432,7 +433,7 @@ class QZoneService: hash_val += (hash_val << 5) + ord(char) return str(hash_val & 2147483647) - async def _renew_and_load_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]: + async def _renew_and_load_cookies(self, qq_account: str, stream_id: str | None) -> dict[str, str] | None: cookie_dir = Path(__file__).resolve().parent.parent / "cookies" cookie_dir.mkdir(exist_ok=True) cookie_file_path = cookie_dir / f"cookies-{qq_account}.json" @@ -480,7 +481,7 @@ class QZoneService: logger.error("所有获取Cookie的方式均失败。") return None - async def _fetch_cookies_http(self, host: str, port: int, napcat_token: str) -> Optional[Dict]: + async def _fetch_cookies_http(self, host: str, port: int, napcat_token: str) -> dict | None: """通过HTTP服务器获取Cookie""" # 从配置中读取主机和端口,如果未提供则使用传入的参数 final_host = self.get_config("cookie.http_fallback_host", host) @@ -515,19 +516,19 @@ class QZoneService: except aiohttp.ClientError as e: if attempt < max_retries - 1: - logger.warning(f"无法连接到Napcat服务(尝试 {attempt + 1}/{max_retries}): {url},错误: {str(e)}") + logger.warning(f"无法连接到Napcat服务(尝试 {attempt + 1}/{max_retries}): {url},错误: {e!s}") await asyncio.sleep(retry_delay) retry_delay *= 2 continue - logger.error(f"无法连接到Napcat服务(最终尝试): {url},错误: {str(e)}") + logger.error(f"无法连接到Napcat服务(最终尝试): {url},错误: {e!s}") raise RuntimeError(f"无法连接到Napcat服务: {url}") from e except Exception as e: - logger.error(f"获取cookie异常: {str(e)}") + logger.error(f"获取cookie异常: {e!s}") raise raise RuntimeError(f"无法连接到Napcat服务: 超过最大重试次数({max_retries})") - async def _get_api_client(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict]: + async def _get_api_client(self, qq_account: str, stream_id: str | None) -> dict | None: cookies = await self.cookie_service.get_cookies(qq_account, stream_id) if not cookies: logger.error( @@ -559,7 +560,7 @@ class QZoneService: response.raise_for_status() return await response.text() - async def _publish(content: str, images: List[bytes]) -> Tuple[bool, str]: + async def _publish(content: str, images: list[bytes]) -> tuple[bool, str]: """发布说说""" try: post_data = { @@ -660,7 +661,7 @@ class QZoneService: return picbo, richval - async def _upload_image(image_bytes: bytes, index: int) -> Optional[Dict[str, str]]: + async def _upload_image(image_bytes: bytes, index: int) -> dict[str, str] | None: """上传图片到QQ空间(完全按照原版实现)""" try: upload_url = "https://up.qzone.qq.com/cgi-bin/upload/cgi_upload_image" @@ -745,7 +746,7 @@ class QZoneService: logger.error(f"上传图片 {index + 1} 异常: {e}", exc_info=True) return None - async def _list_feeds(t_qq: str, num: int) -> List[Dict]: + async def _list_feeds(t_qq: str, num: int) -> list[dict]: """获取指定用户说说列表 (统一接口)""" try: # 统一使用 format=json 获取完整评论 @@ -920,7 +921,7 @@ class QZoneService: logger.error(f"回复评论异常: {e}", exc_info=True) return False - async def _monitor_list_feeds(num: int) -> List[Dict]: + async def _monitor_list_feeds(num: int) -> list[dict]: """监控好友动态""" try: params = { diff --git a/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py index 0fa7edb99..6baa30d21 100644 --- a/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py +++ b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 评论回复跟踪服务 负责记录和管理已回复过的评论ID,避免重复回复 @@ -7,7 +6,8 @@ import json import time from pathlib import Path -from typing import Set, Dict, Any, Union +from typing import Any + from src.common.logger import get_logger logger = get_logger("MaiZone.ReplyTrackerService") @@ -27,7 +27,7 @@ class ReplyTrackerService: # 内存中的已回复评论记录 # 格式: {feed_id: {comment_id: timestamp, ...}, ...} - self.replied_comments: Dict[str, Dict[str, float]] = {} + self.replied_comments: dict[str, dict[str, float]] = {} # 数据清理配置 self.max_record_days = 30 # 保留30天的记录 @@ -64,7 +64,7 @@ class ReplyTrackerService: try: if self.reply_record_file.exists(): try: - with open(self.reply_record_file, "r", encoding="utf-8") as f: + with open(self.reply_record_file, encoding="utf-8") as f: file_content = f.read().strip() if not file_content: # 文件为空 logger.warning("回复记录文件为空,将创建新的记录") @@ -173,7 +173,7 @@ class ReplyTrackerService: if total_removed > 0: logger.info(f"清理了 {total_removed} 条超过{self.max_record_days}天的过期回复记录") - def has_replied(self, feed_id: str, comment_id: Union[str, int]) -> bool: + def has_replied(self, feed_id: str, comment_id: str | int) -> bool: """ 检查是否已经回复过指定的评论 @@ -190,7 +190,7 @@ class ReplyTrackerService: comment_id_str = str(comment_id) return feed_id in self.replied_comments and comment_id_str in self.replied_comments[feed_id] - def mark_as_replied(self, feed_id: str, comment_id: Union[str, int]): + def mark_as_replied(self, feed_id: str, comment_id: str | int): """ 标记指定评论为已回复 @@ -219,7 +219,7 @@ class ReplyTrackerService: else: logger.error(f"标记评论时数据验证失败: feed_id={feed_id}, comment_id={comment_id}") - def get_replied_comments(self, feed_id: str) -> Set[str]: + def get_replied_comments(self, feed_id: str) -> set[str]: """ 获取指定说说下所有已回复的评论ID @@ -234,7 +234,7 @@ class ReplyTrackerService: return {str(comment_id) for comment_id in self.replied_comments[feed_id].keys()} return set() - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: """ 获取回复记录统计信息 diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index 770ced8e6..7cf0e7c93 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 定时任务服务 根据日程表定时发送说说。 @@ -8,13 +7,14 @@ import asyncio import datetime import random import traceback -from typing import Callable +from collections.abc import Callable +from sqlalchemy import select + +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus from src.common.logger import get_logger from src.schedule.schedule_manager import schedule_manager -from src.common.database.sqlalchemy_database_api import get_db_session -from sqlalchemy import select -from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus from .qzone_service import QZoneService diff --git a/src/plugins/built_in/maizone_refactored/utils/history_utils.py b/src/plugins/built_in/maizone_refactored/utils/history_utils.py index 19b3e7baa..6f51a6c0d 100644 --- a/src/plugins/built_in/maizone_refactored/utils/history_utils.py +++ b/src/plugins/built_in/maizone_refactored/utils/history_utils.py @@ -1,15 +1,15 @@ -# -*- coding: utf-8 -*- """ 历史记录工具模块 提供用于获取QQ空间发送历史的功能。 """ -import orjson import os from pathlib import Path -from typing import Dict, Any, Optional, List +from typing import Any +import orjson import requests + from src.common.logger import get_logger logger = get_logger("MaiZone.HistoryUtils") @@ -26,11 +26,11 @@ class _CookieManager: return str(cookie_dir / f"cookies-{uin}.json") @staticmethod - def load_cookies(qq_account: str) -> Optional[Dict[str, str]]: + def load_cookies(qq_account: str) -> dict[str, str] | None: cookie_file = _CookieManager.get_cookie_file_path(qq_account) if os.path.exists(cookie_file): try: - with open(cookie_file, "r", encoding="utf-8") as f: + with open(cookie_file, encoding="utf-8") as f: return orjson.loads(f.read()) except Exception as e: logger.error(f"加载Cookie文件失败: {e}") @@ -42,7 +42,7 @@ class _SimpleQZoneAPI: LIST_URL = "https://user.qzone.qq.com/proxy/domain/taotao.qq.com/cgi-bin/emotion_cgi_msglist_v6" - def __init__(self, cookies_dict: Optional[Dict[str, str]] = None): + def __init__(self, cookies_dict: dict[str, str] | None = None): self.cookies = cookies_dict or {} self.gtk2 = "" p_skey = self.cookies.get("p_skey") or self.cookies.get("p_skey".upper()) @@ -55,7 +55,7 @@ class _SimpleQZoneAPI: hash_val += (hash_val << 5) + ord(char) return str(hash_val & 2147483647) - def get_feed_list(self, target_qq: str, num: int) -> List[Dict[str, Any]]: + def get_feed_list(self, target_qq: str, num: int) -> list[dict[str, Any]]: try: params = { "g_tk": self.gtk2, diff --git a/src/plugins/built_in/permission_management/plugin.py b/src/plugins/built_in/permission_management/plugin.py index fd8612348..d85ca8dd5 100644 --- a/src/plugins/built_in/permission_management/plugin.py +++ b/src/plugins/built_in/permission_management/plugin.py @@ -6,19 +6,17 @@ """ import re -from typing import List, Optional, Tuple, Type +from src.plugin_system.apis.logging_api import get_logger +from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.apis.plugin_register_api import register_plugin from src.plugin_system.base.base_plugin import BasePlugin -from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.base.command_args import CommandArgs -from src.plugin_system.apis.permission_api import permission_api -from src.plugin_system.apis.logging_api import get_logger -from src.plugin_system.base.component_types import PlusCommandInfo, ChatType +from src.plugin_system.base.component_types import ChatType, PlusCommandInfo from src.plugin_system.base.config_types import ConfigField +from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.utils.permission_decorators import require_permission - logger = get_logger("Permission") @@ -44,7 +42,7 @@ class PermissionCommand(PlusCommand): "plugin.permission.view", "权限查看:可以查看权限节点和用户权限信息", "permission_manager", True ) - async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: """执行权限管理命令""" if args.is_empty: await self._show_help() @@ -114,7 +112,7 @@ class PermissionCommand(PlusCommand): await self.send_text(help_text) @staticmethod - def _parse_user_mention(mention: str) -> Optional[str]: + def _parse_user_mention(mention: str) -> str | None: """解析用户提及,提取QQ号 支持的格式: @@ -134,7 +132,7 @@ class PermissionCommand(PlusCommand): return None @staticmethod - def parse_user_from_args(args: CommandArgs, index: int = 0) -> Optional[str]: + def parse_user_from_args(args: CommandArgs, index: int = 0) -> str | None: """从CommandArgs中解析用户ID Args: @@ -166,7 +164,7 @@ class PermissionCommand(PlusCommand): return None @require_permission("plugin.permission.manage", "❌ 你没有权限管理的权限") - async def _grant_permission(self, chat_stream, args: List[str]): + async def _grant_permission(self, chat_stream, args: list[str]): """授权用户权限""" if len(args) < 2: await self.send_text("❌ 用法: /permission grant <@用户|QQ号> <权限节点>") @@ -189,7 +187,7 @@ class PermissionCommand(PlusCommand): await self.send_text("❌ 授权失败,请检查权限节点是否存在") @require_permission("plugin.permission.manage", "❌ 你没有权限管理的权限") - async def _revoke_permission(self, chat_stream, args: List[str]): + async def _revoke_permission(self, chat_stream, args: list[str]): """撤销用户权限""" if len(args) < 2: await self.send_text("❌ 用法: /permission revoke <@用户|QQ号> <权限节点>") @@ -212,7 +210,7 @@ class PermissionCommand(PlusCommand): await self.send_text("❌ 撤销失败,请检查权限节点是否存在") @require_permission("plugin.permission.view", "❌ 你没有查看权限的权限") - async def _list_permissions(self, chat_stream, args: List[str]): + async def _list_permissions(self, chat_stream, args: list[str]): """列出用户权限""" target_user_id = None @@ -244,7 +242,7 @@ class PermissionCommand(PlusCommand): await self.send_text(response) @require_permission("plugin.permission.view", "❌ 你没有查看权限的权限") - async def _check_permission(self, chat_stream, args: List[str]): + async def _check_permission(self, chat_stream, args: list[str]): """检查用户权限""" if len(args) < 2: await self.send_text("❌ 用法: /permission check <@用户|QQ号> <权限节点>") @@ -273,7 +271,7 @@ class PermissionCommand(PlusCommand): await self.send_text(response) @require_permission("plugin.permission.view", "❌ 你没有查看权限的权限") - async def _list_nodes(self, chat_stream, args: List[str]): + async def _list_nodes(self, chat_stream, args: list[str]): """列出权限节点""" plugin_name = args[0] if args else None @@ -388,6 +386,6 @@ class PermissionManagerPlugin(BasePlugin): } } - def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]: + def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type[PlusCommand]]]: """返回插件的PlusCommand组件""" return [(PermissionCommand.get_plus_command_info(), PermissionCommand)] diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index 5061cf496..56199611e 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -1,19 +1,18 @@ import asyncio -from typing import List, Tuple, Type from src.plugin_system import ( BasePlugin, - ConfigField, - register_plugin, - plugin_manage_api, - component_manage_api, ComponentInfo, ComponentType, + ConfigField, + component_manage_api, + plugin_manage_api, + register_plugin, ) -from src.plugin_system.base.plus_command import PlusCommand -from src.plugin_system.base.command_args import CommandArgs -from src.plugin_system.base.component_types import PlusCommandInfo, ChatType from src.plugin_system.apis.permission_api import permission_api +from src.plugin_system.base.command_args import CommandArgs +from src.plugin_system.base.component_types import ChatType, PlusCommandInfo +from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.utils.permission_decorators import require_permission @@ -31,7 +30,7 @@ class ManagementCommand(PlusCommand): super().__init__(*args, **kwargs) @require_permission("plugin.management.admin", "❌ 你没有插件管理的权限") - async def execute(self, args: CommandArgs) -> Tuple[bool, str, bool]: + async def execute(self, args: CommandArgs) -> tuple[bool, str, bool]: """执行插件管理命令""" if args.is_empty: await self._show_help("all") @@ -51,7 +50,7 @@ class ManagementCommand(PlusCommand): await self.send_text(f"❌ 未知的子命令: {subcommand}\n使用 /pm help 查看帮助") return True, "未知子命令", True - async def _handle_plugin_commands(self, args: List[str]) -> Tuple[bool, str, bool]: + async def _handle_plugin_commands(self, args: list[str]) -> tuple[bool, str, bool]: """处理插件相关命令""" if not args: await self._show_help("plugin") @@ -83,7 +82,7 @@ class ManagementCommand(PlusCommand): return True, "插件命令执行完成", True - async def _handle_component_commands(self, args: List[str]) -> Tuple[bool, str, bool]: + async def _handle_component_commands(self, args: list[str]) -> tuple[bool, str, bool]: """处理组件相关命令""" if not args: await self._show_help("component") @@ -258,7 +257,7 @@ class ManagementCommand(PlusCommand): else: await self.send_text(f"❌ 插件强制重载失败: `{plugin_name}`") except Exception as e: - await self.send_text(f"❌ 强制重载过程中发生错误: {str(e)}") + await self.send_text(f"❌ 强制重载过程中发生错误: {e!s}") async def _add_dir(self, dir_path: str): """添加插件目录""" @@ -271,17 +270,17 @@ class ManagementCommand(PlusCommand): await self.send_text(f"❌ 插件目录添加失败: `{dir_path}`") @staticmethod - def _fetch_all_registered_components() -> List[ComponentInfo]: + def _fetch_all_registered_components() -> list[ComponentInfo]: all_plugin_info = component_manage_api.get_all_plugin_info() if not all_plugin_info: return [] - components_info: List[ComponentInfo] = [] + components_info: list[ComponentInfo] = [] for plugin_info in all_plugin_info.values(): components_info.extend(plugin_info.components) return components_info - def _fetch_locally_disabled_components(self) -> List[str]: + def _fetch_locally_disabled_components(self) -> list[str]: """获取本地禁用的组件列表""" stream_id = self.message.chat_stream.stream_id locally_disabled_components_actions = component_manage_api.get_locally_disabled_components( @@ -509,7 +508,7 @@ class PluginManagementPlugin(BasePlugin): False, ) - def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]: + def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type[PlusCommand]]]: """返回插件的PlusCommand组件""" components = [] if self.get_config("plugin.enabled", True): diff --git a/src/plugins/built_in/proactive_thinker/plugin.py b/src/plugins/built_in/proactive_thinker/plugin.py index 5e55e9101..e74c35c8b 100644 --- a/src/plugins/built_in/proactive_thinker/plugin.py +++ b/src/plugins/built_in/proactive_thinker/plugin.py @@ -1,14 +1,13 @@ -from typing import List, Tuple, Type - from src.common.logger import get_logger -from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system import ( + BaseEventHandler, BasePlugin, ConfigField, - register_plugin, EventHandlerInfo, - BaseEventHandler, + register_plugin, ) +from src.plugin_system.base.base_plugin import BasePlugin + from .proacive_thinker_event import ProactiveThinkerEventHandler logger = get_logger(__name__) @@ -33,9 +32,9 @@ class ProactiveThinkerPlugin(BasePlugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def get_plugin_components(self) -> List[Tuple[EventHandlerInfo, Type[BaseEventHandler]]]: + def get_plugin_components(self) -> list[tuple[EventHandlerInfo, type[BaseEventHandler]]]: """返回插件的EventHandler组件""" - components: List[Tuple[EventHandlerInfo, Type[BaseEventHandler]]] = [ + components: list[tuple[EventHandlerInfo, type[BaseEventHandler]]] = [ (ProactiveThinkerEventHandler.get_handler_info(), ProactiveThinkerEventHandler) ] return components diff --git a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py index 5ad560243..be818f037 100644 --- a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py +++ b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py @@ -2,17 +2,17 @@ import asyncio import random import time from datetime import datetime -from typing import List, Union from maim_message import UserInfo from src.chat.message_receive.chat_stream import get_chat_manager from src.common.logger import get_logger from src.config.config import global_config -from src.manager.async_task_manager import async_task_manager, AsyncTask -from src.plugin_system import EventType, BaseEventHandler +from src.manager.async_task_manager import AsyncTask, async_task_manager +from src.plugin_system import BaseEventHandler, EventType from src.plugin_system.apis import chat_api, person_api from src.plugin_system.base.base_event import HandlerResult + from .proactive_thinker_executor import ProactiveThinkerExecutor logger = get_logger(__name__) @@ -199,7 +199,7 @@ class ProactiveThinkerEventHandler(BaseEventHandler): handler_name: str = "proactive_thinker_on_start" handler_description: str = "主动思考插件的启动事件处理器" - init_subscribe: List[Union[EventType, str]] = [EventType.ON_START] + init_subscribe: list[EventType | str] = [EventType.ON_START] async def execute(self, kwargs: dict | None) -> "HandlerResult": """在机器人启动时执行,根据配置决定是否启动后台任务。""" diff --git a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py index ab3631450..2accabe5e 100644 --- a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py +++ b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py @@ -1,20 +1,21 @@ -import orjson -from typing import Optional, Dict, Any from datetime import datetime +from typing import Any + +import orjson from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.person_info.person_info import get_person_info_manager from src.plugin_system.apis import ( chat_api, + database_api, + generator_api, + llm_api, + message_api, person_api, schedule_api, send_api, - llm_api, - message_api, - generator_api, - database_api, ) -from src.config.config import global_config, model_config -from src.person_info.person_info import get_person_info_manager logger = get_logger(__name__) @@ -101,7 +102,7 @@ class ProactiveThinkerExecutor: logger.error(f"解析 stream_id ({stream_id}) 或获取 stream 失败: {e}") return None - async def _gather_context(self, stream_id: str) -> Optional[Dict[str, Any]]: + async def _gather_context(self, stream_id: str) -> dict[str, Any] | None: """ 收集构建提示词所需的所有上下文信息 """ @@ -165,7 +166,7 @@ class ProactiveThinkerExecutor: "current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } - async def _make_decision(self, context: Dict[str, Any], start_mode: str) -> Optional[Dict[str, Any]]: + async def _make_decision(self, context: dict[str, Any], start_mode: str) -> dict[str, Any] | None: """ 决策模块:判断是否应该主动发起对话,以及聊什么话题 """ @@ -234,7 +235,7 @@ class ProactiveThinkerExecutor: logger.error(f"决策LLM返回的JSON格式无效: {response}") return {"should_reply": False, "reason": "决策模型返回格式错误"} - def _build_plan_prompt(self, context: Dict[str, Any], start_mode: str, topic: str, reason: str) -> str: + def _build_plan_prompt(self, context: dict[str, Any], start_mode: str, topic: str, reason: str) -> str: """ 根据启动模式和决策话题,构建最终的规划提示词 """ diff --git a/src/plugins/built_in/social_toolkit_plugin/plugin.py b/src/plugins/built_in/social_toolkit_plugin/plugin.py index a26879da7..71bb83767 100644 --- a/src/plugins/built_in/social_toolkit_plugin/plugin.py +++ b/src/plugins/built_in/social_toolkit_plugin/plugin.py @@ -1,24 +1,25 @@ -import re -from typing import List, Tuple, Type, Optional - -from src.plugin_system import ( - BasePlugin, - register_plugin, - BaseAction, - ComponentInfo, - ActionActivationType, - ConfigField, -) -from src.common.logger import get_logger -from .qq_emoji_list import qq_face -from src.plugin_system.base.component_types import ChatType -from src.person_info.person_info import get_person_info_manager -from dateutil.parser import parse as parse_datetime -from src.manager.async_task_manager import AsyncTask, async_task_manager -from src.plugin_system.apis import send_api, llm_api, generator_api -from src.chat.message_receive.chat_stream import ChatStream import asyncio import datetime +import re + +from dateutil.parser import parse as parse_datetime + +from src.chat.message_receive.chat_stream import ChatStream +from src.common.logger import get_logger +from src.manager.async_task_manager import AsyncTask, async_task_manager +from src.person_info.person_info import get_person_info_manager +from src.plugin_system import ( + ActionActivationType, + BaseAction, + BasePlugin, + ComponentInfo, + ConfigField, + register_plugin, +) +from src.plugin_system.apis import generator_api, llm_api, send_api +from src.plugin_system.base.component_types import ChatType + +from .qq_emoji_list import qq_face logger = get_logger("set_emoji_like_plugin") @@ -30,7 +31,7 @@ class ReminderTask(AsyncTask): self, delay: float, stream_id: str, - group_id: Optional[str], + group_id: str | None, is_group: bool, target_user_id: str, target_user_name: str, @@ -162,7 +163,7 @@ class PokeAction(BaseAction): """ associated_types = ["text"] - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """执行戳一戳的动作""" user_id = self.action_data.get("user_id") user_name = self.action_data.get("user_name") @@ -242,7 +243,7 @@ class SetEmojiLikeAction(BaseAction): if match: emoji_options.append(match.group(1)) - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """执行设置表情回应的动作""" message_id = None set_like = self.action_data.get("set", True) @@ -360,7 +361,7 @@ class RemindAction(BaseAction): "例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'", ] - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """执行设置提醒的动作""" user_name = self.action_data.get("user_name") remind_time_str = self.action_data.get("remind_time") @@ -386,14 +387,14 @@ class RemindAction(BaseAction): # 优先尝试直接解析 try: target_time = parse_datetime(remind_time_str, fuzzy=True) - except Exception: + except Exception as e: # 如果直接解析失败,调用 LLM 进行转换 logger.info(f"[ReminderPlugin] 直接解析时间 '{remind_time_str}' 失败,尝试使用 LLM 进行转换...") # 获取所有可用的模型配置 available_models = llm_api.get_available_models() if "utils_small" not in available_models: - raise ValueError("未找到 'utils_small' 模型配置,无法解析时间") + raise ValueError("未找到 'utils_small' 模型配置,无法解析时间") from e # 明确使用 'planner' 模型 model_to_use = available_models["utils_small"] @@ -419,7 +420,7 @@ class RemindAction(BaseAction): ) if not success or not response: - raise ValueError(f"LLM未能返回有效的时间字符串: {response}") + raise ValueError(f"LLM未能返回有效的时间字符串: {response}") from e converted_time_str = response.strip() logger.info(f"[ReminderPlugin] LLM 转换结果: '{converted_time_str}'") @@ -533,8 +534,8 @@ class SetEmojiLikePlugin(BasePlugin): # 插件基本信息 plugin_name: str = "social_toolkit_plugin" # 内部标识符 enable_plugin: bool = True - dependencies: List[str] = [] # 插件依赖列表 - python_dependencies: List[str] = [] # Python包依赖列表,现在使用内置API + dependencies: list[str] = [] # 插件依赖列表 + python_dependencies: list[str] = [] # Python包依赖列表,现在使用内置API config_file_name: str = "config.toml" # 配置文件名 # 配置节描述 @@ -555,7 +556,7 @@ class SetEmojiLikePlugin(BasePlugin): }, } - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: enable_components = [] if self.get_config("components.action_set_emoji_like"): enable_components.append((SetEmojiLikeAction.get_action_info(), SetEmojiLikeAction)) diff --git a/src/plugins/built_in/tts_plugin/plugin.py b/src/plugins/built_in/tts_plugin/plugin.py index fc625c093..8d1327a4f 100644 --- a/src/plugins/built_in/tts_plugin/plugin.py +++ b/src/plugins/built_in/tts_plugin/plugin.py @@ -1,10 +1,9 @@ +from src.common.logger import get_logger from src.plugin_system.apis.plugin_register_api import register_plugin +from src.plugin_system.base.base_action import ActionActivationType, BaseAction, ChatMode from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system.base.component_types import ComponentInfo -from src.common.logger import get_logger -from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode from src.plugin_system.base.config_types import ConfigField -from typing import Tuple, List, Type logger = get_logger("tts") @@ -44,7 +43,7 @@ class TTSAction(BaseAction): # 关联类型 associated_types = ["tts_text"] - async def execute(self) -> Tuple[bool, str]: + async def execute(self) -> tuple[bool, str]: """处理TTS文本转语音动作""" logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}") @@ -140,7 +139,7 @@ class TTSPlugin(BasePlugin): }, } - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: """返回插件包含的组件列表""" # 从配置获取组件启用状态 diff --git a/src/plugins/built_in/web_search_tool/engines/base.py b/src/plugins/built_in/web_search_tool/engines/base.py index 30d20a540..4fd2c452a 100644 --- a/src/plugins/built_in/web_search_tool/engines/base.py +++ b/src/plugins/built_in/web_search_tool/engines/base.py @@ -3,7 +3,7 @@ Base search engine interface """ from abc import ABC, abstractmethod -from typing import Dict, List, Any +from typing import Any class BaseSearchEngine(ABC): @@ -12,7 +12,7 @@ class BaseSearchEngine(ABC): """ @abstractmethod - async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: + async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]: """ 执行搜索 diff --git a/src/plugins/built_in/web_search_tool/engines/bing_engine.py b/src/plugins/built_in/web_search_tool/engines/bing_engine.py index ece747fbd..46431bff1 100644 --- a/src/plugins/built_in/web_search_tool/engines/bing_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/bing_engine.py @@ -6,11 +6,13 @@ import asyncio import functools import random import traceback -from typing import Dict, List, Any +from typing import Any + import requests from bs4 import BeautifulSoup from src.common.logger import get_logger + from .base import BaseSearchEngine logger = get_logger("bing_engine") @@ -68,7 +70,7 @@ class BingSearchEngine(BaseSearchEngine): """检查Bing搜索引擎是否可用""" return True # Bing是免费搜索引擎,总是可用 - async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: + async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]: """执行Bing搜索""" query = args["query"] num_results = args.get("num_results", 3) @@ -83,7 +85,7 @@ class BingSearchEngine(BaseSearchEngine): logger.error(f"Bing 搜索失败: {e}") return [] - def _search_sync(self, keyword: str, num_results: int, time_range: str) -> List[Dict[str, Any]]: + def _search_sync(self, keyword: str, num_results: int, time_range: str) -> list[dict[str, Any]]: """同步执行Bing搜索""" if not keyword: return [] @@ -113,7 +115,7 @@ class BingSearchEngine(BaseSearchEngine): return list_result[:num_results] if len(list_result) > num_results else list_result @staticmethod - def _parse_html(url: str) -> List[Dict[str, Any]]: + def _parse_html(url: str) -> list[dict[str, Any]]: """解析处理结果""" try: logger.debug(f"访问Bing搜索URL: {url}") @@ -141,11 +143,11 @@ class BingSearchEngine(BaseSearchEngine): try: res = session.get(url=url, timeout=(3.05, 6), verify=True, allow_redirects=True) except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e: - logger.warning(f"第一次请求超时,正在重试: {str(e)}") + logger.warning(f"第一次请求超时,正在重试: {e!s}") try: res = session.get(url=url, timeout=(5, 10), verify=False) except Exception as e2: - logger.error(f"第二次请求也失败: {str(e2)}") + logger.error(f"第二次请求也失败: {e2!s}") return [] res.encoding = "utf-8" @@ -175,7 +177,7 @@ class BingSearchEngine(BaseSearchEngine): try: root = BeautifulSoup(res.text, "html.parser") except Exception as e: - logger.error(f"HTML解析失败: {str(e)}") + logger.error(f"HTML解析失败: {e!s}") return [] list_data = [] @@ -262,6 +264,6 @@ class BingSearchEngine(BaseSearchEngine): return list_data except Exception as e: - logger.error(f"解析Bing页面时出错: {str(e)}") + logger.error(f"解析Bing页面时出错: {e!s}") logger.debug(traceback.format_exc()) return [] diff --git a/src/plugins/built_in/web_search_tool/engines/ddg_engine.py b/src/plugins/built_in/web_search_tool/engines/ddg_engine.py index 29f03b31a..eb73f6bcd 100644 --- a/src/plugins/built_in/web_search_tool/engines/ddg_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/ddg_engine.py @@ -2,10 +2,12 @@ DuckDuckGo search engine implementation """ -from typing import Dict, List, Any +from typing import Any + from asyncddgs import aDDGS from src.common.logger import get_logger + from .base import BaseSearchEngine logger = get_logger("ddg_engine") @@ -20,7 +22,7 @@ class DDGSearchEngine(BaseSearchEngine): """检查DuckDuckGo搜索引擎是否可用""" return True # DuckDuckGo不需要API密钥,总是可用 - async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: + async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]: """执行DuckDuckGo搜索""" query = args["query"] num_results = args.get("num_results", 3) diff --git a/src/plugins/built_in/web_search_tool/engines/exa_engine.py b/src/plugins/built_in/web_search_tool/engines/exa_engine.py index 269e32bd1..37655eb53 100644 --- a/src/plugins/built_in/web_search_tool/engines/exa_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/exa_engine.py @@ -5,13 +5,15 @@ Exa search engine implementation import asyncio import functools from datetime import datetime, timedelta -from typing import Dict, List, Any +from typing import Any + from exa_py import Exa from src.common.logger import get_logger from src.plugin_system.apis import config_api -from .base import BaseSearchEngine + from ..utils.api_key_manager import create_api_key_manager_from_config +from .base import BaseSearchEngine logger = get_logger("exa_engine") @@ -36,7 +38,7 @@ class ExaSearchEngine(BaseSearchEngine): """检查Exa搜索引擎是否可用""" return self.api_manager.is_available() - async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: + async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]: """执行Exa搜索""" if not self.is_available(): return [] diff --git a/src/plugins/built_in/web_search_tool/engines/tavily_engine.py b/src/plugins/built_in/web_search_tool/engines/tavily_engine.py index 2f929284f..acbe23d81 100644 --- a/src/plugins/built_in/web_search_tool/engines/tavily_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/tavily_engine.py @@ -4,13 +4,15 @@ Tavily search engine implementation import asyncio import functools -from typing import Dict, List, Any +from typing import Any + from tavily import TavilyClient from src.common.logger import get_logger from src.plugin_system.apis import config_api -from .base import BaseSearchEngine + from ..utils.api_key_manager import create_api_key_manager_from_config +from .base import BaseSearchEngine logger = get_logger("tavily_engine") @@ -37,7 +39,7 @@ class TavilySearchEngine(BaseSearchEngine): """检查Tavily搜索引擎是否可用""" return self.api_manager.is_available() - async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: + async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]: """执行Tavily搜索""" if not self.is_available(): return [] diff --git a/src/plugins/built_in/web_search_tool/plugin.py b/src/plugins/built_in/web_search_tool/plugin.py index fadc02a88..2b85104bc 100644 --- a/src/plugins/built_in/web_search_tool/plugin.py +++ b/src/plugins/built_in/web_search_tool/plugin.py @@ -4,14 +4,12 @@ Web Search Tool Plugin 一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。 """ -from typing import List, Tuple, Type - -from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency -from src.plugin_system.apis import config_api from src.common.logger import get_logger +from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin +from src.plugin_system.apis import config_api -from .tools.web_search import WebSurfingTool from .tools.url_parser import URLParserTool +from .tools.web_search import WebSurfingTool logger = get_logger("web_search_plugin") @@ -31,7 +29,7 @@ class WEBSEARCHPLUGIN(BasePlugin): # 插件基本信息 plugin_name: str = "web_search_tool" # 内部标识符 enable_plugin: bool = True - dependencies: List[str] = [] # 插件依赖列表 + dependencies: list[str] = [] # 插件依赖列表 def __init__(self, *args, **kwargs): """初始化插件,立即加载所有搜索引擎""" @@ -40,10 +38,10 @@ class WEBSEARCHPLUGIN(BasePlugin): # 立即初始化所有搜索引擎,触发API密钥管理器的日志输出 logger.info("🚀 正在初始化所有搜索引擎...") try: + from .engines.bing_engine import BingSearchEngine + from .engines.ddg_engine import DDGSearchEngine from .engines.exa_engine import ExaSearchEngine from .engines.tavily_engine import TavilySearchEngine - from .engines.ddg_engine import DDGSearchEngine - from .engines.bing_engine import BingSearchEngine # 实例化所有搜索引擎,这会触发API密钥管理器的初始化 exa_engine = ExaSearchEngine() @@ -71,7 +69,7 @@ class WEBSEARCHPLUGIN(BasePlugin): logger.error(f"❌ 搜索引擎初始化失败: {e}", exc_info=True) # Python包依赖列表 - python_dependencies: List[PythonDependency] = [ + python_dependencies: list[PythonDependency] = [ PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False), PythonDependency( package_name="exa_py", @@ -119,7 +117,7 @@ class WEBSEARCHPLUGIN(BasePlugin): }, } - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]: """ 获取插件组件列表 diff --git a/src/plugins/built_in/web_search_tool/tools/url_parser.py b/src/plugins/built_in/web_search_tool/tools/url_parser.py index 25338c35c..6e9bf5a03 100644 --- a/src/plugins/built_in/web_search_tool/tools/url_parser.py +++ b/src/plugins/built_in/web_search_tool/tools/url_parser.py @@ -4,19 +4,20 @@ URL parser tool implementation import asyncio import functools -from typing import Any, Dict -from exa_py import Exa +from typing import Any + import httpx from bs4 import BeautifulSoup +from exa_py import Exa +from src.common.cache_manager import tool_cache from src.common.logger import get_logger from src.plugin_system import BaseTool, ToolParamType, llm_api from src.plugin_system.apis import config_api -from src.common.cache_manager import tool_cache +from ..utils.api_key_manager import create_api_key_manager_from_config from ..utils.formatters import format_url_parse_results from ..utils.url_utils import parse_urls_from_input, validate_urls -from ..utils.api_key_manager import create_api_key_manager_from_config logger = get_logger("url_parser_tool") @@ -50,7 +51,7 @@ class URLParserTool(BaseTool): exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser" ) - async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]: + async def _local_parse_and_summarize(self, url: str) -> dict[str, Any]: """ 使用本地库(httpx, BeautifulSoup)解析URL,并调用LLM进行总结。 """ @@ -124,9 +125,9 @@ class URLParserTool(BaseTool): return {"error": f"请求失败,状态码: {e.response.status_code}"} except Exception as e: logger.error(f"本地解析或总结URL '{url}' 时发生未知异常: {e}", exc_info=True) - return {"error": f"发生未知错误: {str(e)}"} + return {"error": f"发生未知错误: {e!s}"} - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """ 执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。 """ diff --git a/src/plugins/built_in/web_search_tool/tools/web_search.py b/src/plugins/built_in/web_search_tool/tools/web_search.py index 3e4039cb8..9dcafc9a5 100644 --- a/src/plugins/built_in/web_search_tool/tools/web_search.py +++ b/src/plugins/built_in/web_search_tool/tools/web_search.py @@ -3,18 +3,18 @@ Web search tool implementation """ import asyncio -from typing import Any, Dict, List +from typing import Any +from src.common.cache_manager import tool_cache from src.common.logger import get_logger from src.plugin_system import BaseTool, ToolParamType from src.plugin_system.apis import config_api -from src.common.cache_manager import tool_cache +from ..engines.bing_engine import BingSearchEngine +from ..engines.ddg_engine import DDGSearchEngine from ..engines.exa_engine import ExaSearchEngine from ..engines.tavily_engine import TavilySearchEngine -from ..engines.ddg_engine import DDGSearchEngine -from ..engines.bing_engine import BingSearchEngine -from ..utils.formatters import format_search_results, deduplicate_results +from ..utils.formatters import deduplicate_results, format_search_results logger = get_logger("web_search_tool") @@ -51,7 +51,7 @@ class WebSurfingTool(BaseTool): "bing": BingSearchEngine(), } - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: query = function_args.get("query") if not query: return {"error": "搜索查询不能为空。"} @@ -88,8 +88,8 @@ class WebSurfingTool(BaseTool): return result async def _execute_parallel_search( - self, function_args: Dict[str, Any], enabled_engines: List[str] - ) -> Dict[str, Any]: + self, function_args: dict[str, Any], enabled_engines: list[str] + ) -> dict[str, Any]: """并行搜索策略:同时使用所有启用的搜索引擎""" search_tasks = [] @@ -124,11 +124,11 @@ class WebSurfingTool(BaseTool): except Exception as e: logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True) - return {"error": f"执行网络搜索时发生严重错误: {str(e)}"} + return {"error": f"执行网络搜索时发生严重错误: {e!s}"} async def _execute_fallback_search( - self, function_args: Dict[str, Any], enabled_engines: List[str] - ) -> Dict[str, Any]: + self, function_args: dict[str, Any], enabled_engines: list[str] + ) -> dict[str, Any]: """回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个""" for engine_name in enabled_engines: engine = self.engines.get(engine_name) @@ -154,7 +154,7 @@ class WebSurfingTool(BaseTool): return {"error": "所有搜索引擎都失败了。"} - async def _execute_single_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: + async def _execute_single_search(self, function_args: dict[str, Any], enabled_engines: list[str]) -> dict[str, Any]: """单一搜索策略:只使用第一个可用的搜索引擎""" for engine_name in enabled_engines: engine = self.engines.get(engine_name) @@ -174,6 +174,6 @@ class WebSurfingTool(BaseTool): except Exception as e: logger.error(f"{engine_name} 搜索失败: {e}") - return {"error": f"{engine_name} 搜索失败: {str(e)}"} + return {"error": f"{engine_name} 搜索失败: {e!s}"} return {"error": "没有可用的搜索引擎。"} diff --git a/src/plugins/built_in/web_search_tool/utils/api_key_manager.py b/src/plugins/built_in/web_search_tool/utils/api_key_manager.py index 07757cdb1..e7aba03ce 100644 --- a/src/plugins/built_in/web_search_tool/utils/api_key_manager.py +++ b/src/plugins/built_in/web_search_tool/utils/api_key_manager.py @@ -3,7 +3,9 @@ API密钥管理器,提供轮询机制 """ import itertools -from typing import List, Optional, TypeVar, Generic, Callable +from collections.abc import Callable +from typing import Generic, TypeVar + from src.common.logger import get_logger logger = get_logger("api_key_manager") @@ -16,7 +18,7 @@ class APIKeyManager(Generic[T]): API密钥管理器,支持轮询机制 """ - def __init__(self, api_keys: List[str], client_factory: Callable[[str], T], service_name: str = "Unknown"): + def __init__(self, api_keys: list[str], client_factory: Callable[[str], T], service_name: str = "Unknown"): """ 初始化API密钥管理器 @@ -26,8 +28,8 @@ class APIKeyManager(Generic[T]): service_name: 服务名称,用于日志记录 """ self.service_name = service_name - self.clients: List[T] = [] - self.client_cycle: Optional[itertools.cycle] = None + self.clients: list[T] = [] + self.client_cycle: itertools.cycle | None = None if api_keys: # 过滤有效的API密钥,排除None、空字符串、"None"字符串等 @@ -54,7 +56,7 @@ class APIKeyManager(Generic[T]): """检查是否有可用的客户端""" return bool(self.clients and self.client_cycle) - def get_next_client(self) -> Optional[T]: + def get_next_client(self) -> T | None: """获取下一个客户端(轮询)""" if not self.is_available(): return None @@ -66,7 +68,7 @@ class APIKeyManager(Generic[T]): def create_api_key_manager_from_config( - config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str + config_keys: list[str] | None, client_factory: Callable[[str], T], service_name: str ) -> APIKeyManager[T]: """ 从配置创建API密钥管理器的便捷函数 diff --git a/src/plugins/built_in/web_search_tool/utils/formatters.py b/src/plugins/built_in/web_search_tool/utils/formatters.py index df1e4ea18..6173b0bca 100644 --- a/src/plugins/built_in/web_search_tool/utils/formatters.py +++ b/src/plugins/built_in/web_search_tool/utils/formatters.py @@ -2,10 +2,10 @@ Formatters for web search results """ -from typing import List, Dict, Any +from typing import Any -def format_search_results(results: List[Dict[str, Any]]) -> str: +def format_search_results(results: list[dict[str, Any]]) -> str: """ 格式化搜索结果为字符串 """ @@ -26,7 +26,7 @@ def format_search_results(results: List[Dict[str, Any]]) -> str: return formatted_string -def format_url_parse_results(results: List[Dict[str, Any]]) -> str: +def format_url_parse_results(results: list[dict[str, Any]]) -> str: """ 将成功解析的URL结果列表格式化为一段简洁的文本。 """ @@ -45,7 +45,7 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str: return "\n---\n".join(formatted_parts) -def deduplicate_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +def deduplicate_results(results: list[dict[str, Any]]) -> list[dict[str, Any]]: """ 根据URL去重搜索结果 """ diff --git a/src/plugins/built_in/web_search_tool/utils/url_utils.py b/src/plugins/built_in/web_search_tool/utils/url_utils.py index 5bdde0a55..f96d4a04a 100644 --- a/src/plugins/built_in/web_search_tool/utils/url_utils.py +++ b/src/plugins/built_in/web_search_tool/utils/url_utils.py @@ -3,10 +3,9 @@ URL processing utilities """ import re -from typing import List -def parse_urls_from_input(urls_input) -> List[str]: +def parse_urls_from_input(urls_input) -> list[str]: """ 从输入中解析URL列表 """ @@ -29,7 +28,7 @@ def parse_urls_from_input(urls_input) -> List[str]: return urls -def validate_urls(urls: List[str]) -> List[str]: +def validate_urls(urls: list[str]) -> list[str]: """ 验证URL格式,返回有效的URL列表 """ diff --git a/src/schedule/database.py b/src/schedule/database.py index b33bfb953..ccaf92b7f 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -1,7 +1,8 @@ # mmc/src/schedule/database.py -from typing import List -from sqlalchemy import select, func, update, delete + +from sqlalchemy import delete, func, select, update + from src.common.database.sqlalchemy_models import MonthlyPlan, get_db_session from src.common.logger import get_logger from src.config.config import global_config @@ -9,7 +10,7 @@ from src.config.config import global_config logger = get_logger("schedule_database") -async def add_new_plans(plans: List[str], month: str): +async def add_new_plans(plans: list[str], month: str): """ 批量添加新生成的月度计划到数据库,并确保不超过上限。 @@ -55,7 +56,7 @@ async def add_new_plans(plans: List[str], month: str): raise -async def get_active_plans_for_month(month: str) -> List[MonthlyPlan]: +async def get_active_plans_for_month(month: str) -> list[MonthlyPlan]: """ 获取指定月份所有状态为 'active' 的计划。 @@ -75,7 +76,7 @@ async def get_active_plans_for_month(month: str) -> List[MonthlyPlan]: return [] -async def mark_plans_completed(plan_ids: List[int]): +async def mark_plans_completed(plan_ids: list[int]): """ 将指定ID的计划标记为已完成。 @@ -103,7 +104,7 @@ async def mark_plans_completed(plan_ids: List[int]): raise -async def delete_plans_by_ids(plan_ids: List[int]): +async def delete_plans_by_ids(plan_ids: list[int]): """ 根据ID列表从数据库中物理删除月度计划。 @@ -134,7 +135,7 @@ async def delete_plans_by_ids(plan_ids: List[int]): raise -async def update_plan_usage(plan_ids: List[int], used_date: str): +async def update_plan_usage(plan_ids: list[int], used_date: str): """ 更新计划的使用统计信息。 @@ -182,7 +183,7 @@ async def update_plan_usage(plan_ids: List[int], used_date: str): raise -async def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_days: int = 7) -> List[MonthlyPlan]: +async def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_days: int = 7) -> list[MonthlyPlan]: """ 智能抽取月度计划用于每日日程生成。 @@ -255,7 +256,7 @@ async def archive_active_plans_for_month(month: str): raise -async def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]: +async def get_archived_plans_for_month(month: str) -> list[MonthlyPlan]: """ 获取指定月份所有状态为 'archived' 的计划。 用于生成下个月计划时的参考。 diff --git a/src/schedule/llm_generator.py b/src/schedule/llm_generator.py index d3ec56bb6..b8f4c51bd 100644 --- a/src/schedule/llm_generator.py +++ b/src/schedule/llm_generator.py @@ -1,16 +1,18 @@ # mmc/src/schedule/llm_generator.py import asyncio -import orjson from datetime import datetime -from typing import List, Optional, Dict, Any -from lunar_python import Lunar +from typing import Any + +import orjson from json_repair import repair_json +from lunar_python import Lunar from src.common.database.sqlalchemy_models import MonthlyPlan +from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.common.logger import get_logger + from .schemas import ScheduleData logger = get_logger("schedule_llm_generator") @@ -37,7 +39,7 @@ class ScheduleLLMGenerator: def __init__(self): self.llm = LLMRequest(model_set=model_config.model_task_config.schedule_generator, request_type="schedule") - async def generate_schedule_with_llm(self, sampled_plans: List[MonthlyPlan]) -> Optional[List[Dict[str, Any]]]: + async def generate_schedule_with_llm(self, sampled_plans: list[MonthlyPlan]) -> list[dict[str, Any]] | None: now = datetime.now() today_str = now.strftime("%Y-%m-%d") weekday = now.strftime("%A") @@ -143,7 +145,7 @@ class MonthlyPlanLLMGenerator: def __init__(self): self.llm = LLMRequest(model_set=model_config.model_task_config.schedule_generator, request_type="monthly_plan") - async def generate_plans_with_llm(self, target_month: str, archived_plans: List[MonthlyPlan]) -> List[str]: + async def generate_plans_with_llm(self, target_month: str, archived_plans: list[MonthlyPlan]) -> list[str]: guidelines = global_config.planning_system.monthly_plan_guidelines or DEFAULT_MONTHLY_PLAN_GUIDELINES personality = global_config.personality.personality_core personality_side = global_config.personality.personality_side @@ -209,7 +211,7 @@ class MonthlyPlanLLMGenerator: return [] @staticmethod - def _parse_plans_response(response: str) -> List[str]: + def _parse_plans_response(response: str) -> list[str]: try: response = response.strip() lines = [line.strip() for line in response.split("\n") if line.strip()] diff --git a/src/schedule/monthly_plan_manager.py b/src/schedule/monthly_plan_manager.py index 7deaaf77d..22e19cd49 100644 --- a/src/schedule/monthly_plan_manager.py +++ b/src/schedule/monthly_plan_manager.py @@ -1,9 +1,9 @@ import asyncio from datetime import datetime, timedelta -from typing import Optional from src.common.logger import get_logger from src.manager.async_task_manager import AsyncTask, async_task_manager + from .plan_manager import PlanManager logger = get_logger("monthly_plan_manager") @@ -31,7 +31,7 @@ class MonthlyPlanManager: else: logger.info(" 每月月度计划生成任务已在运行中。") - async def ensure_and_generate_plans_if_needed(self, target_month: Optional[str] = None) -> bool: + async def ensure_and_generate_plans_if_needed(self, target_month: str | None = None) -> bool: return await self.plan_manager.ensure_and_generate_plans_if_needed(target_month) diff --git a/src/schedule/plan_manager.py b/src/schedule/plan_manager.py index 513a907d5..239bdf3c2 100644 --- a/src/schedule/plan_manager.py +++ b/src/schedule/plan_manager.py @@ -1,18 +1,18 @@ # mmc/src/schedule/plan_manager.py from datetime import datetime -from typing import List, Optional from src.common.logger import get_logger from src.config.config import global_config + from .database import ( add_new_plans, - get_archived_plans_for_month, archive_active_plans_for_month, - has_active_plans, - get_active_plans_for_month, delete_plans_by_ids, + get_active_plans_for_month, + get_archived_plans_for_month, get_smart_plans_for_daily_schedule, + has_active_plans, ) from .llm_generator import MonthlyPlanLLMGenerator @@ -24,7 +24,7 @@ class PlanManager: self.llm_generator = MonthlyPlanLLMGenerator() self.generation_running = False - async def ensure_and_generate_plans_if_needed(self, target_month: Optional[str] = None) -> bool: + async def ensure_and_generate_plans_if_needed(self, target_month: str | None = None) -> bool: if target_month is None: target_month = datetime.now().strftime("%Y-%m") @@ -48,7 +48,7 @@ class PlanManager: logger.info(f"当前月度计划内容:\n{plan_texts}") return True - async def _generate_monthly_plans_logic(self, target_month: Optional[str] = None) -> bool: + async def _generate_monthly_plans_logic(self, target_month: str | None = None) -> bool: if self.generation_running: logger.info("月度计划生成任务已在运行中,跳过重复启动") return False @@ -90,7 +90,7 @@ class PlanManager: except Exception: return "1900-01" - async def archive_current_month_plans(self, target_month: Optional[str] = None): + async def archive_current_month_plans(self, target_month: str | None = None): try: if target_month is None: target_month = datetime.now().strftime("%Y-%m") @@ -100,6 +100,6 @@ class PlanManager: except Exception as e: logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}") - async def get_plans_for_schedule(self, month: str, max_count: int) -> List: + async def get_plans_for_schedule(self, month: str, max_count: int) -> list: avoid_days = global_config.planning_system.avoid_repetition_days return await get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) diff --git a/src/schedule/schedule_manager.py b/src/schedule/schedule_manager.py index 115480381..9f1133df6 100644 --- a/src/schedule/schedule_manager.py +++ b/src/schedule/schedule_manager.py @@ -1,14 +1,15 @@ -import orjson import asyncio from datetime import datetime, time, timedelta -from typing import Optional, List, Dict, Any +from typing import Any +import orjson from sqlalchemy import select from src.common.database.sqlalchemy_models import Schedule, get_db_session -from src.config.config import global_config from src.common.logger import get_logger +from src.config.config import global_config from src.manager.async_task_manager import AsyncTask, async_task_manager + from .database import update_plan_usage from .llm_generator import ScheduleLLMGenerator from .plan_manager import PlanManager @@ -19,7 +20,7 @@ logger = get_logger("schedule_manager") class ScheduleManager: def __init__(self): - self.today_schedule: Optional[List[Dict[str, Any]]] = None + self.today_schedule: list[dict[str, Any]] | None = None self.llm_generator = ScheduleLLMGenerator() self.plan_manager = PlanManager() self.daily_task_started = False @@ -63,7 +64,7 @@ class ScheduleManager: logger.info("尝试生成日程作为备用方案...") await self.generate_and_save_schedule() - async def _load_schedule_from_db(self, date_str: str) -> Optional[List[Dict[str, Any]]]: + async def _load_schedule_from_db(self, date_str: str) -> list[dict[str, Any]] | None: async with get_db_session() as session: result = await session.execute(select(Schedule).filter(Schedule.date == date_str)) schedule_record = result.scalars().first() @@ -118,7 +119,7 @@ class ScheduleManager: logger.info("日程生成任务结束") @staticmethod - async def _save_schedule_to_db(date_str: str, schedule_data: List[Dict[str, Any]]): + async def _save_schedule_to_db(date_str: str, schedule_data: list[dict[str, Any]]): async with get_db_session() as session: schedule_json = orjson.dumps(schedule_data).decode("utf-8") result = await session.execute(select(Schedule).filter(Schedule.date == date_str)) @@ -132,13 +133,13 @@ class ScheduleManager: await session.commit() @staticmethod - def _log_generated_schedule(date_str: str, schedule_data: List[Dict[str, Any]]): + def _log_generated_schedule(date_str: str, schedule_data: list[dict[str, Any]]): schedule_str = f"✅ 成功生成并保存今天的日程 ({date_str}):\n" for item in schedule_data: schedule_str += f" - {item.get('time_range', '未知时间')}: {item.get('activity', '未知活动')}\n" logger.info(schedule_str) - def get_current_activity(self) -> Optional[str]: + def get_current_activity(self) -> str | None: if not global_config.planning_system.schedule_enable or not self.today_schedule: return None now = datetime.now().time() diff --git a/src/schedule/schemas.py b/src/schedule/schemas.py index a733731be..00508e4d8 100644 --- a/src/schedule/schemas.py +++ b/src/schedule/schemas.py @@ -1,7 +1,7 @@ # mmc/src/schedule/schemas.py from datetime import datetime, time -from typing import List + from pydantic import BaseModel, validator @@ -41,7 +41,7 @@ class ScheduleItem(BaseModel): class ScheduleData(BaseModel): """完整日程数据的Pydantic模型""" - schedule: List[ScheduleItem] + schedule: list[ScheduleItem] @validator("schedule") def validate_schedule_completeness(cls, v): @@ -67,7 +67,7 @@ class ScheduleData(BaseModel): return v @staticmethod - def _check_24_hour_coverage(time_ranges: List[tuple]) -> bool: + def _check_24_hour_coverage(time_ranges: list[tuple]) -> bool: """检查时间段是否覆盖24小时""" if not time_ranges: return False diff --git a/src/utils/message_chunker.py b/src/utils/message_chunker.py index 66a2964e1..2e98adcf1 100644 --- a/src/utils/message_chunker.py +++ b/src/utils/message_chunker.py @@ -3,10 +3,12 @@ MaiBot 端的消息切片处理模块 用于接收和重组来自 Napcat-Adapter 的切片消息 """ -import orjson -import time import asyncio -from typing import Dict, Any, Optional +import time +from typing import Any + +import orjson + from src.common.logger import get_logger logger = get_logger("message_chunker") @@ -17,7 +19,7 @@ class MessageReassembler: def __init__(self, timeout: int = 30): self.timeout = timeout - self.chunk_buffers: Dict[str, Dict[str, Any]] = {} + self.chunk_buffers: dict[str, dict[str, Any]] = {} self._cleanup_task = None async def start_cleanup_task(self): @@ -59,7 +61,7 @@ class MessageReassembler: logger.error(f"清理过期切片时出错: {e}") @staticmethod - def is_chunk_message(message: Dict[str, Any]) -> bool: + def is_chunk_message(message: dict[str, Any]) -> bool: """检查是否是来自 Ada 的切片消息""" return ( isinstance(message, dict) @@ -68,7 +70,7 @@ class MessageReassembler: and "__mmc_is_chunked__" in message ) - async def process_chunk(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]: + async def process_chunk(self, message: dict[str, Any]) -> dict[str, Any] | None: """ 处理切片消息,如果切片完整则返回重组后的消息 @@ -144,7 +146,7 @@ class MessageReassembler: logger.error(f"处理切片消息时出错: {e}") return None - def get_pending_chunks_info(self) -> Dict[str, Any]: + def get_pending_chunks_info(self) -> dict[str, Any]: """获取待处理切片信息""" info = {} for chunk_id, buffer in self.chunk_buffers.items(): diff --git a/src/utils/timing_utils.py b/src/utils/timing_utils.py index b4084d6af..36fb1f870 100644 --- a/src/utils/timing_utils.py +++ b/src/utils/timing_utils.py @@ -10,10 +10,10 @@ - 快速筛选:使用NumPy布尔索引进行高效过滤 """ -import numpy as np -from typing import Optional from functools import lru_cache +import numpy as np + @lru_cache(maxsize=128) def _calculate_sigma_bounds(base_interval: int, sigma_percentage: float, use_3sigma_rule: bool) -> tuple: @@ -35,8 +35,8 @@ def _calculate_sigma_bounds(base_interval: int, sigma_percentage: float, use_3si def get_normal_distributed_interval( base_interval: int, sigma_percentage: float = 0.1, - min_interval: Optional[int] = None, - max_interval: Optional[int] = None, + min_interval: int | None = None, + max_interval: int | None = None, use_3sigma_rule: bool = True, ) -> int: """ @@ -120,8 +120,8 @@ def get_normal_distributed_interval( def _generate_pure_random_interval( sigma_percentage: float, - min_interval: Optional[int] = None, - max_interval: Optional[int] = None, + min_interval: int | None = None, + max_interval: int | None = None, use_3sigma_rule: bool = True, ) -> int: """ diff --git a/ui_log_adapter.py b/ui_log_adapter.py index 58ae14f80..3fb474620 100644 --- a/ui_log_adapter.py +++ b/ui_log_adapter.py @@ -3,9 +3,9 @@ Bot服务UI日志适配器 在最小侵入的情况下捕获Bot的日志并发送到UI """ -import sys -import os import logging +import os +import sys import threading import time