From b303a95f61fbf6aaf4110a85881c821bb32f3dc5 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sat, 12 Jul 2025 00:34:49 +0800 Subject: [PATCH] =?UTF-8?q?=E9=83=A8=E5=88=86=E7=B1=BB=E5=9E=8B=E6=B3=A8?= =?UTF-8?q?=E8=A7=A3=E4=BF=AE=E5=A4=8D=EF=BC=8C=E4=BC=98=E5=8C=96import?= =?UTF-8?q?=E9=A1=BA=E5=BA=8F=EF=BC=8C=E5=88=A0=E9=99=A4=E6=97=A0=E7=94=A8?= =?UTF-8?q?API=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/__init__.py | 0 src/api/apiforgui.py | 26 -- src/api/basic_info_api.py | 169 ---------- src/api/config_api.py | 317 ------------------ src/api/maigraphql/__init__.py | 22 -- src/api/maigraphql/schema.py | 1 - src/api/main.py | 112 ------- src/api/reload_config.py | 24 -- src/chat/emoji_system/emoji_manager.py | 45 +-- src/chat/express/expression_selector.py | 17 +- src/chat/express/exprssion_learner.py | 20 +- src/chat/focus_chat/focus_loop_info.py | 4 +- src/chat/focus_chat/heartFC_chat.py | 36 +- src/chat/focus_chat/hfc_performance_logger.py | 1 + src/chat/focus_chat/hfc_utils.py | 9 +- src/chat/heart_flow/heartflow.py | 8 +- .../heart_flow/heartflow_message_processor.py | 30 +- src/chat/heart_flow/sub_heartflow.py | 12 +- src/chat/memory_system/Hippocampus.py | 120 +++---- src/chat/memory_system/memory_activator.py | 13 +- src/chat/memory_system/sample_distribution.py | 42 --- src/chat/message_receive/bot.py | 26 +- src/chat/message_receive/chat_stream.py | 32 +- src/chat/message_receive/message.py | 55 ++- .../message_receive/normal_message_sender.py | 35 +- src/chat/message_receive/storage.py | 31 +- .../message_receive/uni_message_sender.py | 23 +- src/chat/normal_chat/normal_chat.py | 8 +- src/chat/normal_chat/priority_manager.py | 2 +- .../normal_chat/willing/mode_classical.py | 4 +- .../normal_chat/willing/willing_manager.py | 18 +- src/chat/planner_actions/action_manager.py | 68 ++-- src/chat/planner_actions/action_modifier.py | 28 +- src/chat/planner_actions/planner.py | 7 +- src/chat/replyer/default_generator.py | 58 ++-- src/chat/replyer/replyer_manager.py | 12 +- src/chat/utils/chat_message_builder.py | 61 ++-- src/chat/utils/utils_image.py | 8 +- src/person_info/person_info.py | 4 +- src/plugin_system/base/base_action.py | 4 +- src/plugin_system/base/base_command.py | 2 +- src/plugin_system/base/base_plugin.py | 4 +- src/plugin_system/base/component_types.py | 4 +- src/plugin_system/core/component_registry.py | 49 +-- 44 files changed, 405 insertions(+), 1166 deletions(-) delete mode 100644 src/api/__init__.py delete mode 100644 src/api/apiforgui.py delete mode 100644 src/api/basic_info_api.py delete mode 100644 src/api/config_api.py delete mode 100644 src/api/maigraphql/__init__.py delete mode 100644 src/api/maigraphql/schema.py delete mode 100644 src/api/main.py delete mode 100644 src/api/reload_config.py diff --git a/src/api/__init__.py b/src/api/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/api/apiforgui.py b/src/api/apiforgui.py deleted file mode 100644 index 058c6fc96..000000000 --- a/src/api/apiforgui.py +++ /dev/null @@ -1,26 +0,0 @@ -from src.chat.heart_flow.heartflow import heartflow -from src.chat.heart_flow.sub_heartflow import ChatState -from src.common.logger import get_logger - -logger = get_logger("api") - - -async def get_all_subheartflow_ids() -> list: - """获取所有子心流的ID列表""" - all_subheartflows = heartflow.subheartflow_manager.get_all_subheartflows() - return [subheartflow.subheartflow_id for subheartflow in all_subheartflows] - - -async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatState) -> bool: - """强制改变子心流的状态""" - subheartflow = await heartflow.get_or_create_subheartflow(subheartflow_id) - if subheartflow: - return await heartflow.force_change_subheartflow_status(subheartflow_id, status) - return False - - -async def get_all_states(): - """获取所有状态""" - all_states = await heartflow.api_get_all_states() - logger.debug(f"所有状态: {all_states}") - return all_states diff --git a/src/api/basic_info_api.py b/src/api/basic_info_api.py deleted file mode 100644 index 4e5fa4c7d..000000000 --- a/src/api/basic_info_api.py +++ /dev/null @@ -1,169 +0,0 @@ -import platform -import psutil -import sys -import os - - -def get_system_info(): - """获取操作系统信息""" - return { - "system": platform.system(), - "release": platform.release(), - "version": platform.version(), - "machine": platform.machine(), - "processor": platform.processor(), - } - - -def get_python_version(): - """获取 Python 版本信息""" - return sys.version - - -def get_cpu_usage(): - """获取系统总CPU使用率""" - return psutil.cpu_percent(interval=1) - - -def get_process_cpu_usage(): - """获取当前进程CPU使用率""" - process = psutil.Process(os.getpid()) - return process.cpu_percent(interval=1) - - -def get_memory_usage(): - """获取系统内存使用情况 (单位 MB)""" - mem = psutil.virtual_memory() - bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa - return { - "total_mb": bytes_to_mb(mem.total), - "available_mb": bytes_to_mb(mem.available), - "percent": mem.percent, - "used_mb": bytes_to_mb(mem.used), - "free_mb": bytes_to_mb(mem.free), - } - - -def get_process_memory_usage(): - """获取当前进程内存使用情况 (单位 MB)""" - process = psutil.Process(os.getpid()) - mem_info = process.memory_info() - bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa - return { - "rss_mb": bytes_to_mb(mem_info.rss), # Resident Set Size: 实际使用物理内存 - "vms_mb": bytes_to_mb(mem_info.vms), # Virtual Memory Size: 虚拟内存大小 - "percent": process.memory_percent(), # 进程内存使用百分比 - } - - -def get_disk_usage(path="/"): - """获取指定路径磁盘使用情况 (单位 GB)""" - disk = psutil.disk_usage(path) - bytes_to_gb = lambda x: round(x / (1024 * 1024 * 1024), 2) # noqa - return { - "total_gb": bytes_to_gb(disk.total), - "used_gb": bytes_to_gb(disk.used), - "free_gb": bytes_to_gb(disk.free), - "percent": disk.percent, - } - - -def get_all_basic_info(): - """获取所有基本信息并封装返回""" - # 对于进程CPU使用率,需要先初始化 - process = psutil.Process(os.getpid()) - process.cpu_percent(interval=None) # 初始化调用 - process_cpu = process.cpu_percent(interval=0.1) # 短暂间隔获取 - - return { - "system_info": get_system_info(), - "python_version": get_python_version(), - "cpu_usage_percent": get_cpu_usage(), - "process_cpu_usage_percent": process_cpu, - "memory_usage": get_memory_usage(), - "process_memory_usage": get_process_memory_usage(), - "disk_usage_root": get_disk_usage("/"), - } - - -def get_all_basic_info_string() -> str: - """获取所有基本信息并以带解释的字符串形式返回""" - info = get_all_basic_info() - - sys_info = info["system_info"] - mem_usage = info["memory_usage"] - proc_mem_usage = info["process_memory_usage"] - disk_usage = info["disk_usage_root"] - - # 对进程内存使用百分比进行格式化,保留两位小数 - proc_mem_percent = round(proc_mem_usage["percent"], 2) - - output_string = f"""[系统信息] - - 操作系统: {sys_info["system"]} (例如: Windows, Linux) - - 发行版本: {sys_info["release"]} (例如: 11, Ubuntu 20.04) - - 详细版本: {sys_info["version"]} - - 硬件架构: {sys_info["machine"]} (例如: AMD64) - - 处理器信息: {sys_info["processor"]} - -[Python 环境] - - Python 版本: {info["python_version"]} - -[CPU 状态] - - 系统总 CPU 使用率: {info["cpu_usage_percent"]}% - - 当前进程 CPU 使用率: {info["process_cpu_usage_percent"]}% - -[系统内存使用情况] - - 总物理内存: {mem_usage["total_mb"]} MB - - 可用物理内存: {mem_usage["available_mb"]} MB - - 物理内存使用率: {mem_usage["percent"]}% - - 已用物理内存: {mem_usage["used_mb"]} MB - - 空闲物理内存: {mem_usage["free_mb"]} MB - -[当前进程内存使用情况] - - 实际使用物理内存 (RSS): {proc_mem_usage["rss_mb"]} MB - - 占用虚拟内存 (VMS): {proc_mem_usage["vms_mb"]} MB - - 进程内存使用率: {proc_mem_percent}% - -[磁盘使用情况 (根目录)] - - 总空间: {disk_usage["total_gb"]} GB - - 已用空间: {disk_usage["used_gb"]} GB - - 可用空间: {disk_usage["free_gb"]} GB - - 磁盘使用率: {disk_usage["percent"]}% -""" - return output_string - - -if __name__ == "__main__": - print(f"System Info: {get_system_info()}") - print(f"Python Version: {get_python_version()}") - print(f"CPU Usage: {get_cpu_usage()}%") - # 第一次调用 process.cpu_percent() 会返回0.0或一个无意义的值,需要间隔一段时间再调用 - # 或者在初始化Process对象后,先调用一次cpu_percent(interval=None),然后再调用cpu_percent(interval=1) - current_process = psutil.Process(os.getpid()) - current_process.cpu_percent(interval=None) # 初始化 - print(f"Process CPU Usage: {current_process.cpu_percent(interval=1)}%") # 实际获取 - - memory_usage_info = get_memory_usage() - print( - f"Memory Usage: Total={memory_usage_info['total_mb']}MB, Used={memory_usage_info['used_mb']}MB, Percent={memory_usage_info['percent']}%" - ) - - process_memory_info = get_process_memory_usage() - print( - f"Process Memory Usage: RSS={process_memory_info['rss_mb']}MB, VMS={process_memory_info['vms_mb']}MB, Percent={process_memory_info['percent']}%" - ) - - disk_usage_info = get_disk_usage("/") - print( - f"Disk Usage (Root): Total={disk_usage_info['total_gb']}GB, Used={disk_usage_info['used_gb']}GB, Percent={disk_usage_info['percent']}%" - ) - - print("\n--- All Basic Info (JSON) ---") - all_info = get_all_basic_info() - import json - - print(json.dumps(all_info, indent=4, ensure_ascii=False)) - - print("\n--- All Basic Info (String with Explanations) ---") - info_string = get_all_basic_info_string() - print(info_string) diff --git a/src/api/config_api.py b/src/api/config_api.py deleted file mode 100644 index 07f36a9d8..000000000 --- a/src/api/config_api.py +++ /dev/null @@ -1,317 +0,0 @@ -from typing import List, Optional, Dict, Any -import strawberry - -# from packaging.version import Version -import os - -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) - - -@strawberry.type -class APIBotConfig: - """机器人配置类""" - - INNER_VERSION: str # 配置文件内部版本号(toml为字符串) - MAI_VERSION: str # 硬编码的版本信息 - - # bot - BOT_QQ: Optional[int] # 机器人QQ号 - BOT_NICKNAME: Optional[str] # 机器人昵称 - BOT_ALIAS_NAMES: List[str] # 机器人别名列表 - - # group - talk_allowed_groups: List[int] # 允许回复消息的群号列表 - talk_frequency_down_groups: List[int] # 降低回复频率的群号列表 - ban_user_id: List[int] # 禁止回复和读取消息的QQ号列表 - - # personality - personality_core: str # 人格核心特点描述 - personality_sides: List[str] # 人格细节描述列表 - - # identity - identity_detail: List[str] # 身份特点列表 - age: int # 年龄(岁) - gender: str # 性别 - appearance: str # 外貌特征描述 - - # platforms - platforms: Dict[str, str] # 平台信息 - - # chat - allow_focus_mode: bool # 是否允许专注聊天状态 - base_normal_chat_num: int # 最多允许多少个群进行普通聊天 - base_focused_chat_num: int # 最多允许多少个群进行专注聊天 - observation_context_size: int # 观察到的最长上下文大小 - message_buffer: bool # 是否启用消息缓冲 - ban_words: List[str] # 禁止词列表 - ban_msgs_regex: List[str] # 禁止消息的正则表达式列表 - - # normal_chat - model_reasoning_probability: float # 推理模型概率 - model_normal_probability: float # 普通模型概率 - emoji_chance: float # 表情符号出现概率 - thinking_timeout: int # 思考超时时间 - willing_mode: str # 意愿模式 - response_interested_rate_amplifier: float # 回复兴趣率放大器 - emoji_response_penalty: float # 表情回复惩罚 - mentioned_bot_inevitable_reply: bool # 提及 bot 必然回复 - at_bot_inevitable_reply: bool # @bot 必然回复 - - # focus_chat - reply_trigger_threshold: float # 回复触发阈值 - default_decay_rate_per_second: float # 默认每秒衰减率 - - # compressed - compressed_length: int # 压缩长度 - compress_length_limit: int # 压缩长度限制 - - # emoji - max_emoji_num: int # 最大表情符号数量 - max_reach_deletion: bool # 达到最大数量时是否删除 - check_interval: int # 检查表情包的时间间隔(分钟) - save_emoji: bool # 是否保存表情包 - steal_emoji: bool # 是否偷取表情包 - enable_check: bool # 是否启用表情包过滤 - check_prompt: str # 表情包过滤要求 - - # memory - build_memory_interval: int # 记忆构建间隔 - build_memory_distribution: List[float] # 记忆构建分布 - build_memory_sample_num: int # 采样数量 - build_memory_sample_length: int # 采样长度 - memory_compress_rate: float # 记忆压缩率 - forget_memory_interval: int # 记忆遗忘间隔 - memory_forget_time: int # 记忆遗忘时间(小时) - memory_forget_percentage: float # 记忆遗忘比例 - consolidate_memory_interval: int # 记忆整合间隔 - consolidation_similarity_threshold: float # 相似度阈值 - consolidation_check_percentage: float # 检查节点比例 - memory_ban_words: List[str] # 记忆禁止词列表 - - # mood - mood_update_interval: float # 情绪更新间隔 - mood_decay_rate: float # 情绪衰减率 - mood_intensity_factor: float # 情绪强度因子 - - # keywords_reaction - keywords_reaction_enable: bool # 是否启用关键词反应 - keywords_reaction_rules: List[Dict[str, Any]] # 关键词反应规则 - - # chinese_typo - chinese_typo_enable: bool # 是否启用中文错别字 - chinese_typo_error_rate: float # 中文错别字错误率 - chinese_typo_min_freq: int # 中文错别字最小频率 - chinese_typo_tone_error_rate: float # 中文错别字声调错误率 - chinese_typo_word_replace_rate: float # 中文错别字单词替换率 - - # response_splitter - enable_response_splitter: bool # 是否启用回复分割器 - response_max_length: int # 回复最大长度 - response_max_sentence_num: int # 回复最大句子数 - enable_kaomoji_protection: bool # 是否启用颜文字保护 - - model_max_output_length: int # 模型最大输出长度 - - # remote - remote_enable: bool # 是否启用远程功能 - - # experimental - enable_friend_chat: bool # 是否启用好友聊天 - talk_allowed_private: List[int] # 允许私聊的QQ号列表 - pfc_chatting: bool # 是否启用PFC聊天 - - # 模型配置 - llm_reasoning: Dict[str, Any] # 推理模型配置 - llm_normal: Dict[str, Any] # 普通模型配置 - llm_topic_judge: Dict[str, Any] # 主题判断模型配置 - summary: Dict[str, Any] # 总结模型配置 - vlm: Dict[str, Any] # VLM模型配置 - llm_heartflow: Dict[str, Any] # 心流模型配置 - llm_observation: Dict[str, Any] # 观察模型配置 - llm_sub_heartflow: Dict[str, Any] # 子心流模型配置 - llm_plan: Optional[Dict[str, Any]] # 计划模型配置 - embedding: Dict[str, Any] # 嵌入模型配置 - llm_PFC_action_planner: Optional[Dict[str, Any]] # PFC行动计划模型配置 - llm_PFC_chat: Optional[Dict[str, Any]] # PFC聊天模型配置 - llm_PFC_reply_checker: Optional[Dict[str, Any]] # PFC回复检查模型配置 - llm_tool_use: Optional[Dict[str, Any]] # 工具使用模型配置 - - api_urls: Optional[Dict[str, str]] # API地址配置 - - @staticmethod - def validate_config(config: dict): - """ - 校验传入的 toml 配置字典是否合法。 - :param config: toml库load后的配置字典 - :raises: ValueError, KeyError, TypeError - """ - # 检查主层级 - required_sections = [ - "inner", - "bot", - "groups", - "personality", - "identity", - "platforms", - "chat", - "normal_chat", - "focus_chat", - "emoji", - "memory", - "mood", - "keywords_reaction", - "chinese_typo", - "response_splitter", - "remote", - "experimental", - "model", - ] - for section in required_sections: - if section not in config: - raise KeyError(f"缺少配置段: [{section}]") - - # 检查部分关键字段 - if "version" not in config["inner"]: - raise KeyError("缺少 inner.version 字段") - if not isinstance(config["inner"]["version"], str): - raise TypeError("inner.version 必须为字符串") - - if "qq" not in config["bot"]: - raise KeyError("缺少 bot.qq 字段") - if not isinstance(config["bot"]["qq"], int): - raise TypeError("bot.qq 必须为整数") - - if "personality_core" not in config["personality"]: - raise KeyError("缺少 personality.personality_core 字段") - if not isinstance(config["personality"]["personality_core"], str): - raise TypeError("personality.personality_core 必须为字符串") - - if "identity_detail" not in config["identity"]: - raise KeyError("缺少 identity.identity_detail 字段") - if not isinstance(config["identity"]["identity_detail"], list): - raise TypeError("identity.identity_detail 必须为列表") - - # 可继续添加更多字段的类型和值检查 - # ... - - # 检查模型配置 - model_keys = [ - "llm_reasoning", - "llm_normal", - "llm_topic_judge", - "summary", - "vlm", - "llm_heartflow", - "llm_observation", - "llm_sub_heartflow", - "embedding", - ] - if "model" not in config: - raise KeyError("缺少 [model] 配置段") - for key in model_keys: - if key not in config["model"]: - raise KeyError(f"缺少 model.{key} 配置") - - # 检查通过 - return True - - -@strawberry.type -class APIEnvConfig: - """环境变量配置""" - - HOST: str # 服务主机地址 - PORT: int # 服务端口 - - PLUGINS: List[str] # 插件列表 - - MONGODB_HOST: str # MongoDB 主机地址 - MONGODB_PORT: int # MongoDB 端口 - DATABASE_NAME: str # 数据库名称 - - CHAT_ANY_WHERE_BASE_URL: str # ChatAnywhere 基础URL - SILICONFLOW_BASE_URL: str # SiliconFlow 基础URL - DEEP_SEEK_BASE_URL: str # DeepSeek 基础URL - - DEEP_SEEK_KEY: Optional[str] # DeepSeek API Key - CHAT_ANY_WHERE_KEY: Optional[str] # ChatAnywhere API Key - SILICONFLOW_KEY: Optional[str] # SiliconFlow API Key - - SIMPLE_OUTPUT: Optional[bool] # 是否简化输出 - CONSOLE_LOG_LEVEL: Optional[str] # 控制台日志等级 - FILE_LOG_LEVEL: Optional[str] # 文件日志等级 - DEFAULT_CONSOLE_LOG_LEVEL: Optional[str] # 默认控制台日志等级 - DEFAULT_FILE_LOG_LEVEL: Optional[str] # 默认文件日志等级 - - @strawberry.field - def get_env(self) -> str: - return "env" - - @staticmethod - def validate_config(config: dict): - """ - 校验环境变量配置字典是否合法。 - :param config: 环境变量配置字典 - :raises: KeyError, TypeError - """ - required_fields = [ - "HOST", - "PORT", - "PLUGINS", - "MONGODB_HOST", - "MONGODB_PORT", - "DATABASE_NAME", - "CHAT_ANY_WHERE_BASE_URL", - "SILICONFLOW_BASE_URL", - "DEEP_SEEK_BASE_URL", - ] - for field in required_fields: - if field not in config: - raise KeyError(f"缺少环境变量配置字段: {field}") - - if not isinstance(config["HOST"], str): - raise TypeError("HOST 必须为字符串") - if not isinstance(config["PORT"], int): - raise TypeError("PORT 必须为整数") - if not isinstance(config["PLUGINS"], list): - raise TypeError("PLUGINS 必须为列表") - if not isinstance(config["MONGODB_HOST"], str): - raise TypeError("MONGODB_HOST 必须为字符串") - if not isinstance(config["MONGODB_PORT"], int): - raise TypeError("MONGODB_PORT 必须为整数") - if not isinstance(config["DATABASE_NAME"], str): - raise TypeError("DATABASE_NAME 必须为字符串") - if not isinstance(config["CHAT_ANY_WHERE_BASE_URL"], str): - raise TypeError("CHAT_ANY_WHERE_BASE_URL 必须为字符串") - if not isinstance(config["SILICONFLOW_BASE_URL"], str): - raise TypeError("SILICONFLOW_BASE_URL 必须为字符串") - if not isinstance(config["DEEP_SEEK_BASE_URL"], str): - raise TypeError("DEEP_SEEK_BASE_URL 必须为字符串") - - # 可选字段类型检查 - optional_str_fields = [ - "DEEP_SEEK_KEY", - "CHAT_ANY_WHERE_KEY", - "SILICONFLOW_KEY", - "CONSOLE_LOG_LEVEL", - "FILE_LOG_LEVEL", - "DEFAULT_CONSOLE_LOG_LEVEL", - "DEFAULT_FILE_LOG_LEVEL", - ] - for field in optional_str_fields: - if field in config and config[field] is not None and not isinstance(config[field], str): - raise TypeError(f"{field} 必须为字符串或None") - - if ( - "SIMPLE_OUTPUT" in config - and config["SIMPLE_OUTPUT"] is not None - and not isinstance(config["SIMPLE_OUTPUT"], bool) - ): - raise TypeError("SIMPLE_OUTPUT 必须为布尔值或None") - - # 检查通过 - return True - - -print("当前路径:") -print(ROOT_PATH) diff --git a/src/api/maigraphql/__init__.py b/src/api/maigraphql/__init__.py deleted file mode 100644 index c414911de..000000000 --- a/src/api/maigraphql/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -import strawberry - -from fastapi import FastAPI -from strawberry.fastapi import GraphQLRouter - -from src.common.server import get_global_server - - -@strawberry.type -class Query: - @strawberry.field - def hello(self) -> str: - return "Hello World" - - -schema = strawberry.Schema(Query) - -graphql_app = GraphQLRouter(schema) - -fast_api_app: FastAPI = get_global_server().get_app() - -fast_api_app.include_router(graphql_app, prefix="/graphql") diff --git a/src/api/maigraphql/schema.py b/src/api/maigraphql/schema.py deleted file mode 100644 index 2ae28399f..000000000 --- a/src/api/maigraphql/schema.py +++ /dev/null @@ -1 +0,0 @@ -pass diff --git a/src/api/main.py b/src/api/main.py deleted file mode 100644 index 598b8aec5..000000000 --- a/src/api/main.py +++ /dev/null @@ -1,112 +0,0 @@ -from fastapi import APIRouter -from strawberry.fastapi import GraphQLRouter -import os -import sys - -# from src.chat.heart_flow.heartflow import heartflow -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) -# from src.config.config import BotConfig -from src.common.logger import get_logger -from src.api.reload_config import reload_config as reload_config_func -from src.common.server import get_global_server -from src.api.apiforgui import ( - get_all_subheartflow_ids, - forced_change_subheartflow_status, - get_subheartflow_cycle_info, - get_all_states, -) -from src.chat.heart_flow.sub_heartflow import ChatState -from src.api.basic_info_api import get_all_basic_info # 新增导入 - - -router = APIRouter() - - -logger = get_logger("api") - -logger.info("麦麦API服务器已启动") -graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema - -router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"]) - - -@router.post("/config/reload") -async def reload_config(): - return await reload_config_func() - - -@router.get("/gui/subheartflow/get/all") -async def get_subheartflow_ids(): - """获取所有子心流的ID列表""" - return await get_all_subheartflow_ids() - - -@router.post("/gui/subheartflow/forced_change_status") -async def forced_change_subheartflow_status_api(subheartflow_id: str, status: ChatState): # noqa - """强制改变子心流的状态""" - # 参数检查 - if not isinstance(status, ChatState): - logger.warning(f"无效的状态参数: {status}") - return {"status": "failed", "reason": "invalid status"} - logger.info(f"尝试将子心流 {subheartflow_id} 状态更改为 {status.value}") - success = await forced_change_subheartflow_status(subheartflow_id, status) - if success: - logger.info(f"子心流 {subheartflow_id} 状态更改为 {status.value} 成功") - return {"status": "success"} - else: - logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败") - return {"status": "failed"} - - -@router.get("/stop") -async def force_stop_maibot(): - """强制停止MAI Bot""" - from bot import request_shutdown - - success = await request_shutdown() - if success: - logger.info("MAI Bot已强制停止") - return {"status": "success"} - else: - logger.error("MAI Bot强制停止失败") - return {"status": "failed"} - - -@router.get("/gui/subheartflow/cycleinfo") -async def get_subheartflow_cycle_info_api(subheartflow_id: str, history_len: int): - """获取子心流的循环信息""" - cycle_info = await get_subheartflow_cycle_info(subheartflow_id, history_len) - if cycle_info: - return {"status": "success", "data": cycle_info} - else: - logger.warning(f"子心流 {subheartflow_id} 循环信息未找到") - return {"status": "failed", "reason": "subheartflow not found"} - - -@router.get("/gui/get_all_states") -async def get_all_states_api(): - """获取所有状态""" - all_states = await get_all_states() - if all_states: - return {"status": "success", "data": all_states} - else: - logger.warning("获取所有状态失败") - return {"status": "failed", "reason": "failed to get all states"} - - -@router.get("/info") -async def get_system_basic_info(): - """获取系统基本信息""" - logger.info("请求系统基本信息") - try: - info = get_all_basic_info() - return {"status": "success", "data": info} - except Exception as e: - logger.error(f"获取系统基本信息失败: {e}") - return {"status": "failed", "reason": str(e)} - - -def start_api_server(): - """启动API服务器""" - get_global_server().register_router(router, prefix="/api/v1") - # pass diff --git a/src/api/reload_config.py b/src/api/reload_config.py deleted file mode 100644 index 087c47e4f..000000000 --- a/src/api/reload_config.py +++ /dev/null @@ -1,24 +0,0 @@ -from fastapi import HTTPException -from rich.traceback import install -from src.config.config import get_config_dir, load_config -from src.common.logger import get_logger -import os - -install(extra_lines=3) - -logger = get_logger("api") - - -async def reload_config(): - try: - from src.config import config as config_module - - logger.debug("正在重载配置文件...") - bot_config_path = os.path.join(get_config_dir(), "bot_config.toml") - config_module.global_config = load_config(config_path=bot_config_path) - logger.debug("配置文件重载成功") - return {"status": "reloaded"} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e - except Exception as e: - raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 3511d938b..11fb0f62d 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -5,20 +5,19 @@ import os import random import time import traceback -from typing import Optional, Tuple, List, Any -from PIL import Image import io import re - -# from gradio_client import file +import binascii +from typing import Optional, Tuple, List, Any +from PIL import Image +from rich.traceback import install from src.common.database.database_model import Emoji from src.common.database.database import db as peewee_db +from src.common.logger import get_logger from src.config.config import global_config from src.chat.utils.utils_image import image_path_to_base64, get_image_manager from src.llm_models.utils_model import LLMRequest -from src.common.logger import get_logger -from rich.traceback import install install(extra_lines=3) @@ -26,7 +25,7 @@ logger = get_logger("emoji") BASE_DIR = os.path.join("data") EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录 -EMOJI_REGISTED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录 +EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录 MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中 """ @@ -85,7 +84,7 @@ class MaiEmoji: logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}") try: with Image.open(io.BytesIO(image_bytes)) as img: - self.format = img.format.lower() + self.format = img.format.lower() # type: ignore logger.debug(f"[初始化] 格式获取成功: {self.format}") except Exception as pil_error: logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}") @@ -100,7 +99,7 @@ class MaiEmoji: logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}") self.is_deleted = True return None - except base64.binascii.Error as b64_error: + except (binascii.Error, ValueError) as b64_error: logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}") self.is_deleted = True return None @@ -113,7 +112,7 @@ class MaiEmoji: async def register_to_db(self) -> bool: """ 注册表情包 - 将表情包对应的文件,从当前路径移动到EMOJI_REGISTED_DIR目录下 + 将表情包对应的文件,从当前路径移动到EMOJI_REGISTERED_DIR目录下 并修改对应的实例属性,然后将表情包信息保存到数据库中 """ try: @@ -122,7 +121,7 @@ class MaiEmoji: # 源路径是当前实例的完整路径 self.full_path source_full_path = self.full_path # 目标完整路径 - destination_full_path = os.path.join(EMOJI_REGISTED_DIR, self.filename) + destination_full_path = os.path.join(EMOJI_REGISTERED_DIR, self.filename) # 检查源文件是否存在 if not os.path.exists(source_full_path): @@ -139,7 +138,7 @@ class MaiEmoji: logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}") # 更新实例的路径属性为新路径 self.full_path = destination_full_path - self.path = EMOJI_REGISTED_DIR + self.path = EMOJI_REGISTERED_DIR # self.filename 保持不变 except Exception as move_error: logger.error(f"[错误] 移动文件失败: {str(move_error)}") @@ -202,7 +201,7 @@ class MaiEmoji: try: will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash) result = will_delete_emoji.delete_instance() # Returns the number of rows deleted. - except Emoji.DoesNotExist: + except Emoji.DoesNotExist: # type: ignore logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") result = 0 # Indicate no DB record was deleted @@ -298,7 +297,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]: def _ensure_emoji_dir() -> None: """确保表情存储目录存在""" os.makedirs(EMOJI_DIR, exist_ok=True) - os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True) + os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True) async def clear_temp_emoji() -> None: @@ -331,10 +330,10 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") return removed_count + cleaned_count = 0 try: # 获取内存中所有有效表情包的完整路径集合 tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted} - cleaned_count = 0 # 遍历指定目录中的所有文件 for file_name in os.listdir(emoji_dir): @@ -358,11 +357,11 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r else: logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。") - return removed_count + cleaned_count - except Exception as e: logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}") + return removed_count + cleaned_count + class EmojiManager: _instance = None @@ -414,7 +413,7 @@ class EmojiManager: emoji_update.usage_count += 1 emoji_update.last_used_time = time.time() # Update last used time emoji_update.save() # Persist changes to DB - except Emoji.DoesNotExist: + except Emoji.DoesNotExist: # type: ignore logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") except Exception as e: logger.error(f"记录表情使用失败: {str(e)}") @@ -570,8 +569,8 @@ class EmojiManager: if objects_to_remove: self.emoji_objects = [e for e in self.emoji_objects if e not in objects_to_remove] - # 清理 EMOJI_REGISTED_DIR 目录中未被追踪的文件 - removed_count = await clean_unused_emojis(EMOJI_REGISTED_DIR, self.emoji_objects, removed_count) + # 清理 EMOJI_REGISTERED_DIR 目录中未被追踪的文件 + removed_count = await clean_unused_emojis(EMOJI_REGISTERED_DIR, self.emoji_objects, removed_count) # 输出清理结果 if removed_count > 0: @@ -850,11 +849,13 @@ class EmojiManager: if isinstance(image_base64, str): image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_bytes = base64.b64decode(image_base64) - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore # 调用AI获取描述 if image_format == "gif" or image_format == "GIF": - image_base64 = get_image_manager().transform_gif(image_base64) + image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore + if not image_base64: + raise RuntimeError("GIF表情包转换失败") prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg") else: diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index b85f53b79..0b1eaef7a 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -1,14 +1,16 @@ -from .exprssion_learner import get_expression_learner -import random -from typing import List, Dict, Tuple -from json_repair import repair_json import json import os import time +import random + +from typing import List, Dict, Tuple, Optional +from json_repair import repair_json + from src.llm_models.utils_model import LLMRequest from src.config.config import global_config from src.common.logger import get_logger from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from .exprssion_learner import get_expression_learner logger = get_logger("expression_selector") @@ -165,7 +167,12 @@ class ExpressionSelector: logger.error(f"批量更新表达方式count失败 for {file_path}: {e}") async def select_suitable_expressions_llm( - self, chat_id: str, chat_info: str, max_num: int = 10, min_num: int = 5, target_message: str = None + self, + chat_id: str, + chat_info: str, + max_num: int = 10, + min_num: int = 5, + target_message: Optional[str] = None, ) -> List[Dict[str, str]]: """使用LLM选择适合的表达方式""" diff --git a/src/chat/express/exprssion_learner.py b/src/chat/express/exprssion_learner.py index 9b170d9a3..738a88b95 100644 --- a/src/chat/express/exprssion_learner.py +++ b/src/chat/express/exprssion_learner.py @@ -1,14 +1,16 @@ import time import random +import json +import os + from typing import List, Dict, Optional, Any, Tuple + from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest from src.config.config import global_config from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -import os from src.chat.message_receive.chat_stream import get_chat_manager -import json MAX_EXPRESSION_COUNT = 300 @@ -74,7 +76,8 @@ class ExpressionLearner: ) self.llm_model = None - def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: + def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]: + # sourcery skip: extract-duplicate-method, remove-unnecessary-cast """ 获取指定chat_id的style和grammar表达方式 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作 @@ -119,10 +122,10 @@ class ExpressionLearner: min_len = min(len(s1), len(s2)) if min_len < 5: return False - same = sum(1 for a, b in zip(s1, s2) if a == b) + same = sum(a == b for a, b in zip(s1, s2)) return same / min_len > 0.8 - async def learn_and_store_expression(self) -> List[Tuple[str, str, str]]: + async def learn_and_store_expression(self) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str, str]]]: """ 学习并存储表达方式,分别学习语言风格和句法特点 同时对所有已存储的表达方式进行全局衰减 @@ -158,12 +161,12 @@ class ExpressionLearner: for _ in range(3): learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25) if not learnt_style: - return [] + return [], [] for _ in range(1): learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10) if not learnt_grammar: - return [] + return [], [] return learnt_style, learnt_grammar @@ -214,6 +217,7 @@ class ExpressionLearner: return result async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]: + # sourcery skip: use-join """ 选择从当前到最近1小时内的随机num条消息,然后学习这些消息的表达方式 type: "style" or "grammar" @@ -249,7 +253,7 @@ class ExpressionLearner: return [] # 按chat_id分组 - chat_dict: Dict[str, List[Dict[str, str]]] = {} + chat_dict: Dict[str, List[Dict[str, Any]]] = {} for chat_id, situation, style in learnt_expressions: if chat_id not in chat_dict: chat_dict[chat_id] = [] diff --git a/src/chat/focus_chat/focus_loop_info.py b/src/chat/focus_chat/focus_loop_info.py index 342368df7..827c544a2 100644 --- a/src/chat/focus_chat/focus_loop_info.py +++ b/src/chat/focus_chat/focus_loop_info.py @@ -1,10 +1,10 @@ # 定义了来自外部世界的信息 # 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体 from datetime import datetime +from typing import List + from src.common.logger import get_logger from src.chat.focus_chat.hfc_utils import CycleDetail -from typing import List -# Import the new utility function logger = get_logger("loop_info") diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index 70cda57c6..05600c256 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -8,7 +8,7 @@ from rich.traceback import install from src.config.config import global_config from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.timer_calculator import Timer from src.chat.planner_actions.planner import ActionPlanner @@ -49,7 +49,9 @@ class HeartFChatting: """ # 基础属性 self.stream_id: str = chat_id # 聊天流ID - self.chat_stream = get_chat_manager().get_stream(self.stream_id) + self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore + if not self.chat_stream: + raise ValueError(f"无法找到聊天流: {self.stream_id}") self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]" self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id) @@ -171,7 +173,7 @@ class HeartFChatting: # 执行规划和处理阶段 try: async with self._get_cycle_context(): - thinking_id = "tid" + str(round(time.time(), 2)) + thinking_id = f"tid{str(round(time.time(), 2))}" self._current_cycle_detail.set_thinking_id(thinking_id) # 使用异步上下文管理器处理消息 @@ -245,7 +247,7 @@ class HeartFChatting: logger.info( f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考," - f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " + f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "") ) @@ -256,7 +258,7 @@ class HeartFChatting: cycle_performance_data = { "cycle_id": self._current_cycle_detail.cycle_id, "action_type": action_result.get("action_type", "unknown"), - "total_time": self._current_cycle_detail.end_time - self._current_cycle_detail.start_time, + "total_time": self._current_cycle_detail.end_time - self._current_cycle_detail.start_time, # type: ignore "step_times": cycle_timers.copy(), "reasoning": action_result.get("reasoning", ""), "success": self._current_cycle_detail.loop_action_info.get("action_taken", False), @@ -447,11 +449,8 @@ class HeartFChatting: # 处理动作并获取结果 result = await action_handler.handle_action() - if len(result) == 3: - success, reply_text, command = result - else: - success, reply_text = result - command = "" + success, reply_text = result + command = "" # 检查action_data中是否有系统命令,优先使用系统命令 if "_system_command" in action_data: @@ -478,15 +477,14 @@ class HeartFChatting: ) # 设置系统命令,在下次循环检查时触发退出 command = "stop_focus_chat" - else: - if reply_text == "timeout": - self.reply_timeout_count += 1 - if self.reply_timeout_count > 5: - logger.warning( - f"[{self.log_prefix} ] 连续回复超时次数过多,{global_config.chat.thinking_timeout}秒 内大模型没有返回有效内容,请检查你的api是否速度过慢或配置错误。建议不要使用推理模型,推理模型生成速度过慢。或者尝试拉高thinking_timeout参数,这可能导致回复时间过长。" - ) - logger.warning(f"{self.log_prefix} 回复生成超时{global_config.chat.thinking_timeout}s,已跳过") - return False, "", "" + elif reply_text == "timeout": + self.reply_timeout_count += 1 + if self.reply_timeout_count > 5: + logger.warning( + f"[{self.log_prefix} ] 连续回复超时次数过多,{global_config.chat.thinking_timeout}秒 内大模型没有返回有效内容,请检查你的api是否速度过慢或配置错误。建议不要使用推理模型,推理模型生成速度过慢。或者尝试拉高thinking_timeout参数,这可能导致回复时间过长。" + ) + logger.warning(f"{self.log_prefix} 回复生成超时{global_config.chat.thinking_timeout}s,已跳过") + return False, "", "" return success, reply_text, command diff --git a/src/chat/focus_chat/hfc_performance_logger.py b/src/chat/focus_chat/hfc_performance_logger.py index 64e65ff85..702a8445f 100644 --- a/src/chat/focus_chat/hfc_performance_logger.py +++ b/src/chat/focus_chat/hfc_performance_logger.py @@ -2,6 +2,7 @@ import json from datetime import datetime from typing import Dict, Any from pathlib import Path + from src.common.logger import get_logger logger = get_logger("hfc_performance") diff --git a/src/chat/focus_chat/hfc_utils.py b/src/chat/focus_chat/hfc_utils.py index 11b04c801..0393c2175 100644 --- a/src/chat/focus_chat/hfc_utils.py +++ b/src/chat/focus_chat/hfc_utils.py @@ -1,11 +1,12 @@ import time -from typing import Optional +import json + +from typing import Optional, Dict, Any + from src.chat.message_receive.message import MessageRecv, BaseMessageInfo from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.message import UserInfo from src.common.logger import get_logger -import json -from typing import Dict, Any logger = get_logger(__name__) @@ -117,7 +118,7 @@ async def create_empty_anchor_message( placeholder_msg_info = BaseMessageInfo( message_id=placeholder_id, platform=platform, - group_info=group_info, + group_info=group_info, # type: ignore user_info=placeholder_user, time=time.time(), ) diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py index fdcfba6a3..4c5285259 100644 --- a/src/chat/heart_flow/heartflow.py +++ b/src/chat/heart_flow/heartflow.py @@ -1,7 +1,7 @@ -from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState +from typing import Any, Optional, Dict + from src.common.logger import get_logger -from typing import Any, Optional -from typing import Dict +from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState from src.chat.message_receive.chat_stream import get_chat_manager logger = get_logger("heartflow") @@ -34,7 +34,7 @@ class Heartflow: logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True) return None - async def force_change_subheartflow_status(self, subheartflow_id: str, status: ChatState) -> None: + async def force_change_subheartflow_status(self, subheartflow_id: str, status: ChatState) -> bool: """强制改变子心流的状态""" # 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据 return await self.force_change_state(subheartflow_id, status) diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index d01775168..aa8bfdbf0 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -1,21 +1,21 @@ -from src.chat.memory_system.Hippocampus import hippocampus_manager -from src.config.config import global_config import asyncio +import re +import math +import traceback + +from typing import Tuple + +from src.config.config import global_config +from src.chat.memory_system.Hippocampus import hippocampus_manager from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.storage import MessageStorage from src.chat.heart_flow.heartflow import heartflow from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.timer_calculator import Timer from src.common.logger import get_logger -import re -import math -import traceback -from typing import Tuple - from src.person_info.relationship_manager import get_relationship_manager from src.mood.mood_manager import mood_manager - logger = get_logger("chat") @@ -26,16 +26,16 @@ async def _process_relationship(message: MessageRecv) -> None: message: 消息对象,包含用户信息 """ platform = message.message_info.platform - user_id = message.message_info.user_info.user_id - nickname = message.message_info.user_info.user_nickname - cardname = message.message_info.user_info.user_cardname or nickname + user_id = message.message_info.user_info.user_id # type: ignore + nickname = message.message_info.user_info.user_nickname # type: ignore + cardname = message.message_info.user_info.user_cardname or nickname # type: ignore relationship_manager = get_relationship_manager() is_known = await relationship_manager.is_known_some_one(platform, user_id) if not is_known: logger.info(f"首次认识用户: {nickname}") - await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) + await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: @@ -105,9 +105,9 @@ class HeartFCMessageReceiver: # 2. 兴趣度计算与更新 interested_rate, is_mentioned = await _calculate_interest(message) - subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) + subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) # type: ignore - chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) + chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) # type: ignore asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate)) # 3. 日志记录 @@ -119,7 +119,7 @@ class HeartFCMessageReceiver: picid_pattern = r"\[picid:([^\]]+)\]" processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text) - logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") + logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]") diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py index 9f6a49895..fc230e255 100644 --- a/src/chat/heart_flow/sub_heartflow.py +++ b/src/chat/heart_flow/sub_heartflow.py @@ -1,16 +1,18 @@ import asyncio import time -from typing import Optional, List, Dict, Tuple import traceback + +from typing import Optional, List, Dict, Tuple +from rich.traceback import install + from src.common.logger import get_logger +from src.config.config import global_config from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.focus_chat.heartFC_chat import HeartFChatting from src.chat.normal_chat.normal_chat import NormalChat from src.chat.heart_flow.chat_state_info import ChatState, ChatStateInfo from src.chat.utils.utils import get_chat_type_and_target_info -from src.config.config import global_config -from rich.traceback import install logger = get_logger("sub_heartflow") @@ -40,7 +42,7 @@ class SubHeartflow: self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id) self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id # 兴趣消息集合 - self.interest_dict: Dict[str, tuple[MessageRecv, float, bool]] = {} + self.interest_dict: Dict[str, Tuple[MessageRecv, float, bool]] = {} # focus模式退出冷却时间管理 self.last_focus_exit_time: float = 0 # 上次退出focus模式的时间 @@ -297,7 +299,7 @@ class SubHeartflow: ) def add_message_to_normal_chat_cache(self, message: MessageRecv, interest_value: float, is_mentioned: bool): - self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned) + self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned) # type: ignore # 如果字典长度超过10,删除最旧的消息 if len(self.interest_dict) > 30: oldest_key = next(iter(self.interest_dict)) diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 29a26f64c..a3ee46a7a 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -42,7 +42,7 @@ def calculate_information_content(text): return entropy -def cosine_similarity(v1, v2): +def cosine_similarity(v1, v2): # sourcery skip: assign-if-exp, reintroduce-else """计算余弦相似度""" dot_product = np.dot(v1, v2) norm1 = np.linalg.norm(v1) @@ -89,14 +89,13 @@ class MemoryGraph: if not isinstance(self.G.nodes[concept]["memory_items"], list): self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]] self.G.nodes[concept]["memory_items"].append(memory) - # 更新最后修改时间 - self.G.nodes[concept]["last_modified"] = current_time else: self.G.nodes[concept]["memory_items"] = [memory] # 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time if "created_time" not in self.G.nodes[concept]: self.G.nodes[concept]["created_time"] = current_time - self.G.nodes[concept]["last_modified"] = current_time + # 更新最后修改时间 + self.G.nodes[concept]["last_modified"] = current_time else: # 如果是新节点,创建新的记忆列表 self.G.add_node( @@ -108,11 +107,7 @@ class MemoryGraph: def get_dot(self, concept): # 检查节点是否存在于图中 - if concept in self.G: - # 从图中获取节点数据 - node_data = self.G.nodes[concept] - return concept, node_data - return None + return (concept, self.G.nodes[concept]) if concept in self.G else None def get_related_item(self, topic, depth=1): if topic not in self.G: @@ -139,8 +134,7 @@ class MemoryGraph: if depth >= 2: # 获取相邻节点的记忆项 for neighbor in neighbors: - node_data = self.get_dot(neighbor) - if node_data: + if node_data := self.get_dot(neighbor): concept, data = node_data if "memory_items" in data: memory_items = data["memory_items"] @@ -194,9 +188,9 @@ class MemoryGraph: class Hippocampus: def __init__(self): self.memory_graph = MemoryGraph() - self.model_summary = None - self.entorhinal_cortex = None - self.parahippocampal_gyrus = None + self.model_summary: LLMRequest = None # type: ignore + self.entorhinal_cortex: EntorhinalCortex = None # type: ignore + self.parahippocampal_gyrus: ParahippocampalGyrus = None # type: ignore def initialize(self): # 初始化子组件 @@ -218,7 +212,7 @@ class Hippocampus: memory_items = [memory_items] if memory_items else [] # 使用集合来去重,避免排序 - unique_items = set(str(item) for item in memory_items) + unique_items = {str(item) for item in memory_items} # 使用frozenset来保证顺序一致性 content = f"{concept}:{frozenset(unique_items)}" return hash(content) @@ -231,6 +225,7 @@ class Hippocampus: @staticmethod def find_topic_llm(text, topic_num): + # sourcery skip: inline-immediately-returned-variable prompt = ( f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来," f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。" @@ -240,6 +235,7 @@ class Hippocampus: @staticmethod def topic_what(text, topic): + # sourcery skip: inline-immediately-returned-variable # 不再需要 time_info 参数 prompt = ( f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,' @@ -480,9 +476,7 @@ class Hippocampus: top_memories = memory_similarities[:max_memory_length] # 添加到结果中 - for memory, similarity in top_memories: - all_memories.append((node, [memory], similarity)) - # logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})") + all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories) else: logger.info("节点没有记忆") @@ -646,9 +640,7 @@ class Hippocampus: top_memories = memory_similarities[:max_memory_length] # 添加到结果中 - for memory, similarity in top_memories: - all_memories.append((node, [memory], similarity)) - # logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})") + all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories) else: logger.info("节点没有记忆") @@ -823,11 +815,11 @@ class EntorhinalCortex: logger.debug(f"回忆往事: {readable_timestamp}") chat_samples = [] for timestamp in timestamps: - # 调用修改后的 random_get_msg_snippet - messages = self.random_get_msg_snippet( - timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg - ) - if messages: + if messages := self.random_get_msg_snippet( + timestamp, + global_config.memory.memory_build_sample_length, + max_memorized_time_per_msg, + ): time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条") chat_samples.append(messages) @@ -838,6 +830,7 @@ class EntorhinalCortex: @staticmethod def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None: + # sourcery skip: invert-any-all, use-any, use-named-expression, use-next """从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)""" try_count = 0 time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟 @@ -847,22 +840,21 @@ class EntorhinalCortex: timestamp_start = target_timestamp timestamp_end = target_timestamp + time_window_seconds - chosen_message = get_raw_msg_by_timestamp( - timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest" - ) + if chosen_message := get_raw_msg_by_timestamp( + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + limit=1, + limit_mode="earliest", + ): + chat_id: str = chosen_message[0].get("chat_id") # type: ignore - if chosen_message: - chat_id = chosen_message[0].get("chat_id") - - messages = get_raw_msg_by_timestamp_with_chat( + if messages := get_raw_msg_by_timestamp_with_chat( timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, limit_mode="earliest", chat_id=chat_id, - ) - - if messages: + ): # 检查获取到的所有消息是否都未达到最大记忆次数 all_valid = True for message in messages: @@ -975,7 +967,7 @@ class EntorhinalCortex: ).execute() if nodes_to_delete: - GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() + GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore # 处理边的信息 db_edges = list(GraphEdges.select()) @@ -1114,7 +1106,7 @@ class EntorhinalCortex: node_start = time.time() if nodes_data: batch_size = 500 # 增加批量大小 - with GraphNodes._meta.database.atomic(): + with GraphNodes._meta.database.atomic(): # type: ignore for i in range(0, len(nodes_data), batch_size): batch = nodes_data[i : i + batch_size] GraphNodes.insert_many(batch).execute() @@ -1125,7 +1117,7 @@ class EntorhinalCortex: edge_start = time.time() if edges_data: batch_size = 500 # 增加批量大小 - with GraphEdges._meta.database.atomic(): + with GraphEdges._meta.database.atomic(): # type: ignore for i in range(0, len(edges_data), batch_size): batch = edges_data[i : i + batch_size] GraphEdges.insert_many(batch).execute() @@ -1489,32 +1481,30 @@ class ParahippocampalGyrus: # --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 --- last_modified = node_data.get("last_modified", current_time) # 条件1:检查是否长时间未修改 (超过24小时) - if current_time - last_modified > 3600 * 24: - # 条件2:再次确认节点包含记忆项(理论上已确认,但作为保险) - if memory_items: - current_count = len(memory_items) - # 如果列表非空,才进行随机选择 - if current_count > 0: - removed_item = random.choice(memory_items) - try: - memory_items.remove(removed_item) + if current_time - last_modified > 3600 * 24 and memory_items: + current_count = len(memory_items) + # 如果列表非空,才进行随机选择 + if current_count > 0: + removed_item = random.choice(memory_items) + try: + memory_items.remove(removed_item) - # 条件3:检查移除后 memory_items 是否变空 - if memory_items: # 如果移除后列表不为空 - # self.memory_graph.G.nodes[node]["memory_items"] = memory_items # 直接修改列表即可 - self.memory_graph.G.nodes[node]["last_modified"] = current_time # 更新修改时间 - node_changes["reduced"].append(f"{node} (数量: {current_count} -> {len(memory_items)})") - else: # 如果移除后列表为空 - # 尝试移除节点,处理可能的错误 - try: - self.memory_graph.G.remove_node(node) - node_changes["removed"].append(f"{node}(遗忘清空)") # 标记为遗忘清空 - logger.debug(f"[遗忘] 节点 {node} 因移除最后一项而被清空。") - except nx.NetworkXError as e: - logger.warning(f"[遗忘] 尝试移除节点 {node} 时发生错误(可能已被移除):{e}") - except ValueError: - # 这个错误理论上不应发生,因为 removed_item 来自 memory_items - logger.warning(f"[遗忘] 尝试从节点 '{node}' 移除不存在的项目 '{removed_item[:30]}...'") + # 条件3:检查移除后 memory_items 是否变空 + if memory_items: # 如果移除后列表不为空 + # self.memory_graph.G.nodes[node]["memory_items"] = memory_items # 直接修改列表即可 + self.memory_graph.G.nodes[node]["last_modified"] = current_time # 更新修改时间 + node_changes["reduced"].append(f"{node} (数量: {current_count} -> {len(memory_items)})") + else: # 如果移除后列表为空 + # 尝试移除节点,处理可能的错误 + try: + self.memory_graph.G.remove_node(node) + node_changes["removed"].append(f"{node}(遗忘清空)") # 标记为遗忘清空 + logger.debug(f"[遗忘] 节点 {node} 因移除最后一项而被清空。") + except nx.NetworkXError as e: + logger.warning(f"[遗忘] 尝试移除节点 {node} 时发生错误(可能已被移除):{e}") + except ValueError: + # 这个错误理论上不应发生,因为 removed_item 来自 memory_items + logger.warning(f"[遗忘] 尝试从节点 '{node}' 移除不存在的项目 '{removed_item[:30]}...'") node_check_end = time.time() logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}秒") @@ -1669,7 +1659,7 @@ class ParahippocampalGyrus: class HippocampusManager: def __init__(self): - self._hippocampus = None + self._hippocampus: Hippocampus = None # type: ignore self._initialized = False def initialize(self): diff --git a/src/chat/memory_system/memory_activator.py b/src/chat/memory_system/memory_activator.py index 560fe01a6..66ff89755 100644 --- a/src/chat/memory_system/memory_activator.py +++ b/src/chat/memory_system/memory_activator.py @@ -13,7 +13,7 @@ from json_repair import repair_json logger = get_logger("memory_activator") -def get_keywords_from_json(json_str): +def get_keywords_from_json(json_str) -> List: """ 从JSON字符串中提取关键词列表 @@ -28,15 +28,8 @@ def get_keywords_from_json(json_str): fixed_json = repair_json(json_str) # 如果repair_json返回的是字符串,需要解析为Python对象 - if isinstance(fixed_json, str): - result = json.loads(fixed_json) - else: - # 如果repair_json直接返回了字典对象,直接使用 - result = fixed_json - - # 提取关键词 - keywords = result.get("keywords", []) - return keywords + result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json + return result.get("keywords", []) except Exception as e: logger.error(f"解析关键词JSON失败: {e}") return [] diff --git a/src/chat/memory_system/sample_distribution.py b/src/chat/memory_system/sample_distribution.py index b3b84eb4c..69f23a770 100644 --- a/src/chat/memory_system/sample_distribution.py +++ b/src/chat/memory_system/sample_distribution.py @@ -1,52 +1,10 @@ import numpy as np -from scipy import stats from datetime import datetime, timedelta from rich.traceback import install install(extra_lines=3) -class DistributionVisualizer: - def __init__(self, mean=0, std=1, skewness=0, sample_size=10): - """ - 初始化分布可视化器 - - 参数: - mean (float): 期望均值 - std (float): 标准差 - skewness (float): 偏度 - sample_size (int): 样本大小 - """ - self.mean = mean - self.std = std - self.skewness = skewness - self.sample_size = sample_size - self.samples = None - - def generate_samples(self): - """生成具有指定参数的样本""" - if self.skewness == 0: - # 对于无偏度的情况,直接使用正态分布 - self.samples = np.random.normal(loc=self.mean, scale=self.std, size=self.sample_size) - else: - # 使用 scipy.stats 生成具有偏度的分布 - self.samples = stats.skewnorm.rvs(a=self.skewness, loc=self.mean, scale=self.std, size=self.sample_size) - - def get_weighted_samples(self): - """获取加权后的样本数列""" - if self.samples is None: - self.generate_samples() - # 将样本值乘以样本大小 - return self.samples * self.sample_size - - def get_statistics(self): - """获取分布的统计信息""" - if self.samples is None: - self.generate_samples() - - return {"均值": np.mean(self.samples), "标准差": np.std(self.samples), "实际偏度": stats.skew(self.samples)} - - class MemoryBuildScheduler: def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50): """ diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index b460ad99b..3d1f1e341 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -1,23 +1,25 @@ import traceback import os +import re + from typing import Dict, Any +from maim_message import UserInfo from src.common.logger import get_logger +from src.config.config import global_config from src.mood.mood_manager import mood_manager # 导入情绪管理器 -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream from src.chat.message_receive.message import MessageRecv -from src.experimental.only_message_process import MessageProcessor from src.chat.message_receive.storage import MessageStorage -from src.experimental.PFC.pfc_manager import PFCManager from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.config.config import global_config +from src.experimental.only_message_process import MessageProcessor +from src.experimental.PFC.pfc_manager import PFCManager from src.plugin_system.core.component_registry import component_registry # 导入新插件系统 from src.plugin_system.base.base_command import BaseCommand from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor -from maim_message import UserInfo -from src.chat.message_receive.chat_stream import ChatStream -import re + + # 定义日志配置 # 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录) @@ -184,8 +186,8 @@ class ChatBot: get_chat_manager().register_message(message) chat = await get_chat_manager().get_or_create_stream( - platform=message.message_info.platform, - user_info=user_info, + platform=message.message_info.platform, # type: ignore + user_info=user_info, # type: ignore group_info=group_info, ) @@ -195,8 +197,10 @@ class ChatBot: await message.process() # 过滤检查 - if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( - message.raw_message, chat, user_info + if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore + message.raw_message, # type: ignore + chat, + user_info, # type: ignore ): return diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 355cca1e6..8b71314a6 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -3,18 +3,17 @@ import hashlib import time import copy from typing import Dict, Optional, TYPE_CHECKING - - -from ...common.database.database import db -from ...common.database.database_model import ChatStreams # 新增导入 +from rich.traceback import install from maim_message import GroupInfo, UserInfo +from src.common.logger import get_logger +from src.common.database.database import db +from src.common.database.database_model import ChatStreams # 新增导入 + # 避免循环导入,使用TYPE_CHECKING进行类型提示 if TYPE_CHECKING: from .message import MessageRecv -from src.common.logger import get_logger -from rich.traceback import install install(extra_lines=3) @@ -28,7 +27,7 @@ class ChatMessageContext: def __init__(self, message: "MessageRecv"): self.message = message - def get_template_name(self) -> str: + def get_template_name(self) -> Optional[str]: """获取模板名称""" if self.message.message_info.template_info and not self.message.message_info.template_info.template_default: return self.message.message_info.template_info.template_name @@ -41,10 +40,10 @@ class ChatMessageContext: def check_types(self, types: list) -> bool: # sourcery skip: invert-any-all, use-any, use-next """检查消息类型""" - if not self.message.message_info.format_info.accept_format: + if not self.message.message_info.format_info.accept_format: # type: ignore return False for t in types: - if t not in self.message.message_info.format_info.accept_format: + if t not in self.message.message_info.format_info.accept_format: # type: ignore return False return True @@ -68,7 +67,7 @@ class ChatStream: platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None, - data: dict = None, + data: Optional[dict] = None, ): self.stream_id = stream_id self.platform = platform @@ -77,7 +76,7 @@ class ChatStream: self.create_time = data.get("create_time", time.time()) if data else time.time() self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time self.saved = False - self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息 + self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息 def to_dict(self) -> dict: """转换为字典格式""" @@ -99,7 +98,7 @@ class ChatStream: return cls( stream_id=data["stream_id"], platform=data["platform"], - user_info=user_info, + user_info=user_info, # type: ignore group_info=group_info, data=data, ) @@ -163,8 +162,8 @@ class ChatManager: def register_message(self, message: "MessageRecv"): """注册消息到聊天流""" stream_id = self._generate_stream_id( - message.message_info.platform, - message.message_info.user_info, + message.message_info.platform, # type: ignore + message.message_info.user_info, # type: ignore message.message_info.group_info, ) self.last_messages[stream_id] = message @@ -185,10 +184,7 @@ class ChatManager: def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str: """获取聊天流ID""" - if is_group: - components = [platform, str(id)] - else: - components = [platform, str(id), "private"] + components = [platform, id] if is_group else [platform, id, "private"] key = "_".join(components) return hashlib.md5(key.encode()).hexdigest() diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 7575e0e53..f8d917574 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -1,17 +1,15 @@ import time -from abc import abstractmethod -from dataclasses import dataclass -from typing import Optional, Any, TYPE_CHECKING - import urllib3 -from src.common.logger import get_logger - -if TYPE_CHECKING: - from .chat_stream import ChatStream -from ..utils.utils_image import get_image_manager -from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase +from abc import abstractmethod +from dataclasses import dataclass from rich.traceback import install +from typing import Optional, Any +from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase + +from src.common.logger import get_logger +from src.chat.utils.utils_image import get_image_manager +from .chat_stream import ChatStream install(extra_lines=3) @@ -27,7 +25,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @dataclass class Message(MessageBase): - chat_stream: "ChatStream" = None + chat_stream: "ChatStream" = None # type: ignore reply: Optional["Message"] = None processed_plain_text: str = "" memorized_times: int = 0 @@ -55,7 +53,7 @@ class Message(MessageBase): ) # 调用父类初始化 - super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) + super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore self.chat_stream = chat_stream # 文本处理相关属性 @@ -66,6 +64,7 @@ class Message(MessageBase): self.reply = reply async def _process_message_segments(self, segment: Seg) -> str: + # sourcery skip: remove-unnecessary-else, swap-if-else-branches """递归处理消息段,转换为文字描述 Args: @@ -78,13 +77,13 @@ class Message(MessageBase): # 处理消息段列表 segments_text = [] for seg in segment.data: - processed = await self._process_message_segments(seg) + processed = await self._process_message_segments(seg) # type: ignore if processed: segments_text.append(processed) return " ".join(segments_text) else: # 处理单个消息段 - return await self._process_single_segment(segment) + return await self._process_single_segment(segment) # type: ignore @abstractmethod async def _process_single_segment(self, segment): @@ -138,7 +137,7 @@ class MessageRecv(Message): if segment.type == "text": self.is_picid = False self.is_emoji = False - return segment.data + return segment.data # type: ignore elif segment.type == "image": # 如果是base64图片数据 if isinstance(segment.data, str): @@ -160,7 +159,7 @@ class MessageRecv(Message): elif segment.type == "mention_bot": self.is_picid = False self.is_emoji = False - self.is_mentioned = float(segment.data) + self.is_mentioned = float(segment.data) # type: ignore return "" elif segment.type == "priority_info": self.is_picid = False @@ -186,7 +185,7 @@ class MessageRecv(Message): """生成详细文本,包含时间和用户信息""" timestamp = self.message_info.time user_info = self.message_info.user_info - name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" + name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore return f"[{timestamp}] {name}: {self.processed_plain_text}\n" @@ -234,7 +233,7 @@ class MessageProcessBase(Message): """ try: if seg.type == "text": - return seg.data + return seg.data # type: ignore elif seg.type == "image": # 如果是base64图片数据 if isinstance(seg.data, str): @@ -250,7 +249,7 @@ class MessageProcessBase(Message): if self.reply and hasattr(self.reply, "processed_plain_text"): # print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}") # print(f"reply: {self.reply}") - return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" + return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore return None else: return f"[{seg.type}:{str(seg.data)}]" @@ -264,7 +263,7 @@ class MessageProcessBase(Message): timestamp = self.message_info.time user_info = self.message_info.user_info - name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" + name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n" @@ -313,7 +312,7 @@ class MessageSending(MessageProcessBase): is_emoji: bool = False, thinking_start_time: float = 0, apply_set_reply_logic: bool = False, - reply_to: str = None, + reply_to: str = None, # type: ignore ): # 调用父类初始化 super().__init__( @@ -344,7 +343,7 @@ class MessageSending(MessageProcessBase): self.message_segment = Seg( type="seglist", data=[ - Seg(type="reply", data=self.reply.message_info.message_id), + Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore self.message_segment, ], ) @@ -364,10 +363,10 @@ class MessageSending(MessageProcessBase): ) -> "MessageSending": """从思考状态消息创建发送状态消息""" return cls( - message_id=thinking.message_info.message_id, + message_id=thinking.message_info.message_id, # type: ignore chat_stream=thinking.chat_stream, message_segment=message_segment, - bot_user_info=thinking.message_info.user_info, + bot_user_info=thinking.message_info.user_info, # type: ignore reply=thinking.reply, is_head=is_head, is_emoji=is_emoji, @@ -399,13 +398,11 @@ class MessageSet: if not isinstance(message, MessageSending): raise TypeError("MessageSet只能添加MessageSending类型的消息") self.messages.append(message) - self.messages.sort(key=lambda x: x.message_info.time) + self.messages.sort(key=lambda x: x.message_info.time) # type: ignore def get_message_by_index(self, index: int) -> Optional[MessageSending]: """通过索引获取消息""" - if 0 <= index < len(self.messages): - return self.messages[index] - return None + return self.messages[index] if 0 <= index < len(self.messages) else None def get_message_by_time(self, target_time: float) -> Optional[MessageSending]: """获取最接近指定时间的消息""" @@ -415,7 +412,7 @@ class MessageSet: left, right = 0, len(self.messages) - 1 while left < right: mid = (left + right) // 2 - if self.messages[mid].message_info.time < target_time: + if self.messages[mid].message_info.time < target_time: # type: ignore left = mid + 1 else: right = mid diff --git a/src/chat/message_receive/normal_message_sender.py b/src/chat/message_receive/normal_message_sender.py index aa6721db3..95d296473 100644 --- a/src/chat/message_receive/normal_message_sender.py +++ b/src/chat/message_receive/normal_message_sender.py @@ -1,21 +1,16 @@ -# src/plugins/chat/message_sender.py import asyncio import time from asyncio import Task from typing import Union -from src.common.message.api import get_global_api - -# from ...common.database import db # 数据库依赖似乎不需要了,注释掉 -from .message import MessageSending, MessageThinking, MessageSet - -from src.chat.message_receive.storage import MessageStorage -from ..utils.utils import truncate_message, calculate_typing_time, count_messages_between - -from src.common.logger import get_logger from rich.traceback import install -install(extra_lines=3) +from src.common.logger import get_logger +from src.common.message.api import get_global_api +from src.chat.message_receive.storage import MessageStorage +from src.chat.utils.utils import truncate_message, calculate_typing_time, count_messages_between +from .message import MessageSending, MessageThinking, MessageSet +install(extra_lines=3) logger = get_logger("sender") @@ -79,9 +74,10 @@ class MessageContainer: def count_thinking_messages(self) -> int: """计算当前容器中思考消息的数量""" - return sum(1 for msg in self.messages if isinstance(msg, MessageThinking)) + return sum(isinstance(msg, MessageThinking) for msg in self.messages) def get_timeout_sending_messages(self) -> list[MessageSending]: + # sourcery skip: merge-nested-ifs """获取所有超时的MessageSending对象(思考时间超过20秒),按thinking_start_time排序 - 从旧 sender 合并""" current_time = time.time() timeout_messages = [] @@ -230,9 +226,7 @@ class MessageManager: f"[{message.chat_stream.stream_id}] 处理发送消息 {getattr(message.message_info, 'message_id', 'N/A')} 时出错: {e}" ) logger.exception("详细错误信息:") - # 考虑是否移除出错的消息,防止无限循环 - removed = container.remove_message(message) - if removed: + if container.remove_message(message): logger.warning(f"[{message.chat_stream.stream_id}] 已移除处理出错的消息。") async def _process_chat_messages(self, chat_id: str): @@ -261,10 +255,7 @@ class MessageManager: # --- 处理发送消息 --- await self._handle_sending_message(container, message_earliest) - # --- 处理超时发送消息 (来自旧 sender) --- - # 在处理完最早的消息后,检查是否有超时的发送消息 - timeout_sending_messages = container.get_timeout_sending_messages() - if timeout_sending_messages: + if timeout_sending_messages := container.get_timeout_sending_messages(): logger.debug(f"[{chat_id}] 发现 {len(timeout_sending_messages)} 条超时的发送消息") for msg in timeout_sending_messages: # 确保不是刚刚处理过的最早消息 (虽然理论上应该已被移除,但以防万一) @@ -274,6 +265,7 @@ class MessageManager: await self._handle_sending_message(container, msg) # 复用处理逻辑 async def _start_processor_loop(self): + # sourcery skip: list-comprehension, move-assign-in-block, use-named-expression """消息处理器主循环""" while self._running: tasks = [] @@ -282,10 +274,7 @@ class MessageManager: # 创建 keys 的快照以安全迭代 chat_ids = list(self.containers.keys()) - for chat_id in chat_ids: - # 为每个 chat_id 创建一个处理任务 - tasks.append(asyncio.create_task(self._process_chat_messages(chat_id))) - + tasks.extend(asyncio.create_task(self._process_chat_messages(chat_id)) for chat_id in chat_ids) if tasks: try: # 等待当前批次的所有任务完成 diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index c40c4eb75..d5fc7b514 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,11 +1,10 @@ import re from typing import Union -# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance -from .message import MessageSending, MessageRecv -from .chat_stream import ChatStream -from ...common.database.database_model import Messages, RecalledMessages, Images # Import Peewee models +from src.common.database.database_model import Messages, RecalledMessages, Images from src.common.logger import get_logger +from .chat_stream import ChatStream +from .message import MessageSending, MessageRecv logger = get_logger("message_storage") @@ -44,7 +43,7 @@ class MessageStorage: reply_to = "" chat_info_dict = chat_stream.to_dict() - user_info_dict = message.message_info.user_info.to_dict() + user_info_dict = message.message_info.user_info.to_dict() # type: ignore # message_id 现在是 TextField,直接使用字符串值 msg_id = message.message_info.message_id @@ -56,7 +55,7 @@ class MessageStorage: Messages.create( message_id=msg_id, - time=float(message.message_info.time), + time=float(message.message_info.time), # type: ignore chat_id=chat_stream.stream_id, # Flattened chat_info reply_to=reply_to, @@ -103,7 +102,7 @@ class MessageStorage: try: # Assuming input 'time' is a string timestamp that can be converted to float current_time_float = float(time) - RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() + RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() # type: ignore except Exception: logger.exception("删除撤回消息失败") @@ -115,22 +114,19 @@ class MessageStorage: """更新最新一条匹配消息的message_id""" try: if message.message_segment.type == "notify": - mmc_message_id = message.message_segment.data.get("echo") - qq_message_id = message.message_segment.data.get("actual_id") + mmc_message_id = message.message_segment.data.get("echo") # type: ignore + qq_message_id = message.message_segment.data.get("actual_id") # type: ignore else: logger.info(f"更新消息ID错误,seg类型为{message.message_segment.type}") return if not qq_message_id: logger.info("消息不存在message_id,无法更新") return - # 查询最新一条匹配消息 - matched_message = ( + if matched_message := ( Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first() - ) - - if matched_message: + ): # 更新找到的消息记录 - Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() + Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") else: logger.debug("未找到匹配的消息") @@ -155,10 +151,7 @@ class MessageStorage: image_record = ( Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first() ) - if image_record: - return f"[picid:{image_record.image_id}]" - else: - return match.group(0) # 保持原样 + return f"[picid:{image_record.image_id}]" if image_record else match.group(0) except Exception: return match.group(0) diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 0efcf16d8..663bf23a8 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -1,16 +1,17 @@ import asyncio -from typing import Dict, Optional # 重新导入类型 -from src.chat.message_receive.message import MessageSending, MessageThinking -from src.common.message.api import get_global_api -from src.chat.message_receive.storage import MessageStorage -from src.chat.utils.utils import truncate_message -from src.common.logger import get_logger -from src.chat.utils.utils import calculate_typing_time -from rich.traceback import install import traceback -install(extra_lines=3) +from typing import Dict, Optional +from rich.traceback import install +from src.common.message.api import get_global_api +from src.common.logger import get_logger +from src.chat.message_receive.message import MessageSending, MessageThinking +from src.chat.message_receive.storage import MessageStorage +from src.chat.utils.utils import truncate_message +from src.chat.utils.utils import calculate_typing_time + +install(extra_lines=3) logger = get_logger("sender") @@ -86,10 +87,10 @@ class HeartFCSender: """ if not message.chat_stream: logger.error("消息缺少 chat_stream,无法发送") - raise Exception("消息缺少 chat_stream,无法发送") + raise ValueError("消息缺少 chat_stream,无法发送") if not message.message_info or not message.message_info.message_id: logger.error("消息缺少 message_info 或 message_id,无法发送") - raise Exception("消息缺少 message_info 或 message_id,无法发送") + raise ValueError("消息缺少 message_info 或 message_id,无法发送") chat_id = message.chat_stream.stream_id message_id = message.message_info.message_id diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index 63e394c7c..414d607a1 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -1,6 +1,7 @@ import asyncio import time import traceback + from random import random from typing import List, Optional, Dict from maim_message import UserInfo, Seg @@ -40,7 +41,7 @@ class NormalChat: def __init__( self, chat_stream: ChatStream, - interest_dict: dict = None, + interest_dict: Optional[Dict] = None, on_switch_to_focus_callback=None, get_cooldown_progress_callback=None, ): @@ -147,10 +148,7 @@ class NormalChat: while not self._disabled: try: if not self.priority_manager.is_empty(): - # 获取最高优先级的消息 - message = self.priority_manager.get_highest_priority_message() - - if message: + if message := self.priority_manager.get_highest_priority_message(): logger.info( f"[{self.stream_name}] 从队列中取出消息进行处理: User {message.message_info.user_info.user_id}, Time: {time.strftime('%H:%M:%S', time.localtime(message.message_info.time))}" ) diff --git a/src/chat/normal_chat/priority_manager.py b/src/chat/normal_chat/priority_manager.py index 0296017ff..8c1c0e731 100644 --- a/src/chat/normal_chat/priority_manager.py +++ b/src/chat/normal_chat/priority_manager.py @@ -53,7 +53,7 @@ class PriorityManager: """ 添加新消息到合适的队列中。 """ - user_id = message.message_info.user_info.user_id + user_id = message.message_info.user_info.user_id # type: ignore is_vip = message.priority_info.get("message_type") == "vip" if message.priority_info else False message_priority = message.priority_info.get("message_priority", 0.0) if message.priority_info else 0.0 diff --git a/src/chat/normal_chat/willing/mode_classical.py b/src/chat/normal_chat/willing/mode_classical.py index 0b296bbf4..7539274c1 100644 --- a/src/chat/normal_chat/willing/mode_classical.py +++ b/src/chat/normal_chat/willing/mode_classical.py @@ -35,9 +35,7 @@ class ClassicalWillingManager(BaseWillingManager): self.chat_reply_willing[chat_id] = min(current_willing, 3.0) - reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1) - - return reply_probability + return min(max((current_willing - 0.5), 0.01) * 2, 1) async def before_generate_reply_handle(self, message_id): chat_id = self.ongoing_messages[message_id].chat_id diff --git a/src/chat/normal_chat/willing/willing_manager.py b/src/chat/normal_chat/willing/willing_manager.py index 0fa701f94..f797bc3e0 100644 --- a/src/chat/normal_chat/willing/willing_manager.py +++ b/src/chat/normal_chat/willing/willing_manager.py @@ -1,14 +1,16 @@ -from src.common.logger import get_logger +import importlib +import asyncio + +from abc import ABC, abstractmethod +from typing import Dict, Optional +from rich.traceback import install from dataclasses import dataclass + +from src.common.logger import get_logger from src.config.config import global_config from src.chat.message_receive.chat_stream import ChatStream, GroupInfo from src.chat.message_receive.message import MessageRecv from src.person_info.person_info import PersonInfoManager, get_person_info_manager -from abc import ABC, abstractmethod -import importlib -from typing import Dict, Optional -import asyncio -from rich.traceback import install install(extra_lines=3) @@ -92,8 +94,8 @@ class BaseWillingManager(ABC): self.logger = logger def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float): - person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) - self.ongoing_messages[message.message_info.message_id] = WillingInfo( + person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) # type: ignore + self.ongoing_messages[message.message_info.message_id] = WillingInfo( # type: ignore message=message, chat=chat, person_info_manager=get_person_info_manager(), diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 45bdfd72d..ed045436f 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -27,14 +27,11 @@ class ActionManager: # 当前正在使用的动作集合,默认加载默认动作 self._using_actions: Dict[str, ActionInfo] = {} - # 默认动作集,仅作为快照,用于恢复默认 - self._default_actions: Dict[str, ActionInfo] = {} - # 加载插件动作 self._load_plugin_actions() # 初始化时将默认动作加载到使用中的动作 - self._using_actions = self._default_actions.copy() + self._using_actions = component_registry.get_default_actions() def _load_plugin_actions(self) -> None: """ @@ -52,7 +49,7 @@ class ActionManager: """从插件系统的component_registry加载Action组件""" try: # 获取所有Action组件 - action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) + action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore for action_name, action_info in action_components.items(): if action_name in self._registered_actions: @@ -61,10 +58,6 @@ class ActionManager: self._registered_actions[action_name] = action_info - # 如果启用,也添加到默认动作集 - if action_info.enabled: - self._default_actions[action_name] = action_info - logger.debug( f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})" ) @@ -106,7 +99,9 @@ class ActionManager: """ try: # 获取组件类 - 明确指定查询Action类型 - component_class = component_registry.get_component_class(action_name, ComponentType.ACTION) + component_class: Type[BaseAction] = component_registry.get_component_class( + action_name, ComponentType.ACTION + ) # type: ignore if not component_class: logger.warning(f"{log_prefix} 未找到Action组件: {action_name}") return None @@ -146,10 +141,6 @@ class ActionManager: """获取所有已注册的动作集""" return self._registered_actions.copy() - def get_default_actions(self) -> Dict[str, ActionInfo]: - """获取默认动作集""" - return self._default_actions.copy() - def get_using_actions(self) -> Dict[str, ActionInfo]: """获取当前正在使用的动作集合""" return self._using_actions.copy() @@ -217,31 +208,31 @@ class ActionManager: logger.debug(f"已从使用集中移除动作 {action_name}") return True - def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool: - """ - 添加新的动作到注册集 + # def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool: + # """ + # 添加新的动作到注册集 - Args: - action_name: 动作名称 - description: 动作描述 - parameters: 动作参数定义,默认为空字典 - require: 动作依赖项,默认为空列表 + # Args: + # action_name: 动作名称 + # description: 动作描述 + # parameters: 动作参数定义,默认为空字典 + # require: 动作依赖项,默认为空列表 - Returns: - bool: 添加是否成功 - """ - if action_name in self._registered_actions: - return False + # Returns: + # bool: 添加是否成功 + # """ + # if action_name in self._registered_actions: + # return False - if parameters is None: - parameters = {} - if require is None: - require = [] + # if parameters is None: + # parameters = {} + # if require is None: + # require = [] - action_info = {"description": description, "parameters": parameters, "require": require} + # action_info = {"description": description, "parameters": parameters, "require": require} - self._registered_actions[action_name] = action_info - return True + # self._registered_actions[action_name] = action_info + # return True def remove_action(self, action_name: str) -> bool: """从注册集移除指定动作""" @@ -260,10 +251,9 @@ class ActionManager: def restore_actions(self) -> None: """恢复到默认动作集""" - logger.debug( - f"恢复动作集: 从 {list(self._using_actions.keys())} 恢复到默认动作集 {list(self._default_actions.keys())}" - ) - self._using_actions = self._default_actions.copy() + actions_to_restore = list(self._using_actions.keys()) + self._using_actions = component_registry.get_default_actions() + logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") def add_system_action_if_needed(self, action_name: str) -> bool: """ @@ -293,4 +283,4 @@ class ActionManager: """ from src.plugin_system.core.component_registry import component_registry - return component_registry.get_component_class(action_name) + return component_registry.get_component_class(action_name) # type: ignore diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 8aaafc201..21a4ce06e 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -2,7 +2,7 @@ import random import asyncio import hashlib import time -from typing import List, Any, Dict +from typing import List, Any, Dict, TYPE_CHECKING from src.common.logger import get_logger from src.config.config import global_config @@ -13,6 +13,9 @@ from src.chat.planner_actions.action_manager import ActionManager from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages from src.plugin_system.base.component_types import ChatMode, ActionInfo, ActionActivationType +if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream + logger = get_logger("action_manager") @@ -27,7 +30,7 @@ class ActionModifier: def __init__(self, action_manager: ActionManager, chat_id: str): """初始化动作处理器""" self.chat_id = chat_id - self.chat_stream = get_chat_manager().get_stream(self.chat_id) + self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]" self.action_manager = action_manager @@ -142,7 +145,7 @@ class ActionModifier: async def _get_deactivated_actions_by_type( self, actions_with_info: Dict[str, ActionInfo], - mode: str = "focus", + mode: ChatMode = ChatMode.FOCUS, chat_content: str = "", ) -> List[tuple[str, str]]: """ @@ -270,7 +273,7 @@ class ActionModifier: task_results = await asyncio.gather(*tasks, return_exceptions=True) # 处理结果并更新缓存 - for _, (action_name, result) in enumerate(zip(task_names, task_results)): + for action_name, result in zip(task_names, task_results): if isinstance(result, Exception): logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}") results[action_name] = False @@ -286,7 +289,7 @@ class ActionModifier: except Exception as e: logger.error(f"{self.log_prefix}并行LLM判定失败: {e}") # 如果并行执行失败,为所有任务返回False - for action_name in tasks_to_run.keys(): + for action_name in tasks_to_run: results[action_name] = False # 清理过期缓存 @@ -297,10 +300,11 @@ class ActionModifier: def _cleanup_expired_cache(self, current_time: float): """清理过期的缓存条目""" expired_keys = [] - for cache_key, cache_data in self._llm_judge_cache.items(): - if current_time - cache_data["timestamp"] > self._cache_expiry_time: - expired_keys.append(cache_key) - + expired_keys.extend( + cache_key + for cache_key, cache_data in self._llm_judge_cache.items() + if current_time - cache_data["timestamp"] > self._cache_expiry_time + ) for key in expired_keys: del self._llm_judge_cache[key] @@ -379,7 +383,7 @@ class ActionModifier: def _check_keyword_activation( self, action_name: str, - action_info: Dict[str, Any], + action_info: ActionInfo, chat_content: str = "", ) -> bool: """ @@ -396,8 +400,8 @@ class ActionModifier: bool: 是否应该激活此action """ - activation_keywords = action_info.get("activation_keywords", []) - case_sensitive = action_info.get("keyword_case_sensitive", False) + activation_keywords = action_info.activation_keywords + case_sensitive = action_info.keyword_case_sensitive if not activation_keywords: logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词") diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index f4c8a9a4a..850f43d12 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -70,7 +70,7 @@ class ActionPlanner: self.last_obs_time_mark = 0.0 - async def plan(self) -> Dict[str, Any]: + async def plan(self) -> Dict[str, Any]: # sourcery skip: dict-comprehension """ 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 """ @@ -162,7 +162,6 @@ class ActionPlanner: reasoning = parsed_json.get("reasoning", "未提供原因") # 将所有其他属性添加到action_data - action_data = {} for key, value in parsed_json.items(): if key not in ["action", "reasoning"]: action_data[key] = value @@ -285,7 +284,7 @@ class ActionPlanner: identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:" planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") - prompt = planner_prompt_template.format( + return planner_prompt_template.format( time_block=time_block, by_what=by_what, chat_context_description=chat_context_description, @@ -295,8 +294,6 @@ class ActionPlanner: moderation_prompt=moderation_prompt_block, identity_block=identity_block, ) - return prompt - except Exception as e: logger.error(f"构建 Planner 提示词时出错: {e}") logger.error(traceback.format_exc()) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 6cb526d11..084dfd58c 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -130,9 +130,7 @@ class DefaultReplyer: # 提取权重,如果模型配置中没有'weight'键,则默认为1.0 weights = [config.get("weight", 1.0) for config in configs] - # random.choices 返回一个列表,我们取第一个元素 - selected_config = random.choices(population=configs, weights=weights, k=1)[0] - return selected_config + return random.choices(population=configs, weights=weights, k=1)[0] async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str): """创建思考消息 (尝试锚定到 anchor_message)""" @@ -314,8 +312,7 @@ class DefaultReplyer: logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID,跳过信息提取") return f"你完全不认识{sender},不理解ta的相关信息。" - relation_info = await relationship_fetcher.build_relation_info(person_id, text, chat_history) - return relation_info + return await relationship_fetcher.build_relation_info(person_id, text, chat_history) async def build_expression_habits(self, chat_history, target): if not global_config.expression.enable_expression: @@ -363,15 +360,13 @@ class DefaultReplyer: target_message=target, chat_history_prompt=chat_history ) - if running_memories: - memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" - for running_memory in running_memories: - memory_str += f"- {running_memory['content']}\n" - memory_block = memory_str - else: - memory_block = "" + if not running_memories: + return "" - return memory_block + memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" + for running_memory in running_memories: + memory_str += f"- {running_memory['content']}\n" + return memory_str async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True): """构建工具信息块 @@ -453,7 +448,7 @@ class DefaultReplyer: for name, content in result.groupdict().items(): reaction = reaction.replace(f"[{name}]", content) logger.info(f"匹配到正则表达式:{pattern_str},触发反应:{reaction}") - keywords_reaction_prompt += reaction + "," + keywords_reaction_prompt += f"{reaction}," break except re.error as e: logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}") @@ -477,7 +472,7 @@ class DefaultReplyer: available_actions: Optional[Dict[str, ActionInfo]] = None, enable_timeout: bool = False, enable_tool: bool = True, - ) -> str: + ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if """ 构建回复器上下文 @@ -612,7 +607,7 @@ class DefaultReplyer: short_impression = ["友好活泼", "人类"] personality = short_impression[0] identity = short_impression[1] - prompt_personality = personality + "," + identity + prompt_personality = f"{personality},{identity}" identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" moderation_prompt_block = ( @@ -660,7 +655,7 @@ class DefaultReplyer: "chat_target_private2", sender_name=chat_target_name ) - prompt = await global_prompt_manager.format_prompt( + return await global_prompt_manager.format_prompt( template_name, expression_habits_block=expression_habits_block, chat_target=chat_target_1, @@ -683,8 +678,6 @@ class DefaultReplyer: mood_state=mood_prompt, ) - return prompt - async def build_prompt_rewrite_context( self, reply_data: Dict[str, Any], @@ -745,7 +738,7 @@ class DefaultReplyer: short_impression = ["友好活泼", "人类"] personality = short_impression[0] identity = short_impression[1] - prompt_personality = personality + "," + identity + prompt_personality = f"{personality},{identity}" identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" moderation_prompt_block = ( @@ -790,7 +783,7 @@ class DefaultReplyer: template_name = "default_expressor_prompt" - prompt = await global_prompt_manager.format_prompt( + return await global_prompt_manager.format_prompt( template_name, expression_habits_block=expression_habits_block, relation_info_block=relation_info, @@ -807,8 +800,6 @@ class DefaultReplyer: moderation_prompt=moderation_prompt_block, ) - return prompt - async def send_response_messages( self, anchor_message: Optional[MessageRecv], @@ -816,6 +807,7 @@ class DefaultReplyer: thinking_id: str = "", display_message: str = "", ) -> Optional[MessageSending]: + # sourcery skip: assign-if-exp, boolean-if-exp-identity, remove-unnecessary-cast """发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender""" chat = self.chat_stream chat_id = self.chat_stream.stream_id @@ -849,16 +841,16 @@ class DefaultReplyer: for i, msg_text in enumerate(response_set): # 为每个消息片段生成唯一ID - type = msg_text[0] + msg_type = msg_text[0] data = msg_text[1] - if global_config.debug.debug_show_chat_mode and type == "text": + if global_config.debug.debug_show_chat_mode and msg_type == "text": data += "ᶠ" part_message_id = f"{thinking_id}_{i}" - message_segment = Seg(type=type, data=data) + message_segment = Seg(type=msg_type, data=data) - if type == "emoji": + if msg_type == "emoji": is_emoji = True else: is_emoji = False @@ -871,7 +863,6 @@ class DefaultReplyer: display_message=display_message, reply_to=reply_to, is_emoji=is_emoji, - thinking_id=thinking_id, thinking_start_time=thinking_start_time, ) @@ -895,7 +886,7 @@ class DefaultReplyer: reply_message_ids.append(part_message_id) # 记录我们生成的ID - sent_msg_list.append((type, sent_msg)) + sent_msg_list.append((msg_type, sent_msg)) except Exception as e: logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}") @@ -930,12 +921,9 @@ class DefaultReplyer: ) # await anchor_message.process() - if anchor_message: - sender_info = anchor_message.message_info.user_info - else: - sender_info = None + sender_info = anchor_message.message_info.user_info if anchor_message else None - bot_message = MessageSending( + return MessageSending( message_id=message_id, # 使用片段的唯一ID chat_stream=self.chat_stream, bot_user_info=bot_user_info, @@ -948,8 +936,6 @@ class DefaultReplyer: display_message=display_message, ) - return bot_message - def weighted_sample_no_replacement(items, weights, k) -> list: """ diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index 6a73b7d4b..a2a2aaaa0 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -1,4 +1,5 @@ from typing import Dict, Any, Optional, List + from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.replyer.default_generator import DefaultReplyer from src.common.logger import get_logger @@ -8,7 +9,7 @@ logger = get_logger("ReplyerManager") class ReplyerManager: def __init__(self): - self._replyers: Dict[str, DefaultReplyer] = {} + self._repliers: Dict[str, DefaultReplyer] = {} def get_replyer( self, @@ -29,17 +30,16 @@ class ReplyerManager: return None # 如果已有缓存实例,直接返回 - if stream_id in self._replyers: + if stream_id in self._repliers: logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 返回已存在的回复器实例。") - return self._replyers[stream_id] + return self._repliers[stream_id] # 如果没有缓存,则创建新实例(首次初始化) logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 创建新的回复器实例并缓存。") target_stream = chat_stream if not target_stream: - chat_manager = get_chat_manager() - if chat_manager: + if chat_manager := get_chat_manager(): target_stream = chat_manager.get_stream(stream_id) if not target_stream: @@ -52,7 +52,7 @@ class ReplyerManager: model_configs=model_configs, # 可以是None,此时使用默认模型 request_type=request_type, ) - self._replyers[stream_id] = replyer + self._repliers[stream_id] = replyer return replyer diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index ab97f395b..06044defb 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -1,14 +1,15 @@ -from src.config.config import global_config -from typing import List, Dict, Any, Tuple # 确保类型提示被导入 import time # 导入 time 模块以获取当前时间 import random import re -from src.common.message_repository import find_messages, count_messages -from src.person_info.person_info import PersonInfoManager, get_person_info_manager -from src.chat.utils.utils import translate_timestamp_to_human_readable +from typing import List, Dict, Any, Tuple, Optional from rich.traceback import install + +from src.config.config import global_config +from src.common.message_repository import find_messages, count_messages from src.common.database.database_model import ActionRecords from src.common.database.database_model import Images +from src.person_info.person_info import PersonInfoManager, get_person_info_manager +from src.chat.utils.utils import translate_timestamp_to_human_readable install(extra_lines=3) @@ -135,7 +136,7 @@ def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float = None) -> int: +def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: """ 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。 如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。 @@ -172,7 +173,7 @@ def _build_readable_messages_internal( merge_messages: bool = False, timestamp_mode: str = "relative", truncate: bool = False, - pic_id_mapping: Dict[str, str] = None, + pic_id_mapping: Optional[Dict[str, str]] = None, pic_counter: int = 1, show_pic: bool = True, ) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]: @@ -194,7 +195,7 @@ def _build_readable_messages_internal( if not messages: return "", [], pic_id_mapping or {}, pic_counter - message_details_raw: List[Tuple[float, str, str]] = [] + message_details_raw: List[Tuple[float, str, str, bool]] = [] # 使用传入的映射字典,如果没有则创建新的 if pic_id_mapping is None: @@ -225,7 +226,7 @@ def _build_readable_messages_internal( # 检查是否是动作记录 if msg.get("is_action_record", False): is_action = True - timestamp = msg.get("time") + timestamp: float = msg.get("time") # type: ignore content = msg.get("display_message", "") # 对于动作记录,也处理图片ID content = process_pic_ids(content) @@ -249,9 +250,10 @@ def _build_readable_messages_internal( user_nickname = user_info.get("user_nickname") user_cardname = user_info.get("user_cardname") - timestamp = msg.get("time") + timestamp: float = msg.get("time") # type: ignore + content: str if msg.get("display_message"): - content = msg.get("display_message") + content = msg.get("display_message", "") else: content = msg.get("processed_plain_text", "") # 默认空字符串 @@ -271,6 +273,7 @@ def _build_readable_messages_internal( person_id = PersonInfoManager.get_person_id(platform, user_id) person_info_manager = get_person_info_manager() # 根据 replace_bot_name 参数决定是否替换机器人名称 + person_name: str if replace_bot_name and user_id == global_config.bot.qq_account: person_name = f"{global_config.bot.nickname}(你)" else: @@ -289,12 +292,10 @@ def _build_readable_messages_internal( reply_pattern = r"回复<([^:<>]+):([^:<>]+)>" match = re.search(reply_pattern, content) if match: - aaa = match.group(1) - bbb = match.group(2) + aaa: str = match[1] + bbb: str = match[2] reply_person_id = PersonInfoManager.get_person_id(platform, bbb) - reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") - if not reply_person_name: - reply_person_name = aaa + reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") or aaa # 在内容前加上回复信息 content = re.sub(reply_pattern, lambda m, name=reply_person_name: f"回复 {name}", content, count=1) @@ -309,18 +310,15 @@ def _build_readable_messages_internal( aaa = m.group(1) bbb = m.group(2) at_person_id = PersonInfoManager.get_person_id(platform, bbb) - at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") - if not at_person_name: - at_person_name = aaa + at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") or aaa new_content += f"@{at_person_name}" last_end = m.end() new_content += content[last_end:] content = new_content target_str = "这是QQ的一个功能,用于提及某人,但没那么明显" - if target_str in content: - if random.random() < 0.6: - content = content.replace(target_str, "") + if target_str in content and random.random() < 0.6: + content = content.replace(target_str, "") if content != "": message_details_raw.append((timestamp, person_name, content, False)) @@ -470,6 +468,7 @@ def _build_readable_messages_internal( def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: + # sourcery skip: use-contextlib-suppress """ 构建图片映射信息字符串,显示图片的具体描述内容 @@ -518,9 +517,7 @@ async def build_readable_messages_with_list( messages, replace_bot_name, merge_messages, timestamp_mode, truncate ) - # 生成图片映射信息并添加到最前面 - pic_mapping_info = build_pic_mapping_info(pic_id_mapping) - if pic_mapping_info: + if pic_mapping_info := build_pic_mapping_info(pic_id_mapping): formatted_string = f"{pic_mapping_info}\n\n{formatted_string}" return formatted_string, details_list @@ -535,7 +532,7 @@ def build_readable_messages( truncate: bool = False, show_actions: bool = False, show_pic: bool = True, -) -> str: +) -> str: # sourcery skip: extract-method """ 将消息列表转换为可读的文本格式。 如果提供了 read_mark,则在相应位置插入已读标记。 @@ -658,9 +655,7 @@ def build_readable_messages( # 组合结果 result_parts = [] if pic_mapping_info: - result_parts.append(pic_mapping_info) - result_parts.append("\n") - + result_parts.extend((pic_mapping_info, "\n")) if formatted_before and formatted_after: result_parts.extend([formatted_before, read_mark_line, formatted_after]) elif formatted_before: @@ -733,8 +728,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: platform = msg.get("chat_info_platform") user_id = msg.get("user_id") _timestamp = msg.get("time") + content: str = "" if msg.get("display_message"): - content = msg.get("display_message") + content = msg.get("display_message", "") else: content = msg.get("processed_plain_text", "") @@ -829,10 +825,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: if not all([platform, user_id]) or user_id == global_config.bot.qq_account: continue - person_id = PersonInfoManager.get_person_id(platform, user_id) - - # 只有当获取到有效 person_id 时才添加 - if person_id: + if person_id := PersonInfoManager.get_person_id(platform, user_id): person_ids_set.add(person_id) return list(person_ids_set) # 将集合转换为列表返回 diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 17cfb2323..5579ccf84 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -103,7 +103,7 @@ class ImageManager: image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore # 查询缓存的描述 cached_description = self._get_description_from_db(image_hash, "emoji") @@ -154,7 +154,7 @@ class ImageManager: img_obj.description = description img_obj.timestamp = current_timestamp img_obj.save() - except Images.DoesNotExist: + except Images.DoesNotExist: # type: ignore Images.create( emoji_hash=image_hash, path=file_path, @@ -204,7 +204,7 @@ class ImageManager: return f"[图片:{cached_description}]" # 调用AI获取描述 - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来,请留意其主题,直观感受,输出为一段平文本,最多50字" description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) @@ -491,7 +491,7 @@ class ImageManager: return # 获取图片格式 - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore # 构建prompt prompt = """请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本""" diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 7f22fc2d4..f44a88225 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -3,7 +3,7 @@ from src.common.database.database import db from src.common.database.database_model import PersonInfo # 新增导入 import copy import hashlib -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Union import datetime import asyncio from src.llm_models.utils_model import LLMRequest @@ -84,7 +84,7 @@ class PersonInfoManager: logger.error(f"从 Peewee 加载 person_name_list 失败: {e}") @staticmethod - def get_person_id(platform: str, user_id: int): + def get_person_id(platform: str, user_id: Union[int, str]) -> str: """获取唯一id""" if "-" in platform: platform = platform.split("-")[1] diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 42e36b64d..73c883e0a 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -32,10 +32,10 @@ class BaseAction(ABC): reasoning: str, cycle_timers: dict, thinking_id: str, - chat_stream: ChatStream = None, + chat_stream: ChatStream, log_prefix: str = "", shutting_down: bool = False, - plugin_config: dict = None, + plugin_config: Optional[dict] = None, **kwargs, ): """初始化Action组件 diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 8977c5e70..2c2ddf81e 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -29,7 +29,7 @@ class BaseCommand(ABC): command_examples: List[str] = [] intercept_message: bool = True # 默认拦截消息,不继续处理 - def __init__(self, message: MessageRecv, plugin_config: dict = None): + def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None): """初始化Command组件 Args: diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index b8112a490..fe3813b88 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -66,7 +66,7 @@ class BasePlugin(ABC): config_section_descriptions: Dict[str, str] = {} - def __init__(self, plugin_dir: str = None): + def __init__(self, plugin_dir: str): """初始化插件 Args: @@ -526,7 +526,7 @@ class BasePlugin(ABC): # 从配置中更新 enable_plugin if "plugin" in self.config and "enabled" in self.config["plugin"]: - self.enable_plugin = self.config["plugin"]["enabled"] + self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}") else: logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml") diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index bc66100d9..2bac36e5c 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -81,7 +81,9 @@ class ComponentInfo: class ActionInfo(ComponentInfo): """动作组件信息""" - action_parameters: Dict[str, str] = field(default_factory=dict) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"} + action_parameters: Dict[str, str] = field( + default_factory=dict + ) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"} action_require: List[str] = field(default_factory=list) # 动作需求说明 associated_types: List[str] = field(default_factory=list) # 关联的消息类型 # 激活类型相关 diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 2ec77c7b7..b152a1abc 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Any, Pattern, Union +from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type import re from src.common.logger import get_logger from src.plugin_system.base.component_types import ( @@ -28,25 +28,25 @@ class ComponentRegistry: ComponentType.ACTION: {}, ComponentType.COMMAND: {}, } - self._component_classes: Dict[str, Union[BaseCommand, BaseAction]] = {} # 组件名 -> 组件类 + self._component_classes: Dict[str, Union[Type[BaseCommand], Type[BaseAction]]] = {} # 组件名 -> 组件类 # 插件注册表 self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息 # Action特定注册表 - self._action_registry: Dict[str, BaseAction] = {} # action名 -> action类 - # self._action_descriptions: Dict[str, str] = {} # 启用的action名 -> 描述 + self._action_registry: Dict[str, Type[BaseAction]] = {} # action名 -> action类 + self._default_actions: Dict[str, ActionInfo] = {} # 默认动作集,即启用的Action集,用于重置ActionManager状态 # Command特定注册表 - self._command_registry: Dict[str, BaseCommand] = {} # command名 -> command类 - self._command_patterns: Dict[Pattern, BaseCommand] = {} # 编译后的正则 -> command类 + self._command_registry: Dict[str, Type[BaseCommand]] = {} # command名 -> command类 + self._command_patterns: Dict[Pattern, Type[BaseCommand]] = {} # 编译后的正则 -> command类 logger.info("组件注册中心初始化完成") # === 通用组件注册方法 === def register_component( - self, component_info: ComponentInfo, component_class: Union[BaseCommand, BaseAction] + self, component_info: ComponentInfo, component_class: Union[Type[BaseCommand], Type[BaseAction]] ) -> bool: """注册组件 @@ -88,9 +88,9 @@ class ComponentRegistry: # 根据组件类型进行特定注册(使用原始名称) if component_type == ComponentType.ACTION: - self._register_action_component(component_info, component_class) + self._register_action_component(component_info, component_class) # type: ignore elif component_type == ComponentType.COMMAND: - self._register_command_component(component_info, component_class) + self._register_command_component(component_info, component_class) # type: ignore logger.debug( f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' " @@ -98,7 +98,7 @@ class ComponentRegistry: ) return True - def _register_action_component(self, action_info: ActionInfo, action_class: BaseAction): + def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]): # -------------------------------- NEED REFACTORING -------------------------------- # -------------------------------- LOGIC ERROR ------------------------------------- """注册Action组件到Action特定注册表""" @@ -106,11 +106,10 @@ class ComponentRegistry: self._action_registry[action_name] = action_class # 如果启用,添加到默认动作集 - # ---- HERE ---- - # if action_info.enabled: - # self._action_descriptions[action_name] = action_info.description + if action_info.enabled: + self._default_actions[action_name] = action_info - def _register_command_component(self, command_info: CommandInfo, command_class: BaseCommand): + def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]): """注册Command组件到Command特定注册表""" command_name = command_info.name self._command_registry[command_name] = command_class @@ -122,7 +121,7 @@ class ComponentRegistry: # === 组件查询方法 === - def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]: + def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]: # type: ignore # sourcery skip: class-extract-method """获取组件信息,支持自动命名空间解析 @@ -170,8 +169,10 @@ class ComponentRegistry: return None def get_component_class( - self, component_name: str, component_type: ComponentType = None - ) -> Optional[Union[BaseCommand, BaseAction]]: + self, + component_name: str, + component_type: ComponentType = None, # type: ignore + ) -> Optional[Union[Type[BaseCommand], Type[BaseAction]]]: """获取组件类,支持自动命名空间解析 Args: @@ -230,7 +231,7 @@ class ComponentRegistry: # === Action特定查询方法 === - def get_action_registry(self) -> Dict[str, BaseAction]: + def get_action_registry(self) -> Dict[str, Type[BaseAction]]: """获取Action注册表(用于兼容现有系统)""" return self._action_registry.copy() @@ -239,13 +240,17 @@ class ComponentRegistry: info = self.get_component_info(action_name, ComponentType.ACTION) return info if isinstance(info, ActionInfo) else None + def get_default_actions(self) -> Dict[str, ActionInfo]: + """获取默认动作集""" + return self._default_actions.copy() + # === Command特定查询方法 === - def get_command_registry(self) -> Dict[str, BaseCommand]: + def get_command_registry(self) -> Dict[str, Type[BaseCommand]]: """获取Command注册表(用于兼容现有系统)""" return self._command_registry.copy() - def get_command_patterns(self) -> Dict[Pattern, BaseCommand]: + def get_command_patterns(self) -> Dict[Pattern, Type[BaseCommand]]: """获取Command模式注册表(用于兼容现有系统)""" return self._command_patterns.copy() @@ -254,7 +259,7 @@ class ComponentRegistry: info = self.get_component_info(command_name, ComponentType.COMMAND) return info if isinstance(info, CommandInfo) else None - def find_command_by_text(self, text: str) -> Optional[tuple[BaseCommand, dict, bool, str]]: + def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]: # sourcery skip: use-named-expression, use-next """根据文本查找匹配的命令 @@ -262,7 +267,7 @@ class ComponentRegistry: text: 输入文本 Returns: - Optional[tuple[BaseCommand, dict, bool, str]]: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None + Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None """ for pattern, command_class in self._command_patterns.items():