From bed9c2bf6be1106d11e86d3511dbc3bf3cfd020f Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Mon, 7 Jul 2025 12:13:33 +0800 Subject: [PATCH 01/13] =?UTF-8?q?plugin=5Fmanager=20=E9=87=8D=E6=96=B0?= =?UTF-8?q?=E6=8B=86=E5=88=86=EF=BC=8C=E5=A2=9E=E5=8A=A0=E6=89=A9=E5=B1=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_system/base/base_plugin.py | 3 - src/plugin_system/core/plugin_manager.py | 634 +++++++------------ src/plugin_system/core/plugin_manager_bak.py | 570 +++++++++++++++++ src/plugin_system/events/__init__.py | 9 + src/plugin_system/events/events.py | 14 + 5 files changed, 827 insertions(+), 403 deletions(-) create mode 100644 src/plugin_system/core/plugin_manager_bak.py create mode 100644 src/plugin_system/events/__init__.py create mode 100644 src/plugin_system/events/events.py diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index 5c7edd23b..4044c12e9 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -568,9 +568,6 @@ class BasePlugin(ABC): def register_plugin(self) -> bool: """注册插件及其所有组件""" - if not self.enable_plugin: - logger.info(f"{self.log_prefix} 插件已禁用,跳过注册") - return False components = self.get_plugin_components() diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 3fc263a0d..fbd5de8cc 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -1,25 +1,24 @@ -from typing import Dict, List, Optional, Any, TYPE_CHECKING, Tuple +from typing import Dict, List, Optional, Callable, Tuple, Type, Any import os -import importlib -import importlib.util +from importlib.util import spec_from_file_location, module_from_spec +from inspect import getmodule from pathlib import Path import traceback -if TYPE_CHECKING: - from src.plugin_system.base.base_plugin import BasePlugin - from src.common.logger import get_logger +from src.plugin_system.events.events import EventType from src.plugin_system.core.component_registry import component_registry -from src.plugin_system.core.dependency_manager import dependency_manager -from src.plugin_system.base.component_types import ComponentType, PluginInfo +from src.plugin_system.base.base_plugin import BasePlugin +from src.plugin_system.utils.manifest_utils import VersionComparator logger = get_logger("plugin_manager") class PluginManager: - """插件管理器 + """ + 插件管理器类 - 负责加载、初始化和管理所有插件及其组件 + 负责加载,重载和卸载插件,同时管理插件的所有组件 """ def __init__(self): @@ -27,38 +26,42 @@ class PluginManager: self.loaded_plugins: Dict[str, "BasePlugin"] = {} self.failed_plugins: Dict[str, str] = {} self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射 + self.events_subscriptions: Dict[EventType, List[Callable]] = {} + self.plugin_classes: Dict[str, Type[BasePlugin]] = {} # 全局插件类注册表 # 确保插件目录存在 self._ensure_plugin_directories() logger.info("插件管理器初始化完成") - def _ensure_plugin_directories(self): - """确保所有插件目录存在,如果不存在则创建""" + def _ensure_plugin_directories(self) -> None: + """确保所有插件根目录存在,如果不存在则创建""" default_directories = ["src/plugins/built_in", "plugins"] for directory in default_directories: if not os.path.exists(directory): os.makedirs(directory, exist_ok=True) - logger.info(f"创建插件目录: {directory}") + logger.info(f"创建插件根目录: {directory}") if directory not in self.plugin_directories: self.plugin_directories.append(directory) - logger.debug(f"已添加插件目录: {directory}") + logger.debug(f"已添加插件根目录: {directory}") else: - logger.warning(f"插件不可重复加载: {directory}") + logger.warning(f"根目录不可重复加载: {directory}") - def add_plugin_directory(self, directory: str): + def add_plugin_directory(self, directory: str) -> bool: """添加插件目录""" if os.path.exists(directory): if directory not in self.plugin_directories: self.plugin_directories.append(directory) logger.debug(f"已添加插件目录: {directory}") + return True else: logger.warning(f"插件不可重复加载: {directory}") else: logger.warning(f"插件目录不存在: {directory}") + return False - def load_all_plugins(self) -> tuple[int, int]: - """加载所有插件目录中的插件 + def load_all_plugins(self) -> Tuple[int, int]: + """加载所有插件 Returns: tuple[int, int]: (插件数量, 组件数量) @@ -76,202 +79,102 @@ class PluginManager: logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}") - # 第二阶段:实例化所有已注册的插件类 - from src.plugin_system.base.base_plugin import get_registered_plugin_classes - - plugin_classes = get_registered_plugin_classes() total_registered = 0 total_failed_registration = 0 - for plugin_name, plugin_class in plugin_classes.items(): - try: - # 使用记录的插件目录路径 - plugin_dir = self.plugin_paths.get(plugin_name) - - # 如果没有记录,则尝试查找(fallback) - if not plugin_dir: - plugin_dir = self._find_plugin_directory(plugin_class) - if plugin_dir: - self.plugin_paths[plugin_name] = plugin_dir # 实例化插件(可能因为缺少manifest而失败) - plugin_instance = plugin_class(plugin_dir=plugin_dir) - - # 检查插件是否启用 - if not plugin_instance.enable_plugin: - logger.info(f"插件 {plugin_name} 已禁用,跳过加载") - continue - - # 检查版本兼容性 - is_compatible, compatibility_error = self.check_plugin_version_compatibility( - plugin_name, plugin_instance.manifest_data - ) - if not is_compatible: - total_failed_registration += 1 - self.failed_plugins[plugin_name] = compatibility_error - logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}") - continue - - if plugin_instance.register_plugin(): - total_registered += 1 - self.loaded_plugins[plugin_name] = plugin_instance - - # 📊 显示插件详细信息 - plugin_info = component_registry.get_plugin_info(plugin_name) - if plugin_info: - component_types = {} - for comp in plugin_info.components: - comp_type = comp.component_type.name - component_types[comp_type] = component_types.get(comp_type, 0) + 1 - - components_str = ", ".join([f"{count}个{ctype}" for ctype, count in component_types.items()]) - - # 显示manifest信息 - manifest_info = "" - if plugin_info.license: - manifest_info += f" [{plugin_info.license}]" - if plugin_info.keywords: - manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词 - if len(plugin_info.keywords) > 3: - manifest_info += "..." - - logger.info( - f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}" - ) - else: - logger.info(f"✅ 插件加载成功: {plugin_name}") - else: - total_failed_registration += 1 - self.failed_plugins[plugin_name] = "插件注册失败" - logger.error(f"❌ 插件注册失败: {plugin_name}") - - except FileNotFoundError as e: - # manifest文件缺失 + for plugin_name in self.plugin_classes.keys(): + if self.load_registered_plugin_classes(plugin_name): + total_registered += 1 + else: total_failed_registration += 1 - error_msg = f"缺少manifest文件: {str(e)}" - self.failed_plugins[plugin_name] = error_msg - logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") - except ValueError as e: - # manifest文件格式错误或验证失败 - traceback.print_exc() - total_failed_registration += 1 - error_msg = f"manifest验证失败: {str(e)}" - self.failed_plugins[plugin_name] = error_msg - logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") - - except Exception as e: - # 其他错误 - total_failed_registration += 1 - error_msg = f"未知错误: {str(e)}" - self.failed_plugins[plugin_name] = error_msg - logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") - logger.debug("详细错误信息: ", exc_info=True) - - # 获取组件统计信息 - stats = component_registry.get_registry_stats() - action_count = stats.get("action_components", 0) - command_count = stats.get("command_components", 0) - total_components = stats.get("total_components", 0) - - # 📋 显示插件加载总览 - if total_registered > 0: - logger.info("🎉 插件系统加载完成!") - logger.info( - f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count})" - ) - - # 显示详细的插件列表 logger.info("📋 已加载插件详情:") - for plugin_name, _plugin_class in self.loaded_plugins.items(): - plugin_info = component_registry.get_plugin_info(plugin_name) - if plugin_info: - # 插件基本信息 - version_info = f"v{plugin_info.version}" if plugin_info.version else "" - author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown" - license_info = f"[{plugin_info.license}]" if plugin_info.license else "" - info_parts = [part for part in [version_info, author_info, license_info] if part] - extra_info = f" ({', '.join(info_parts)})" if info_parts else "" - - logger.info(f" 📦 {plugin_name}{extra_info}") - - # Manifest信息 - if plugin_info.manifest_data: - if plugin_info.keywords: - logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}") - if plugin_info.categories: - logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}") - if plugin_info.homepage_url: - logger.info(f" 🌐 主页: {plugin_info.homepage_url}") - - # 组件列表 - if plugin_info.components: - action_components = [c for c in plugin_info.components if c.component_type.name == "ACTION"] - command_components = [c for c in plugin_info.components if c.component_type.name == "COMMAND"] - - if action_components: - action_names = [c.name for c in action_components] - logger.info(f" 🎯 Action组件: {', '.join(action_names)}") - - if command_components: - command_names = [c.name for c in command_components] - logger.info(f" ⚡ Command组件: {', '.join(command_names)}") - - # 版本兼容性信息 - if plugin_info.min_host_version or plugin_info.max_host_version: - version_range = "" - if plugin_info.min_host_version: - version_range += f">={plugin_info.min_host_version}" - if plugin_info.max_host_version: - if version_range: - version_range += f", <={plugin_info.max_host_version}" - else: - version_range += f"<={plugin_info.max_host_version}" - logger.info(f" 📋 兼容版本: {version_range}") - - # 依赖信息 - if plugin_info.dependencies: - logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}") - - # 配置文件信息 - if plugin_info.config_file: - config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌" - logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}") - - # 显示目录统计 - logger.info("📂 加载目录统计:") - for directory in self.plugin_directories: - if os.path.exists(directory): - plugins_in_dir = [] - for plugin_name in self.loaded_plugins.keys(): - plugin_path = self.plugin_paths.get(plugin_name, "") - if plugin_path.startswith(directory): - plugins_in_dir.append(plugin_name) - - if plugins_in_dir: - logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})") - else: - logger.info(f" 📁 {directory}: 0个插件") - - # 失败信息 - if total_failed_registration > 0: - logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败") - for failed_plugin, error in self.failed_plugins.items(): - logger.info(f" ❌ {failed_plugin}: {error}") - else: - logger.warning("😕 没有成功加载任何插件") - - # 返回插件数量和组件数量 - return total_registered, total_components - - def _find_plugin_directory(self, plugin_class) -> Optional[str]: - """查找插件类对应的目录路径""" + def load_registered_plugin_classes(self, plugin_name: str) -> bool: + # sourcery skip: extract-duplicate-method, extract-method + """ + 加载已经注册的插件类 + """ + plugin_class: Type[BasePlugin] = self.plugin_classes.get(plugin_name) + if not plugin_class: + logger.error(f"插件 {plugin_name} 的插件类未注册或不存在") + return False try: - import inspect + # 使用记录的插件目录路径 + plugin_dir = self.plugin_paths.get(plugin_name) + + # 如果没有记录,则尝试查找(fallback) + if not plugin_dir: + plugin_dir = self._find_plugin_directory(plugin_class) + if plugin_dir: + self.plugin_paths[plugin_name] = plugin_dir # 更新路径 + plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败) + # 检查插件是否启用 + if not plugin_instance.enable_plugin: + logger.info(f"插件 {plugin_name} 已禁用,跳过加载") + return False + + # 检查版本兼容性 + is_compatible, compatibility_error = self._check_plugin_version_compatibility( + plugin_name, plugin_instance.manifest_data + ) + if not is_compatible: + self.failed_plugins[plugin_name] = compatibility_error + logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}") + return False + if plugin_instance.register_plugin(): + self.loaded_plugins[plugin_name] = plugin_instance + self._show_plugin_components(plugin_name) + return True + else: + self.failed_plugins[plugin_name] = "插件注册失败" + logger.error(f"❌ 插件注册失败: {plugin_name}") + return False + + except FileNotFoundError as e: + # manifest文件缺失 + error_msg = f"缺少manifest文件: {str(e)}" + self.failed_plugins[plugin_name] = error_msg + logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") + return False + + except ValueError as e: + # manifest文件格式错误或验证失败 + traceback.print_exc() + error_msg = f"manifest验证失败: {str(e)}" + self.failed_plugins[plugin_name] = error_msg + logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") + return False - module = inspect.getmodule(plugin_class) - if module and hasattr(module, "__file__") and module.__file__: - return os.path.dirname(module.__file__) except Exception as e: - logger.debug(f"通过inspect获取插件目录失败: {e}") - return None + # 其他错误 + error_msg = f"未知错误: {str(e)}" + self.failed_plugins[plugin_name] = error_msg + logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") + logger.debug("详细错误信息: ", exc_info=True) + return False + + def unload_registered_plugin_module(self, plugin_name: str) -> None: + """ + 卸载插件模块 + """ + pass + + def reload_registered_plugin_module(self, plugin_name: str) -> None: + """ + 重载插件模块 + """ + self.unload_registered_plugin_module(plugin_name) + self.load_registered_plugin_classes(plugin_name) + + def rescan_plugin_directory(self) -> None: + """ + 重新扫描插件根目录 + """ + for directory in self.plugin_directories: + if os.path.exists(directory): + logger.debug(f"重新扫描插件根目录: {directory}") + self._load_plugin_modules_from_directory(directory) + else: + logger.warning(f"插件根目录不存在: {directory}") def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]: """从指定目录加载插件模块""" @@ -279,10 +182,10 @@ class PluginManager: failed_count = 0 if not os.path.exists(directory): - logger.warning(f"插件目录不存在: {directory}") - return loaded_count, failed_count + logger.warning(f"插件根目录不存在: {directory}") + return 0, 1 - logger.debug(f"正在扫描插件目录: {directory}") + logger.debug(f"正在扫描插件根目录: {directory}") # 遍历目录中的所有Python文件和包 for item in os.listdir(directory): @@ -308,7 +211,18 @@ class PluginManager: return loaded_count, failed_count + def _find_plugin_directory(self, plugin_class: str) -> Optional[str]: + """查找插件类对应的目录路径""" + try: + module = getmodule(plugin_class) + if module and hasattr(module, "__file__") and module.__file__: + return os.path.dirname(module.__file__) + except Exception as e: + logger.debug(f"通过inspect获取插件目录失败: {e}") + return None + def _load_plugin_module_file(self, plugin_file: str, plugin_name: str, plugin_dir: str) -> bool: + # sourcery skip: extract-method """加载单个插件模块文件 Args: @@ -327,12 +241,12 @@ class PluginManager: try: # 动态导入插件模块 - spec = importlib.util.spec_from_file_location(module_name, plugin_file) + spec = spec_from_file_location(module_name, plugin_file) if spec is None or spec.loader is None: logger.error(f"无法创建模块规范: {plugin_file}") return False - module = importlib.util.module_from_spec(spec) + module = module_from_spec(spec) spec.loader.exec_module(module) # 记录插件名和目录路径的映射 @@ -347,177 +261,7 @@ class PluginManager: self.failed_plugins[plugin_name] = error_msg return False - def get_loaded_plugins(self) -> List[PluginInfo]: - """获取所有已加载的插件信息""" - return list(component_registry.get_all_plugins().values()) - - def get_enabled_plugins(self) -> List[PluginInfo]: - """获取所有启用的插件信息""" - return list(component_registry.get_enabled_plugins().values()) - - def enable_plugin(self, plugin_name: str) -> bool: - """启用插件""" - plugin_info = component_registry.get_plugin_info(plugin_name) - if plugin_info: - plugin_info.enabled = True - # 启用插件的所有组件 - for component in plugin_info.components: - component_registry.enable_component(component.name) - logger.debug(f"已启用插件: {plugin_name}") - return True - return False - - def disable_plugin(self, plugin_name: str) -> bool: - """禁用插件""" - plugin_info = component_registry.get_plugin_info(plugin_name) - if plugin_info: - plugin_info.enabled = False - # 禁用插件的所有组件 - for component in plugin_info.components: - component_registry.disable_component(component.name) - logger.debug(f"已禁用插件: {plugin_name}") - return True - return False - - def get_plugin_instance(self, plugin_name: str) -> Optional["BasePlugin"]: - """获取插件实例 - - Args: - plugin_name: 插件名称 - - Returns: - Optional[BasePlugin]: 插件实例或None - """ - return self.loaded_plugins.get(plugin_name) - - def get_plugin_stats(self) -> Dict[str, Any]: - """获取插件统计信息""" - all_plugins = component_registry.get_all_plugins() - enabled_plugins = component_registry.get_enabled_plugins() - - action_components = component_registry.get_components_by_type(ComponentType.ACTION) - command_components = component_registry.get_components_by_type(ComponentType.COMMAND) - - return { - "total_plugins": len(all_plugins), - "enabled_plugins": len(enabled_plugins), - "failed_plugins": len(self.failed_plugins), - "total_components": len(action_components) + len(command_components), - "action_components": len(action_components), - "command_components": len(command_components), - "loaded_plugin_files": len(self.loaded_plugins), - "failed_plugin_details": self.failed_plugins.copy(), - } - - def reload_plugin(self, plugin_name: str) -> bool: - """重新加载插件(高级功能,需要谨慎使用)""" - # TODO: 实现插件热重载功能 - logger.warning("插件热重载功能尚未实现") - return False - - def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, any]: - """检查所有插件的Python依赖包 - - Args: - auto_install: 是否自动安装缺失的依赖包 - - Returns: - Dict[str, any]: 检查结果摘要 - """ - logger.info("开始检查所有插件的Python依赖包...") - - all_required_missing = [] - all_optional_missing = [] - plugin_status = {} - - for plugin_name, _plugin_instance in self.loaded_plugins.items(): - plugin_info = component_registry.get_plugin_info(plugin_name) - if not plugin_info or not plugin_info.python_dependencies: - plugin_status[plugin_name] = {"status": "no_dependencies", "missing": []} - continue - - logger.info(f"检查插件 {plugin_name} 的依赖...") - - missing_required, missing_optional = dependency_manager.check_dependencies(plugin_info.python_dependencies) - - if missing_required: - all_required_missing.extend(missing_required) - plugin_status[plugin_name] = { - "status": "missing_required", - "missing": [dep.package_name for dep in missing_required], - "optional_missing": [dep.package_name for dep in missing_optional], - } - logger.error(f"插件 {plugin_name} 缺少必需依赖: {[dep.package_name for dep in missing_required]}") - elif missing_optional: - all_optional_missing.extend(missing_optional) - plugin_status[plugin_name] = { - "status": "missing_optional", - "missing": [], - "optional_missing": [dep.package_name for dep in missing_optional], - } - logger.warning(f"插件 {plugin_name} 缺少可选依赖: {[dep.package_name for dep in missing_optional]}") - else: - plugin_status[plugin_name] = {"status": "ok", "missing": []} - logger.info(f"插件 {plugin_name} 依赖检查通过") - - # 汇总结果 - total_missing = len(set(dep.package_name for dep in all_required_missing)) - total_optional_missing = len(set(dep.package_name for dep in all_optional_missing)) - - logger.info(f"依赖检查完成 - 缺少必需包: {total_missing}个, 缺少可选包: {total_optional_missing}个") - - # 如果需要自动安装 - install_success = True - if auto_install and all_required_missing: - # 去重 - unique_required = {} - for dep in all_required_missing: - unique_required[dep.package_name] = dep - - logger.info(f"开始自动安装 {len(unique_required)} 个必需依赖包...") - install_success = dependency_manager.install_dependencies(list(unique_required.values()), auto_install=True) - - return { - "total_plugins_checked": len(plugin_status), - "plugins_with_missing_required": len( - [p for p in plugin_status.values() if p["status"] == "missing_required"] - ), - "plugins_with_missing_optional": len( - [p for p in plugin_status.values() if p["status"] == "missing_optional"] - ), - "total_missing_required": total_missing, - "total_missing_optional": total_optional_missing, - "plugin_status": plugin_status, - "auto_install_attempted": auto_install and bool(all_required_missing), - "auto_install_success": install_success, - "install_summary": dependency_manager.get_install_summary(), - } - - def generate_plugin_requirements(self, output_path: str = "plugin_requirements.txt") -> bool: - """生成所有插件依赖的requirements文件 - - Args: - output_path: 输出文件路径 - - Returns: - bool: 生成是否成功 - """ - logger.info("开始生成插件依赖requirements文件...") - - all_dependencies = [] - - for plugin_name, _plugin_instance in self.loaded_plugins.items(): - plugin_info = component_registry.get_plugin_info(plugin_name) - if plugin_info and plugin_info.python_dependencies: - all_dependencies.append(plugin_info.python_dependencies) - - if not all_dependencies: - logger.info("没有找到任何插件依赖") - return False - - return dependency_manager.generate_requirements_file(all_dependencies, output_path) - - def check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: + def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: """检查插件版本兼容性 Args: @@ -528,8 +272,7 @@ class PluginManager: Tuple[bool, str]: (是否兼容, 错误信息) """ if "host_application" not in manifest_data: - # 没有版本要求,默认兼容 - return True, "" + return True, "" # 没有版本要求,默认兼容 host_app = manifest_data["host_application"] if not isinstance(host_app, dict): @@ -539,31 +282,122 @@ class PluginManager: max_version = host_app.get("max_version", "") if not min_version and not max_version: - return True, "" + return True, "" # 没有版本要求,默认兼容 try: - from src.plugin_system.utils.manifest_utils import VersionComparator - current_version = VersionComparator.get_current_host_version() is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version) - if not is_compatible: return False, f"版本不兼容: {error_msg}" - else: - logger.debug(f"插件 {plugin_name} 版本兼容性检查通过") - return True, "" + logger.debug(f"插件 {plugin_name} 版本兼容性检查通过") + return True, "" except Exception as e: logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}") - return True, "" # 检查失败时默认允许加载 + return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载 + def _show_stats(self, total_registered: int, total_failed_registration: int): + # sourcery skip: low-code-quality + # 获取组件统计信息 + stats = component_registry.get_registry_stats() + action_count = stats.get("action_components", 0) + command_count = stats.get("command_components", 0) + total_components = stats.get("total_components", 0) -# 全局插件管理器实例 -plugin_manager = PluginManager() + # 📋 显示插件加载总览 + if total_registered > 0: + logger.info("🎉 插件系统加载完成!") + logger.info( + f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count})" + ) -# 注释掉以解决插件目录重复加载的情况 -# 默认插件目录 -# plugin_manager.add_plugin_directory("src/plugins/built_in") -# plugin_manager.add_plugin_directory("src/plugins/examples") -# 用户插件目录 -# plugin_manager.add_plugin_directory("plugins") + # 显示详细的插件列表 + logger.info("📋 已加载插件详情:") + for plugin_name in self.loaded_plugins.keys(): + if plugin_info := component_registry.get_plugin_info(plugin_name): + # 插件基本信息 + version_info = f"v{plugin_info.version}" if plugin_info.version else "" + author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown" + license_info = f"[{plugin_info.license}]" if plugin_info.license else "" + info_parts = [part for part in [version_info, author_info, license_info] if part] + extra_info = f" ({', '.join(info_parts)})" if info_parts else "" + + logger.info(f" 📦 {plugin_name}{extra_info}") + + # Manifest信息 + if plugin_info.manifest_data: + if plugin_info.keywords: + logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}") + if plugin_info.categories: + logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}") + if plugin_info.homepage_url: + logger.info(f" 🌐 主页: {plugin_info.homepage_url}") + + # 组件列表 + if plugin_info.components: + action_components = [c for c in plugin_info.components if c.component_type.name == "ACTION"] + command_components = [c for c in plugin_info.components if c.component_type.name == "COMMAND"] + + if action_components: + action_names = [c.name for c in action_components] + logger.info(f" 🎯 Action组件: {', '.join(action_names)}") + + if command_components: + command_names = [c.name for c in command_components] + logger.info(f" ⚡ Command组件: {', '.join(command_names)}") + + # 依赖信息 + if plugin_info.dependencies: + logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}") + + # 配置文件信息 + if plugin_info.config_file: + config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌" + logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}") + + # 显示目录统计 + logger.info("📂 加载目录统计:") + for directory in self.plugin_directories: + if os.path.exists(directory): + plugins_in_dir = [] + for plugin_name in self.loaded_plugins.keys(): + plugin_path = self.plugin_paths.get(plugin_name, "") + if plugin_path.startswith(directory): + plugins_in_dir.append(plugin_name) + + if plugins_in_dir: + logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})") + else: + logger.info(f" 📁 {directory}: 0个插件") + + # 失败信息 + if total_failed_registration > 0: + logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败") + for failed_plugin, error in self.failed_plugins.items(): + logger.info(f" ❌ {failed_plugin}: {error}") + else: + logger.warning("😕 没有成功加载任何插件") + + def _show_plugin_components(self, plugin_name: str) -> None: + if plugin_info := component_registry.get_plugin_info(plugin_name): + component_types = {} + for comp in plugin_info.components: + comp_type = comp.component_type.name + component_types[comp_type] = component_types.get(comp_type, 0) + 1 + + components_str = ", ".join([f"{count}个{ctype}" for ctype, count in component_types.items()]) + + # 显示manifest信息 + manifest_info = "" + if plugin_info.license: + manifest_info += f" [{plugin_info.license}]" + if plugin_info.keywords: + manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词 + if len(plugin_info.keywords) > 3: + manifest_info += "..." + + logger.info( + f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}" + ) + else: + logger.info(f"✅ 插件加载成功: {plugin_name}") diff --git a/src/plugin_system/core/plugin_manager_bak.py b/src/plugin_system/core/plugin_manager_bak.py new file mode 100644 index 000000000..7bb74b6ee --- /dev/null +++ b/src/plugin_system/core/plugin_manager_bak.py @@ -0,0 +1,570 @@ +from typing import Dict, List, Optional, Any, TYPE_CHECKING, Tuple +import os +import importlib +import importlib.util +from pathlib import Path +import traceback +from src.plugin_system.base.component_types import PythonDependency + +if TYPE_CHECKING: + from src.plugin_system.base.base_plugin import BasePlugin + +from src.common.logger import get_logger +from src.plugin_system.core.component_registry import component_registry +from src.plugin_system.core.dependency_manager import dependency_manager +from src.plugin_system.base.component_types import ComponentType, PluginInfo + +logger = get_logger("plugin_manager") + + +class PluginManager: + """插件管理器 + + 负责加载、初始化和管理所有插件及其组件 + """ + + def __init__(self): + self.plugin_directories: List[str] = [] + self.loaded_plugins: Dict[str, "BasePlugin"] = {} + self.failed_plugins: Dict[str, str] = {} + self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射 + + # 确保插件目录存在 + self._ensure_plugin_directories() + logger.info("插件管理器初始化完成") + + def _ensure_plugin_directories(self): + """确保所有插件目录存在,如果不存在则创建""" + default_directories = ["src/plugins/built_in", "plugins"] + + for directory in default_directories: + if not os.path.exists(directory): + os.makedirs(directory, exist_ok=True) + logger.info(f"创建插件目录: {directory}") + if directory not in self.plugin_directories: + self.plugin_directories.append(directory) + logger.debug(f"已添加插件目录: {directory}") + else: + logger.warning(f"插件不可重复加载: {directory}") + + def add_plugin_directory(self, directory: str): + """添加插件目录""" + if os.path.exists(directory): + if directory not in self.plugin_directories: + self.plugin_directories.append(directory) + logger.debug(f"已添加插件目录: {directory}") + else: + logger.warning(f"插件不可重复加载: {directory}") + else: + logger.warning(f"插件目录不存在: {directory}") + + def load_all_plugins(self) -> tuple[int, int]: + """加载所有插件目录中的插件 + + Returns: + tuple[int, int]: (插件数量, 组件数量) + """ + logger.debug("开始加载所有插件...") + + # 第一阶段:加载所有插件模块(注册插件类) + total_loaded_modules = 0 + total_failed_modules = 0 + + for directory in self.plugin_directories: + loaded, failed = self._load_plugin_modules_from_directory(directory) + total_loaded_modules += loaded + total_failed_modules += failed + + logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}") + + # 第二阶段:实例化所有已注册的插件类 + from src.plugin_system.base.base_plugin import get_registered_plugin_classes + + plugin_classes = get_registered_plugin_classes() + total_registered = 0 + total_failed_registration = 0 + + for plugin_name, plugin_class in plugin_classes.items(): + try: + # 使用记录的插件目录路径 + plugin_dir = self.plugin_paths.get(plugin_name) + + # 如果没有记录,则尝试查找(fallback) + if not plugin_dir: + plugin_dir = self._find_plugin_directory(plugin_class) + if plugin_dir: + self.plugin_paths[plugin_name] = plugin_dir # 实例化插件(可能因为缺少manifest而失败) + plugin_instance = plugin_class(plugin_dir=plugin_dir) + + # 检查插件是否启用 + if not plugin_instance.enable_plugin: + logger.info(f"插件 {plugin_name} 已禁用,跳过加载") + continue + + # 检查版本兼容性 + is_compatible, compatibility_error = self.check_plugin_version_compatibility( + plugin_name, plugin_instance.manifest_data + ) + if not is_compatible: + total_failed_registration += 1 + self.failed_plugins[plugin_name] = compatibility_error + logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}") + continue + + if plugin_instance.register_plugin(): + total_registered += 1 + self.loaded_plugins[plugin_name] = plugin_instance + + # 📊 显示插件详细信息 + plugin_info = component_registry.get_plugin_info(plugin_name) + if plugin_info: + component_types = {} + for comp in plugin_info.components: + comp_type = comp.component_type.name + component_types[comp_type] = component_types.get(comp_type, 0) + 1 + + components_str = ", ".join([f"{count}个{ctype}" for ctype, count in component_types.items()]) + + # 显示manifest信息 + manifest_info = "" + if plugin_info.license: + manifest_info += f" [{plugin_info.license}]" + if plugin_info.keywords: + manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词 + if len(plugin_info.keywords) > 3: + manifest_info += "..." + + logger.info( + f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}" + ) + else: + logger.info(f"✅ 插件加载成功: {plugin_name}") + else: + total_failed_registration += 1 + self.failed_plugins[plugin_name] = "插件注册失败" + logger.error(f"❌ 插件注册失败: {plugin_name}") + + except FileNotFoundError as e: + # manifest文件缺失 + total_failed_registration += 1 + error_msg = f"缺少manifest文件: {str(e)}" + self.failed_plugins[plugin_name] = error_msg + logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") + + except ValueError as e: + # manifest文件格式错误或验证失败 + traceback.print_exc() + total_failed_registration += 1 + error_msg = f"manifest验证失败: {str(e)}" + self.failed_plugins[plugin_name] = error_msg + logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") + + except Exception as e: + # 其他错误 + total_failed_registration += 1 + error_msg = f"未知错误: {str(e)}" + self.failed_plugins[plugin_name] = error_msg + logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") + logger.debug("详细错误信息: ", exc_info=True) + + # 获取组件统计信息 + stats = component_registry.get_registry_stats() + action_count = stats.get("action_components", 0) + command_count = stats.get("command_components", 0) + total_components = stats.get("total_components", 0) + + # 📋 显示插件加载总览 + if total_registered > 0: + logger.info("🎉 插件系统加载完成!") + logger.info( + f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count})" + ) + + # 显示详细的插件列表 logger.info("📋 已加载插件详情:") + for plugin_name, _plugin_class in self.loaded_plugins.items(): + plugin_info = component_registry.get_plugin_info(plugin_name) + if plugin_info: + # 插件基本信息 + version_info = f"v{plugin_info.version}" if plugin_info.version else "" + author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown" + license_info = f"[{plugin_info.license}]" if plugin_info.license else "" + info_parts = [part for part in [version_info, author_info, license_info] if part] + extra_info = f" ({', '.join(info_parts)})" if info_parts else "" + + logger.info(f" 📦 {plugin_name}{extra_info}") + + # Manifest信息 + if plugin_info.manifest_data: + if plugin_info.keywords: + logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}") + if plugin_info.categories: + logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}") + if plugin_info.homepage_url: + logger.info(f" 🌐 主页: {plugin_info.homepage_url}") + + # 组件列表 + if plugin_info.components: + action_components = [c for c in plugin_info.components if c.component_type.name == "ACTION"] + command_components = [c for c in plugin_info.components if c.component_type.name == "COMMAND"] + + if action_components: + action_names = [c.name for c in action_components] + logger.info(f" 🎯 Action组件: {', '.join(action_names)}") + + if command_components: + command_names = [c.name for c in command_components] + logger.info(f" ⚡ Command组件: {', '.join(command_names)}") + + # 版本兼容性信息 + if plugin_info.min_host_version or plugin_info.max_host_version: + version_range = "" + if plugin_info.min_host_version: + version_range += f">={plugin_info.min_host_version}" + if plugin_info.max_host_version: + if version_range: + version_range += f", <={plugin_info.max_host_version}" + else: + version_range += f"<={plugin_info.max_host_version}" + logger.info(f" 📋 兼容版本: {version_range}") + + # 依赖信息 + if plugin_info.dependencies: + logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}") + + # 配置文件信息 + if plugin_info.config_file: + config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌" + logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}") + + # 显示目录统计 + logger.info("📂 加载目录统计:") + for directory in self.plugin_directories: + if os.path.exists(directory): + plugins_in_dir = [] + for plugin_name in self.loaded_plugins.keys(): + plugin_path = self.plugin_paths.get(plugin_name, "") + if plugin_path.startswith(directory): + plugins_in_dir.append(plugin_name) + + if plugins_in_dir: + logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})") + else: + logger.info(f" 📁 {directory}: 0个插件") + + # 失败信息 + if total_failed_registration > 0: + logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败") + for failed_plugin, error in self.failed_plugins.items(): + logger.info(f" ❌ {failed_plugin}: {error}") + else: + logger.warning("😕 没有成功加载任何插件") + + # 返回插件数量和组件数量 + return total_registered, total_components + + def _find_plugin_directory(self, plugin_class) -> Optional[str]: + """查找插件类对应的目录路径""" + try: + import inspect + + module = inspect.getmodule(plugin_class) + if module and hasattr(module, "__file__") and module.__file__: + return os.path.dirname(module.__file__) + except Exception as e: + logger.debug(f"通过inspect获取插件目录失败: {e}") + return None + + def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]: + """从指定目录加载插件模块""" + loaded_count = 0 + failed_count = 0 + + if not os.path.exists(directory): + logger.warning(f"插件目录不存在: {directory}") + return loaded_count, failed_count + + logger.debug(f"正在扫描插件目录: {directory}") + + # 遍历目录中的所有Python文件和包 + for item in os.listdir(directory): + item_path = os.path.join(directory, item) + + if os.path.isfile(item_path) and item.endswith(".py") and item != "__init__.py": + # 单文件插件 + plugin_name = Path(item_path).stem + if self._load_plugin_module_file(item_path, plugin_name, directory): + loaded_count += 1 + else: + failed_count += 1 + + elif os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"): + # 插件包 + plugin_file = os.path.join(item_path, "plugin.py") + if os.path.exists(plugin_file): + plugin_name = item # 使用目录名作为插件名 + if self._load_plugin_module_file(plugin_file, plugin_name, item_path): + loaded_count += 1 + else: + failed_count += 1 + + return loaded_count, failed_count + + def _load_plugin_module_file(self, plugin_file: str, plugin_name: str, plugin_dir: str) -> bool: + """加载单个插件模块文件 + + Args: + plugin_file: 插件文件路径 + plugin_name: 插件名称 + plugin_dir: 插件目录路径 + """ + # 生成模块名 + plugin_path = Path(plugin_file) + if plugin_path.parent.name != "plugins": + # 插件包格式:parent_dir.plugin + module_name = f"plugins.{plugin_path.parent.name}.plugin" + else: + # 单文件格式:plugins.filename + module_name = f"plugins.{plugin_path.stem}" + + try: + # 动态导入插件模块 + spec = importlib.util.spec_from_file_location(module_name, plugin_file) + if spec is None or spec.loader is None: + logger.error(f"无法创建模块规范: {plugin_file}") + return False + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # 记录插件名和目录路径的映射 + self.plugin_paths[plugin_name] = plugin_dir + + logger.debug(f"插件模块加载成功: {plugin_file}") + return True + + except Exception as e: + error_msg = f"加载插件模块 {plugin_file} 失败: {e}" + logger.error(error_msg) + self.failed_plugins[plugin_name] = error_msg + return False + + def get_loaded_plugins(self) -> List[PluginInfo]: + """获取所有已加载的插件信息""" + return list(component_registry.get_all_plugins().values()) + + def get_enabled_plugins(self) -> List[PluginInfo]: + """获取所有启用的插件信息""" + return list(component_registry.get_enabled_plugins().values()) + + def enable_plugin(self, plugin_name: str) -> bool: + """启用插件""" + plugin_info = component_registry.get_plugin_info(plugin_name) + if plugin_info: + plugin_info.enabled = True + # 启用插件的所有组件 + for component in plugin_info.components: + component_registry.enable_component(component.name) + logger.debug(f"已启用插件: {plugin_name}") + return True + return False + + def disable_plugin(self, plugin_name: str) -> bool: + """禁用插件""" + plugin_info = component_registry.get_plugin_info(plugin_name) + if plugin_info: + plugin_info.enabled = False + # 禁用插件的所有组件 + for component in plugin_info.components: + component_registry.disable_component(component.name) + logger.debug(f"已禁用插件: {plugin_name}") + return True + return False + + def get_plugin_instance(self, plugin_name: str) -> Optional["BasePlugin"]: + """获取插件实例 + + Args: + plugin_name: 插件名称 + + Returns: + Optional[BasePlugin]: 插件实例或None + """ + return self.loaded_plugins.get(plugin_name) + + def get_plugin_stats(self) -> Dict[str, Any]: + """获取插件统计信息""" + all_plugins = component_registry.get_all_plugins() + enabled_plugins = component_registry.get_enabled_plugins() + + action_components = component_registry.get_components_by_type(ComponentType.ACTION) + command_components = component_registry.get_components_by_type(ComponentType.COMMAND) + + return { + "total_plugins": len(all_plugins), + "enabled_plugins": len(enabled_plugins), + "failed_plugins": len(self.failed_plugins), + "total_components": len(action_components) + len(command_components), + "action_components": len(action_components), + "command_components": len(command_components), + "loaded_plugin_files": len(self.loaded_plugins), + "failed_plugin_details": self.failed_plugins.copy(), + } + + def reload_plugin(self, plugin_name: str) -> bool: + """重新加载插件(高级功能,需要谨慎使用)""" + # TODO: 实现插件热重载功能 + logger.warning("插件热重载功能尚未实现") + return False + + def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, any]: + """检查所有插件的Python依赖包 + + Args: + auto_install: 是否自动安装缺失的依赖包 + + Returns: + Dict[str, any]: 检查结果摘要 + """ + logger.info("开始检查所有插件的Python依赖包...") + + all_required_missing: List[PythonDependency] = [] + all_optional_missing: List[PythonDependency] = [] + plugin_status = {} + + for plugin_name, _plugin_instance in self.loaded_plugins.items(): + plugin_info = component_registry.get_plugin_info(plugin_name) + if not plugin_info or not plugin_info.python_dependencies: + plugin_status[plugin_name] = {"status": "no_dependencies", "missing": []} + continue + + logger.info(f"检查插件 {plugin_name} 的依赖...") + + missing_required, missing_optional = dependency_manager.check_dependencies(plugin_info.python_dependencies) + + if missing_required: + all_required_missing.extend(missing_required) + plugin_status[plugin_name] = { + "status": "missing_required", + "missing": [dep.package_name for dep in missing_required], + "optional_missing": [dep.package_name for dep in missing_optional], + } + logger.error(f"插件 {plugin_name} 缺少必需依赖: {[dep.package_name for dep in missing_required]}") + elif missing_optional: + all_optional_missing.extend(missing_optional) + plugin_status[plugin_name] = { + "status": "missing_optional", + "missing": [], + "optional_missing": [dep.package_name for dep in missing_optional], + } + logger.warning(f"插件 {plugin_name} 缺少可选依赖: {[dep.package_name for dep in missing_optional]}") + else: + plugin_status[plugin_name] = {"status": "ok", "missing": []} + logger.info(f"插件 {plugin_name} 依赖检查通过") + + # 汇总结果 + total_missing = len({dep.package_name for dep in all_required_missing}) + total_optional_missing = len({dep.package_name for dep in all_optional_missing}) + + logger.info(f"依赖检查完成 - 缺少必需包: {total_missing}个, 缺少可选包: {total_optional_missing}个") + + # 如果需要自动安装 + install_success = True + if auto_install and all_required_missing: + # 去重 + unique_required = {} + for dep in all_required_missing: + unique_required[dep.package_name] = dep + + logger.info(f"开始自动安装 {len(unique_required)} 个必需依赖包...") + install_success = dependency_manager.install_dependencies(list(unique_required.values()), auto_install=True) + + return { + "total_plugins_checked": len(plugin_status), + "plugins_with_missing_required": len( + [p for p in plugin_status.values() if p["status"] == "missing_required"] + ), + "plugins_with_missing_optional": len( + [p for p in plugin_status.values() if p["status"] == "missing_optional"] + ), + "total_missing_required": total_missing, + "total_missing_optional": total_optional_missing, + "plugin_status": plugin_status, + "auto_install_attempted": auto_install and bool(all_required_missing), + "auto_install_success": install_success, + "install_summary": dependency_manager.get_install_summary(), + } + + def generate_plugin_requirements(self, output_path: str = "plugin_requirements.txt") -> bool: + """生成所有插件依赖的requirements文件 + + Args: + output_path: 输出文件路径 + + Returns: + bool: 生成是否成功 + """ + logger.info("开始生成插件依赖requirements文件...") + + all_dependencies = [] + + for plugin_name, _plugin_instance in self.loaded_plugins.items(): + plugin_info = component_registry.get_plugin_info(plugin_name) + if plugin_info and plugin_info.python_dependencies: + all_dependencies.append(plugin_info.python_dependencies) + + if not all_dependencies: + logger.info("没有找到任何插件依赖") + return False + + return dependency_manager.generate_requirements_file(all_dependencies, output_path) + + def check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: + """检查插件版本兼容性 + + Args: + plugin_name: 插件名称 + manifest_data: manifest数据 + + Returns: + Tuple[bool, str]: (是否兼容, 错误信息) + """ + if "host_application" not in manifest_data: + # 没有版本要求,默认兼容 + return True, "" + + host_app = manifest_data["host_application"] + if not isinstance(host_app, dict): + return True, "" + + min_version = host_app.get("min_version", "") + max_version = host_app.get("max_version", "") + + if not min_version and not max_version: + return True, "" + + try: + from src.plugin_system.utils.manifest_utils import VersionComparator + + current_version = VersionComparator.get_current_host_version() + is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version) + + if not is_compatible: + return False, f"版本不兼容: {error_msg}" + else: + logger.debug(f"插件 {plugin_name} 版本兼容性检查通过") + return True, "" + + except Exception as e: + logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}") + return True, "" # 检查失败时默认允许加载 + + +# 全局插件管理器实例 +plugin_manager = PluginManager() + +# 注释掉以解决插件目录重复加载的情况 +# 默认插件目录 +# plugin_manager.add_plugin_directory("src/plugins/built_in") +# plugin_manager.add_plugin_directory("src/plugins/examples") +# 用户插件目录 +# plugin_manager.add_plugin_directory("plugins") diff --git a/src/plugin_system/events/__init__.py b/src/plugin_system/events/__init__.py new file mode 100644 index 000000000..6b49951df --- /dev/null +++ b/src/plugin_system/events/__init__.py @@ -0,0 +1,9 @@ +""" +插件的事件系统模块 +""" + +from .events import EventType + +__all__ = [ + "EventType", +] diff --git a/src/plugin_system/events/events.py b/src/plugin_system/events/events.py new file mode 100644 index 000000000..64d3a7dad --- /dev/null +++ b/src/plugin_system/events/events.py @@ -0,0 +1,14 @@ +from enum import Enum + + +class EventType(Enum): + """ + 事件类型枚举类 + """ + + ON_MESSAGE = "on_message" + ON_PLAN = "on_plan" + POST_LLM = "post_llm" + AFTER_LLM = "after_llm" + POST_SEND = "post_send" + AFTER_SEND = "after_send" From 6633d5e273e22f29492f58b0082476c73dff4ea8 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Mon, 7 Jul 2025 12:23:55 +0800 Subject: [PATCH 02/13] =?UTF-8?q?=E8=A1=A5=E5=85=A8plugin=5Fmanager?= =?UTF-8?q?=E7=9A=84=E5=89=A9=E4=BD=99=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_system/core/plugin_manager.py | 162 +++++++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index fbd5de8cc..12c59dcf4 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -8,7 +8,9 @@ import traceback from src.common.logger import get_logger from src.plugin_system.events.events import EventType from src.plugin_system.core.component_registry import component_registry +from src.plugin_system.core.dependency_manager import dependency_manager from src.plugin_system.base.base_plugin import BasePlugin +from src.plugin_system.base.component_types import ComponentType, PluginInfo, PythonDependency from src.plugin_system.utils.manifest_utils import VersionComparator logger = get_logger("plugin_manager") @@ -175,6 +177,166 @@ class PluginManager: self._load_plugin_modules_from_directory(directory) else: logger.warning(f"插件根目录不存在: {directory}") + + def get_loaded_plugins(self) -> List[PluginInfo]: + """获取所有已加载的插件信息""" + return list(component_registry.get_all_plugins().values()) + + def get_enabled_plugins(self) -> List[PluginInfo]: + """获取所有启用的插件信息""" + return list(component_registry.get_enabled_plugins().values()) + + def enable_plugin(self, plugin_name: str) -> bool: + # -------------------------------- NEED REFACTORING -------------------------------- + """启用插件""" + if plugin_info := component_registry.get_plugin_info(plugin_name): + plugin_info.enabled = True + # 启用插件的所有组件 + for component in plugin_info.components: + component_registry.enable_component(component.name) + logger.debug(f"已启用插件: {plugin_name}") + return True + return False + + def disable_plugin(self, plugin_name: str) -> bool: + # -------------------------------- NEED REFACTORING -------------------------------- + """禁用插件""" + if plugin_info := component_registry.get_plugin_info(plugin_name): + plugin_info.enabled = False + # 禁用插件的所有组件 + for component in plugin_info.components: + component_registry.disable_component(component.name) + logger.debug(f"已禁用插件: {plugin_name}") + return True + return False + + def get_plugin_instance(self, plugin_name: str) -> Optional["BasePlugin"]: + """获取插件实例 + + Args: + plugin_name: 插件名称 + + Returns: + Optional[BasePlugin]: 插件实例或None + """ + return self.loaded_plugins.get(plugin_name) + + def get_plugin_stats(self) -> Dict[str, Any]: + """获取插件统计信息""" + all_plugins = component_registry.get_all_plugins() + enabled_plugins = component_registry.get_enabled_plugins() + + action_components = component_registry.get_components_by_type(ComponentType.ACTION) + command_components = component_registry.get_components_by_type(ComponentType.COMMAND) + + return { + "total_plugins": len(all_plugins), + "enabled_plugins": len(enabled_plugins), + "failed_plugins": len(self.failed_plugins), + "total_components": len(action_components) + len(command_components), + "action_components": len(action_components), + "command_components": len(command_components), + "loaded_plugin_files": len(self.loaded_plugins), + "failed_plugin_details": self.failed_plugins.copy(), + } + + def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, any]: + """检查所有插件的Python依赖包 + + Args: + auto_install: 是否自动安装缺失的依赖包 + + Returns: + Dict[str, any]: 检查结果摘要 + """ + logger.info("开始检查所有插件的Python依赖包...") + + all_required_missing: List[PythonDependency] = [] + all_optional_missing: List[PythonDependency] = [] + plugin_status = {} + + for plugin_name, _plugin_instance in self.loaded_plugins.items(): + plugin_info = component_registry.get_plugin_info(plugin_name) + if not plugin_info or not plugin_info.python_dependencies: + plugin_status[plugin_name] = {"status": "no_dependencies", "missing": []} + continue + + logger.info(f"检查插件 {plugin_name} 的依赖...") + + missing_required, missing_optional = dependency_manager.check_dependencies(plugin_info.python_dependencies) + + if missing_required: + all_required_missing.extend(missing_required) + plugin_status[plugin_name] = { + "status": "missing_required", + "missing": [dep.package_name for dep in missing_required], + "optional_missing": [dep.package_name for dep in missing_optional], + } + logger.error(f"插件 {plugin_name} 缺少必需依赖: {[dep.package_name for dep in missing_required]}") + elif missing_optional: + all_optional_missing.extend(missing_optional) + plugin_status[plugin_name] = { + "status": "missing_optional", + "missing": [], + "optional_missing": [dep.package_name for dep in missing_optional], + } + logger.warning(f"插件 {plugin_name} 缺少可选依赖: {[dep.package_name for dep in missing_optional]}") + else: + plugin_status[plugin_name] = {"status": "ok", "missing": []} + logger.info(f"插件 {plugin_name} 依赖检查通过") + + # 汇总结果 + total_missing = len({dep.package_name for dep in all_required_missing}) + total_optional_missing = len({dep.package_name for dep in all_optional_missing}) + + logger.info(f"依赖检查完成 - 缺少必需包: {total_missing}个, 缺少可选包: {total_optional_missing}个") + + # 如果需要自动安装 + install_success = True + if auto_install and all_required_missing: + unique_required = {dep.package_name: dep for dep in all_required_missing} + logger.info(f"开始自动安装 {len(unique_required)} 个必需依赖包...") + install_success = dependency_manager.install_dependencies(list(unique_required.values()), auto_install=True) + + return { + "total_plugins_checked": len(plugin_status), + "plugins_with_missing_required": len( + [p for p in plugin_status.values() if p["status"] == "missing_required"] + ), + "plugins_with_missing_optional": len( + [p for p in plugin_status.values() if p["status"] == "missing_optional"] + ), + "total_missing_required": total_missing, + "total_missing_optional": total_optional_missing, + "plugin_status": plugin_status, + "auto_install_attempted": auto_install and bool(all_required_missing), + "auto_install_success": install_success, + "install_summary": dependency_manager.get_install_summary(), + } + + def generate_plugin_requirements(self, output_path: str = "plugin_requirements.txt") -> bool: + """生成所有插件依赖的requirements文件 + + Args: + output_path: 输出文件路径 + + Returns: + bool: 生成是否成功 + """ + logger.info("开始生成插件依赖requirements文件...") + + all_dependencies = [] + + for plugin_name, _plugin_instance in self.loaded_plugins.items(): + plugin_info = component_registry.get_plugin_info(plugin_name) + if plugin_info and plugin_info.python_dependencies: + all_dependencies.append(plugin_info.python_dependencies) + + if not all_dependencies: + logger.info("没有找到任何插件依赖") + return False + + return dependency_manager.generate_requirements_file(all_dependencies, output_path) def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]: """从指定目录加载插件模块""" From 36974197a8337e439e6c79529b6ecb12294a58da Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Tue, 8 Jul 2025 00:10:31 +0800 Subject: [PATCH 03/13] =?UTF-8?q?=E6=9A=B4=E9=9C=B2=E5=85=A8=E9=83=A8api?= =?UTF-8?q?=EF=BC=8C=E8=A7=A3=E5=86=B3=E5=BE=AA=E7=8E=AFimport=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- plugins/take_picture_plugin/plugin.py | 3 +- src/plugin_system/__init__.py | 22 +- src/plugin_system/apis/__init__.py | 2 + src/plugin_system/apis/plugin_register_api.py | 29 + src/plugin_system/base/__init__.py | 13 +- src/plugin_system/base/base_plugin.py | 119 +--- src/plugin_system/core/__init__.py | 3 +- src/plugin_system/core/component_registry.py | 4 +- src/plugin_system/core/plugin_manager.py | 9 +- src/plugin_system/core/plugin_manager_bak.py | 570 ------------------ src/plugin_system/utils/__init__.py | 9 +- src/plugin_system/utils/manifest_utils.py | 2 +- src/plugins/built_in/tts_plugin/plugin.py | 3 +- src/plugins/built_in/vtb_plugin/plugin.py | 3 +- 14 files changed, 89 insertions(+), 702 deletions(-) create mode 100644 src/plugin_system/apis/plugin_register_api.py delete mode 100644 src/plugin_system/core/plugin_manager_bak.py diff --git a/plugins/take_picture_plugin/plugin.py b/plugins/take_picture_plugin/plugin.py index 5be4bf438..15406ca16 100644 --- a/plugins/take_picture_plugin/plugin.py +++ b/plugins/take_picture_plugin/plugin.py @@ -36,11 +36,12 @@ import urllib.error import base64 import traceback -from src.plugin_system.base.base_plugin import BasePlugin, register_plugin +from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system.base.base_action import BaseAction from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.component_types import ComponentInfo, ActionActivationType, ChatMode from src.plugin_system.base.config_types import ConfigField +from src.plugin_system import register_plugin from src.common.logger import get_logger logger = get_logger("take_picture_plugin") diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index 01b9a6125..213e86cac 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -5,11 +5,11 @@ MaiBot 插件系统 """ # 导出主要的公共接口 -from src.plugin_system.base.base_plugin import BasePlugin, register_plugin -from src.plugin_system.base.base_action import BaseAction -from src.plugin_system.base.base_command import BaseCommand -from src.plugin_system.base.config_types import ConfigField -from src.plugin_system.base.component_types import ( +from .base import ( + BasePlugin, + BaseAction, + BaseCommand, + ConfigField, ComponentType, ActionActivationType, ChatMode, @@ -19,18 +19,22 @@ from src.plugin_system.base.component_types import ( PluginInfo, PythonDependency, ) -from src.plugin_system.core.plugin_manager import plugin_manager -from src.plugin_system.core.component_registry import component_registry -from src.plugin_system.core.dependency_manager import dependency_manager +from .core.plugin_manager import ( + plugin_manager, + component_registry, + dependency_manager, +) # 导入工具模块 -from src.plugin_system.utils import ( +from .utils import ( ManifestValidator, ManifestGenerator, validate_plugin_manifest, generate_plugin_manifest, ) +from .apis.plugin_register_api import register_plugin + __version__ = "1.0.0" diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py index cfcf9b7e7..15ef547ef 100644 --- a/src/plugin_system/apis/__init__.py +++ b/src/plugin_system/apis/__init__.py @@ -16,6 +16,7 @@ from src.plugin_system.apis import ( person_api, send_api, utils_api, + plugin_register_api, ) # 导出所有API模块,使它们可以通过 apis.xxx 方式访问 @@ -30,4 +31,5 @@ __all__ = [ "person_api", "send_api", "utils_api", + "plugin_register_api", ] diff --git a/src/plugin_system/apis/plugin_register_api.py b/src/plugin_system/apis/plugin_register_api.py new file mode 100644 index 000000000..d6e7f1f53 --- /dev/null +++ b/src/plugin_system/apis/plugin_register_api.py @@ -0,0 +1,29 @@ +from src.common.logger import get_logger + +logger = get_logger("plugin_register") + + +def register_plugin(cls): + from src.plugin_system.core.plugin_manager import plugin_manager + from src.plugin_system.base.base_plugin import BasePlugin + + """插件注册装饰器 + + 用法: + @register_plugin + class MyPlugin(BasePlugin): + plugin_name = "my_plugin" + plugin_description = "我的插件" + ... + """ + if not issubclass(cls, BasePlugin): + logger.error(f"类 {cls.__name__} 不是 BasePlugin 的子类") + return cls + + # 只是注册插件类,不立即实例化 + # 插件管理器会负责实例化和注册 + plugin_name = cls.plugin_name or cls.__name__ + plugin_manager.plugin_classes[plugin_name] = cls + logger.debug(f"插件类已注册: {plugin_name}") + + return cls diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py index f22f5082d..bff325948 100644 --- a/src/plugin_system/base/__init__.py +++ b/src/plugin_system/base/__init__.py @@ -4,10 +4,10 @@ 提供插件开发的基础类和类型定义 """ -from src.plugin_system.base.base_plugin import BasePlugin, register_plugin -from src.plugin_system.base.base_action import BaseAction -from src.plugin_system.base.base_command import BaseCommand -from src.plugin_system.base.component_types import ( +from .base_plugin import BasePlugin +from .base_action import BaseAction +from .base_command import BaseCommand +from .component_types import ( ComponentType, ActionActivationType, ChatMode, @@ -15,13 +15,14 @@ from src.plugin_system.base.component_types import ( ActionInfo, CommandInfo, PluginInfo, + PythonDependency, ) +from .config_types import ConfigField __all__ = [ "BasePlugin", "BaseAction", "BaseCommand", - "register_plugin", "ComponentType", "ActionActivationType", "ChatMode", @@ -29,4 +30,6 @@ __all__ = [ "ActionInfo", "CommandInfo", "PluginInfo", + "PythonDependency", + "ConfigField", ] diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index 4044c12e9..5fdf20d2e 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -4,6 +4,9 @@ import os import inspect import toml import json +import shutil +import datetime + from src.common.logger import get_logger from src.plugin_system.base.component_types import ( PluginInfo, @@ -11,13 +14,10 @@ from src.plugin_system.base.component_types import ( PythonDependency, ) from src.plugin_system.base.config_types import ConfigField -from src.plugin_system.core.component_registry import component_registry +from src.plugin_system.utils.manifest_utils import ManifestValidator logger = get_logger("base_plugin") -# 全局插件类注册表 -_plugin_classes: Dict[str, Type["BasePlugin"]] = {} - class BasePlugin(ABC): """插件基类 @@ -29,7 +29,7 @@ class BasePlugin(ABC): """ # 插件基本信息(子类必须定义) - plugin_name: str = "" # 插件内部标识符(如 "doubao_pic_plugin") + plugin_name: str = "" # 插件内部标识符(如 "hello_world_plugin") enable_plugin: bool = False # 是否启用插件 dependencies: List[str] = [] # 依赖的其他插件 python_dependencies: List[PythonDependency] = [] # Python包依赖 @@ -103,7 +103,7 @@ class BasePlugin(ABC): if not self.get_manifest_info("description"): raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少description字段") - def _load_manifest(self): + def _load_manifest(self): # sourcery skip: raise-from-previous-error """加载manifest文件(强制要求)""" if not self.plugin_dir: raise ValueError(f"{self.log_prefix} 没有插件目录路径,无法加载manifest") @@ -124,9 +124,6 @@ class BasePlugin(ABC): # 验证manifest格式 self._validate_manifest() - # 从manifest覆盖插件基本信息(如果插件类中未定义) - self._apply_manifest_overrides() - except json.JSONDecodeError as e: error_msg = f"{self.log_prefix} manifest文件格式错误: {e}" logger.error(error_msg) @@ -136,15 +133,6 @@ class BasePlugin(ABC): logger.error(error_msg) raise IOError(error_msg) # noqa - def _apply_manifest_overrides(self): - """从manifest文件覆盖插件信息(现在只处理内部标识符的fallback)""" - if not self.manifest_data: - return - - # 只有当插件类中没有定义plugin_name时,才从manifest中获取作为fallback - if not self.plugin_name: - self.plugin_name = self.manifest_data.get("name", "").replace(" ", "_").lower() - def _get_author_name(self) -> str: """从manifest获取作者名称""" author_info = self.get_manifest_info("author", {}) @@ -156,10 +144,7 @@ class BasePlugin(ABC): def _validate_manifest(self): """验证manifest文件格式(使用强化的验证器)""" if not self.manifest_data: - return - - # 导入验证器 - from src.plugin_system.utils.manifest_utils import ManifestValidator + raise ValueError(f"{self.log_prefix} manifest数据为空,验证失败") validator = ManifestValidator() is_valid = validator.validate_manifest(self.manifest_data) @@ -176,36 +161,6 @@ class BasePlugin(ABC): error_msg += f": {'; '.join(validator.validation_errors)}" raise ValueError(error_msg) - def _generate_default_manifest(self, manifest_path: str): - """生成默认的manifest文件""" - if not self.plugin_name: - logger.debug(f"{self.log_prefix} 插件名称未定义,无法生成默认manifest") - return - - # 从plugin_name生成友好的显示名称 - display_name = self.plugin_name.replace("_", " ").title() - - default_manifest = { - "manifest_version": 1, - "name": display_name, - "version": "1.0.0", - "description": "插件描述", - "author": {"name": "Unknown", "url": ""}, - "license": "MIT", - "host_application": {"min_version": "1.0.0", "max_version": "4.0.0"}, - "keywords": [], - "categories": [], - "default_locale": "zh-CN", - "locales_path": "_locales", - } - - try: - with open(manifest_path, "w", encoding="utf-8") as f: - json.dump(default_manifest, f, ensure_ascii=False, indent=2) - logger.info(f"{self.log_prefix} 已生成默认manifest文件: {manifest_path}") - except IOError as e: - logger.error(f"{self.log_prefix} 保存默认manifest文件失败: {e}") - def get_manifest_info(self, key: str, default: Any = None) -> Any: """获取manifest信息 @@ -304,9 +259,6 @@ class BasePlugin(ABC): def _backup_config_file(self, config_file_path: str) -> str: """备份配置文件""" - import shutil - import datetime - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") backup_path = f"{config_file_path}.backup_{timestamp}" @@ -377,13 +329,14 @@ class BasePlugin(ABC): logger.warning(f"{self.log_prefix} 配置节 {section_name} 结构已改变,使用默认值") # 检查旧配置中是否有新配置没有的节 - for section_name in old_config.keys(): + for section_name in old_config: if section_name not in migrated_config: logger.warning(f"{self.log_prefix} 配置节 {section_name} 在新版本中已被移除") return migrated_config def _generate_config_from_schema(self) -> Dict[str, Any]: + # sourcery skip: dict-comprehension """根据schema生成配置数据结构(不写入文件)""" if not self.config_schema: return {} @@ -473,7 +426,7 @@ class BasePlugin(ABC): except IOError as e: logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True) - def _load_plugin_config(self): + def _load_plugin_config(self): # sourcery skip: extract-method """加载插件配置文件,支持版本检查和自动迁移""" if not self.config_file_name: logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载") @@ -568,7 +521,7 @@ class BasePlugin(ABC): def register_plugin(self) -> bool: """注册插件及其所有组件""" - + from src.plugin_system.core.component_registry import component_registry components = self.get_plugin_components() # 检查依赖 @@ -598,6 +551,7 @@ class BasePlugin(ABC): def _check_dependencies(self) -> bool: """检查插件依赖""" + from src.plugin_system.core.component_registry import component_registry if not self.dependencies: return True @@ -629,52 +583,3 @@ class BasePlugin(ABC): return default return current - - -def register_plugin(cls): - """插件注册装饰器 - - 用法: - @register_plugin - class MyPlugin(BasePlugin): - plugin_name = "my_plugin" - plugin_description = "我的插件" - ... - """ - if not issubclass(cls, BasePlugin): - logger.error(f"类 {cls.__name__} 不是 BasePlugin 的子类") - return cls - - # 只是注册插件类,不立即实例化 - # 插件管理器会负责实例化和注册 - plugin_name = cls.plugin_name or cls.__name__ - _plugin_classes[plugin_name] = cls - logger.debug(f"插件类已注册: {plugin_name}") - - return cls - - -def get_registered_plugin_classes() -> Dict[str, Type["BasePlugin"]]: - """获取所有已注册的插件类""" - return _plugin_classes.copy() - - -def instantiate_and_register_plugin(plugin_class: Type["BasePlugin"], plugin_dir: str = None) -> bool: - """实例化并注册插件 - - Args: - plugin_class: 插件类 - plugin_dir: 插件目录路径 - - Returns: - bool: 是否成功 - """ - try: - plugin_instance = plugin_class(plugin_dir=plugin_dir) - return plugin_instance.register_plugin() - except Exception as e: - logger.error(f"注册插件 {plugin_class.__name__} 时出错: {e}") - import traceback - - logger.error(traceback.format_exc()) - return False diff --git a/src/plugin_system/core/__init__.py b/src/plugin_system/core/__init__.py index d1377b477..6bd3d3935 100644 --- a/src/plugin_system/core/__init__.py +++ b/src/plugin_system/core/__init__.py @@ -6,8 +6,9 @@ from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.component_registry import component_registry - +from src.plugin_system.core.dependency_manager import dependency_manager __all__ = [ "plugin_manager", "component_registry", + "dependency_manager", ] diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 9d2dea721..809319802 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -9,8 +9,8 @@ from src.plugin_system.base.component_types import ( ComponentType, ) -from ..base.base_command import BaseCommand -from ..base.base_action import BaseAction +from src.plugin_system.base.base_command import BaseCommand +from src.plugin_system.base.base_action import BaseAction logger = get_logger("component_registry") diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 12c59dcf4..0de8f6eb6 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -89,6 +89,8 @@ class PluginManager: total_registered += 1 else: total_failed_registration += 1 + + return total_registered, total_failed_registration def load_registered_plugin_classes(self, plugin_name: str) -> bool: # sourcery skip: extract-duplicate-method, extract-method @@ -255,7 +257,7 @@ class PluginManager: all_optional_missing: List[PythonDependency] = [] plugin_status = {} - for plugin_name, _plugin_instance in self.loaded_plugins.items(): + for plugin_name in self.loaded_plugins: plugin_info = component_registry.get_plugin_info(plugin_name) if not plugin_info or not plugin_info.python_dependencies: plugin_status[plugin_name] = {"status": "no_dependencies", "missing": []} @@ -327,7 +329,7 @@ class PluginManager: all_dependencies = [] - for plugin_name, _plugin_instance in self.loaded_plugins.items(): + for plugin_name in self.loaded_plugins: plugin_info = component_registry.get_plugin_info(plugin_name) if plugin_info and plugin_info.python_dependencies: all_dependencies.append(plugin_info.python_dependencies) @@ -563,3 +565,6 @@ class PluginManager: ) else: logger.info(f"✅ 插件加载成功: {plugin_name}") + +# 全局插件管理器实例 +plugin_manager = PluginManager() \ No newline at end of file diff --git a/src/plugin_system/core/plugin_manager_bak.py b/src/plugin_system/core/plugin_manager_bak.py deleted file mode 100644 index 7bb74b6ee..000000000 --- a/src/plugin_system/core/plugin_manager_bak.py +++ /dev/null @@ -1,570 +0,0 @@ -from typing import Dict, List, Optional, Any, TYPE_CHECKING, Tuple -import os -import importlib -import importlib.util -from pathlib import Path -import traceback -from src.plugin_system.base.component_types import PythonDependency - -if TYPE_CHECKING: - from src.plugin_system.base.base_plugin import BasePlugin - -from src.common.logger import get_logger -from src.plugin_system.core.component_registry import component_registry -from src.plugin_system.core.dependency_manager import dependency_manager -from src.plugin_system.base.component_types import ComponentType, PluginInfo - -logger = get_logger("plugin_manager") - - -class PluginManager: - """插件管理器 - - 负责加载、初始化和管理所有插件及其组件 - """ - - def __init__(self): - self.plugin_directories: List[str] = [] - self.loaded_plugins: Dict[str, "BasePlugin"] = {} - self.failed_plugins: Dict[str, str] = {} - self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射 - - # 确保插件目录存在 - self._ensure_plugin_directories() - logger.info("插件管理器初始化完成") - - def _ensure_plugin_directories(self): - """确保所有插件目录存在,如果不存在则创建""" - default_directories = ["src/plugins/built_in", "plugins"] - - for directory in default_directories: - if not os.path.exists(directory): - os.makedirs(directory, exist_ok=True) - logger.info(f"创建插件目录: {directory}") - if directory not in self.plugin_directories: - self.plugin_directories.append(directory) - logger.debug(f"已添加插件目录: {directory}") - else: - logger.warning(f"插件不可重复加载: {directory}") - - def add_plugin_directory(self, directory: str): - """添加插件目录""" - if os.path.exists(directory): - if directory not in self.plugin_directories: - self.plugin_directories.append(directory) - logger.debug(f"已添加插件目录: {directory}") - else: - logger.warning(f"插件不可重复加载: {directory}") - else: - logger.warning(f"插件目录不存在: {directory}") - - def load_all_plugins(self) -> tuple[int, int]: - """加载所有插件目录中的插件 - - Returns: - tuple[int, int]: (插件数量, 组件数量) - """ - logger.debug("开始加载所有插件...") - - # 第一阶段:加载所有插件模块(注册插件类) - total_loaded_modules = 0 - total_failed_modules = 0 - - for directory in self.plugin_directories: - loaded, failed = self._load_plugin_modules_from_directory(directory) - total_loaded_modules += loaded - total_failed_modules += failed - - logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}") - - # 第二阶段:实例化所有已注册的插件类 - from src.plugin_system.base.base_plugin import get_registered_plugin_classes - - plugin_classes = get_registered_plugin_classes() - total_registered = 0 - total_failed_registration = 0 - - for plugin_name, plugin_class in plugin_classes.items(): - try: - # 使用记录的插件目录路径 - plugin_dir = self.plugin_paths.get(plugin_name) - - # 如果没有记录,则尝试查找(fallback) - if not plugin_dir: - plugin_dir = self._find_plugin_directory(plugin_class) - if plugin_dir: - self.plugin_paths[plugin_name] = plugin_dir # 实例化插件(可能因为缺少manifest而失败) - plugin_instance = plugin_class(plugin_dir=plugin_dir) - - # 检查插件是否启用 - if not plugin_instance.enable_plugin: - logger.info(f"插件 {plugin_name} 已禁用,跳过加载") - continue - - # 检查版本兼容性 - is_compatible, compatibility_error = self.check_plugin_version_compatibility( - plugin_name, plugin_instance.manifest_data - ) - if not is_compatible: - total_failed_registration += 1 - self.failed_plugins[plugin_name] = compatibility_error - logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}") - continue - - if plugin_instance.register_plugin(): - total_registered += 1 - self.loaded_plugins[plugin_name] = plugin_instance - - # 📊 显示插件详细信息 - plugin_info = component_registry.get_plugin_info(plugin_name) - if plugin_info: - component_types = {} - for comp in plugin_info.components: - comp_type = comp.component_type.name - component_types[comp_type] = component_types.get(comp_type, 0) + 1 - - components_str = ", ".join([f"{count}个{ctype}" for ctype, count in component_types.items()]) - - # 显示manifest信息 - manifest_info = "" - if plugin_info.license: - manifest_info += f" [{plugin_info.license}]" - if plugin_info.keywords: - manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词 - if len(plugin_info.keywords) > 3: - manifest_info += "..." - - logger.info( - f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}" - ) - else: - logger.info(f"✅ 插件加载成功: {plugin_name}") - else: - total_failed_registration += 1 - self.failed_plugins[plugin_name] = "插件注册失败" - logger.error(f"❌ 插件注册失败: {plugin_name}") - - except FileNotFoundError as e: - # manifest文件缺失 - total_failed_registration += 1 - error_msg = f"缺少manifest文件: {str(e)}" - self.failed_plugins[plugin_name] = error_msg - logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") - - except ValueError as e: - # manifest文件格式错误或验证失败 - traceback.print_exc() - total_failed_registration += 1 - error_msg = f"manifest验证失败: {str(e)}" - self.failed_plugins[plugin_name] = error_msg - logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") - - except Exception as e: - # 其他错误 - total_failed_registration += 1 - error_msg = f"未知错误: {str(e)}" - self.failed_plugins[plugin_name] = error_msg - logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") - logger.debug("详细错误信息: ", exc_info=True) - - # 获取组件统计信息 - stats = component_registry.get_registry_stats() - action_count = stats.get("action_components", 0) - command_count = stats.get("command_components", 0) - total_components = stats.get("total_components", 0) - - # 📋 显示插件加载总览 - if total_registered > 0: - logger.info("🎉 插件系统加载完成!") - logger.info( - f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count})" - ) - - # 显示详细的插件列表 logger.info("📋 已加载插件详情:") - for plugin_name, _plugin_class in self.loaded_plugins.items(): - plugin_info = component_registry.get_plugin_info(plugin_name) - if plugin_info: - # 插件基本信息 - version_info = f"v{plugin_info.version}" if plugin_info.version else "" - author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown" - license_info = f"[{plugin_info.license}]" if plugin_info.license else "" - info_parts = [part for part in [version_info, author_info, license_info] if part] - extra_info = f" ({', '.join(info_parts)})" if info_parts else "" - - logger.info(f" 📦 {plugin_name}{extra_info}") - - # Manifest信息 - if plugin_info.manifest_data: - if plugin_info.keywords: - logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}") - if plugin_info.categories: - logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}") - if plugin_info.homepage_url: - logger.info(f" 🌐 主页: {plugin_info.homepage_url}") - - # 组件列表 - if plugin_info.components: - action_components = [c for c in plugin_info.components if c.component_type.name == "ACTION"] - command_components = [c for c in plugin_info.components if c.component_type.name == "COMMAND"] - - if action_components: - action_names = [c.name for c in action_components] - logger.info(f" 🎯 Action组件: {', '.join(action_names)}") - - if command_components: - command_names = [c.name for c in command_components] - logger.info(f" ⚡ Command组件: {', '.join(command_names)}") - - # 版本兼容性信息 - if plugin_info.min_host_version or plugin_info.max_host_version: - version_range = "" - if plugin_info.min_host_version: - version_range += f">={plugin_info.min_host_version}" - if plugin_info.max_host_version: - if version_range: - version_range += f", <={plugin_info.max_host_version}" - else: - version_range += f"<={plugin_info.max_host_version}" - logger.info(f" 📋 兼容版本: {version_range}") - - # 依赖信息 - if plugin_info.dependencies: - logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}") - - # 配置文件信息 - if plugin_info.config_file: - config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌" - logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}") - - # 显示目录统计 - logger.info("📂 加载目录统计:") - for directory in self.plugin_directories: - if os.path.exists(directory): - plugins_in_dir = [] - for plugin_name in self.loaded_plugins.keys(): - plugin_path = self.plugin_paths.get(plugin_name, "") - if plugin_path.startswith(directory): - plugins_in_dir.append(plugin_name) - - if plugins_in_dir: - logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})") - else: - logger.info(f" 📁 {directory}: 0个插件") - - # 失败信息 - if total_failed_registration > 0: - logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败") - for failed_plugin, error in self.failed_plugins.items(): - logger.info(f" ❌ {failed_plugin}: {error}") - else: - logger.warning("😕 没有成功加载任何插件") - - # 返回插件数量和组件数量 - return total_registered, total_components - - def _find_plugin_directory(self, plugin_class) -> Optional[str]: - """查找插件类对应的目录路径""" - try: - import inspect - - module = inspect.getmodule(plugin_class) - if module and hasattr(module, "__file__") and module.__file__: - return os.path.dirname(module.__file__) - except Exception as e: - logger.debug(f"通过inspect获取插件目录失败: {e}") - return None - - def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]: - """从指定目录加载插件模块""" - loaded_count = 0 - failed_count = 0 - - if not os.path.exists(directory): - logger.warning(f"插件目录不存在: {directory}") - return loaded_count, failed_count - - logger.debug(f"正在扫描插件目录: {directory}") - - # 遍历目录中的所有Python文件和包 - for item in os.listdir(directory): - item_path = os.path.join(directory, item) - - if os.path.isfile(item_path) and item.endswith(".py") and item != "__init__.py": - # 单文件插件 - plugin_name = Path(item_path).stem - if self._load_plugin_module_file(item_path, plugin_name, directory): - loaded_count += 1 - else: - failed_count += 1 - - elif os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"): - # 插件包 - plugin_file = os.path.join(item_path, "plugin.py") - if os.path.exists(plugin_file): - plugin_name = item # 使用目录名作为插件名 - if self._load_plugin_module_file(plugin_file, plugin_name, item_path): - loaded_count += 1 - else: - failed_count += 1 - - return loaded_count, failed_count - - def _load_plugin_module_file(self, plugin_file: str, plugin_name: str, plugin_dir: str) -> bool: - """加载单个插件模块文件 - - Args: - plugin_file: 插件文件路径 - plugin_name: 插件名称 - plugin_dir: 插件目录路径 - """ - # 生成模块名 - plugin_path = Path(plugin_file) - if plugin_path.parent.name != "plugins": - # 插件包格式:parent_dir.plugin - module_name = f"plugins.{plugin_path.parent.name}.plugin" - else: - # 单文件格式:plugins.filename - module_name = f"plugins.{plugin_path.stem}" - - try: - # 动态导入插件模块 - spec = importlib.util.spec_from_file_location(module_name, plugin_file) - if spec is None or spec.loader is None: - logger.error(f"无法创建模块规范: {plugin_file}") - return False - - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - # 记录插件名和目录路径的映射 - self.plugin_paths[plugin_name] = plugin_dir - - logger.debug(f"插件模块加载成功: {plugin_file}") - return True - - except Exception as e: - error_msg = f"加载插件模块 {plugin_file} 失败: {e}" - logger.error(error_msg) - self.failed_plugins[plugin_name] = error_msg - return False - - def get_loaded_plugins(self) -> List[PluginInfo]: - """获取所有已加载的插件信息""" - return list(component_registry.get_all_plugins().values()) - - def get_enabled_plugins(self) -> List[PluginInfo]: - """获取所有启用的插件信息""" - return list(component_registry.get_enabled_plugins().values()) - - def enable_plugin(self, plugin_name: str) -> bool: - """启用插件""" - plugin_info = component_registry.get_plugin_info(plugin_name) - if plugin_info: - plugin_info.enabled = True - # 启用插件的所有组件 - for component in plugin_info.components: - component_registry.enable_component(component.name) - logger.debug(f"已启用插件: {plugin_name}") - return True - return False - - def disable_plugin(self, plugin_name: str) -> bool: - """禁用插件""" - plugin_info = component_registry.get_plugin_info(plugin_name) - if plugin_info: - plugin_info.enabled = False - # 禁用插件的所有组件 - for component in plugin_info.components: - component_registry.disable_component(component.name) - logger.debug(f"已禁用插件: {plugin_name}") - return True - return False - - def get_plugin_instance(self, plugin_name: str) -> Optional["BasePlugin"]: - """获取插件实例 - - Args: - plugin_name: 插件名称 - - Returns: - Optional[BasePlugin]: 插件实例或None - """ - return self.loaded_plugins.get(plugin_name) - - def get_plugin_stats(self) -> Dict[str, Any]: - """获取插件统计信息""" - all_plugins = component_registry.get_all_plugins() - enabled_plugins = component_registry.get_enabled_plugins() - - action_components = component_registry.get_components_by_type(ComponentType.ACTION) - command_components = component_registry.get_components_by_type(ComponentType.COMMAND) - - return { - "total_plugins": len(all_plugins), - "enabled_plugins": len(enabled_plugins), - "failed_plugins": len(self.failed_plugins), - "total_components": len(action_components) + len(command_components), - "action_components": len(action_components), - "command_components": len(command_components), - "loaded_plugin_files": len(self.loaded_plugins), - "failed_plugin_details": self.failed_plugins.copy(), - } - - def reload_plugin(self, plugin_name: str) -> bool: - """重新加载插件(高级功能,需要谨慎使用)""" - # TODO: 实现插件热重载功能 - logger.warning("插件热重载功能尚未实现") - return False - - def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, any]: - """检查所有插件的Python依赖包 - - Args: - auto_install: 是否自动安装缺失的依赖包 - - Returns: - Dict[str, any]: 检查结果摘要 - """ - logger.info("开始检查所有插件的Python依赖包...") - - all_required_missing: List[PythonDependency] = [] - all_optional_missing: List[PythonDependency] = [] - plugin_status = {} - - for plugin_name, _plugin_instance in self.loaded_plugins.items(): - plugin_info = component_registry.get_plugin_info(plugin_name) - if not plugin_info or not plugin_info.python_dependencies: - plugin_status[plugin_name] = {"status": "no_dependencies", "missing": []} - continue - - logger.info(f"检查插件 {plugin_name} 的依赖...") - - missing_required, missing_optional = dependency_manager.check_dependencies(plugin_info.python_dependencies) - - if missing_required: - all_required_missing.extend(missing_required) - plugin_status[plugin_name] = { - "status": "missing_required", - "missing": [dep.package_name for dep in missing_required], - "optional_missing": [dep.package_name for dep in missing_optional], - } - logger.error(f"插件 {plugin_name} 缺少必需依赖: {[dep.package_name for dep in missing_required]}") - elif missing_optional: - all_optional_missing.extend(missing_optional) - plugin_status[plugin_name] = { - "status": "missing_optional", - "missing": [], - "optional_missing": [dep.package_name for dep in missing_optional], - } - logger.warning(f"插件 {plugin_name} 缺少可选依赖: {[dep.package_name for dep in missing_optional]}") - else: - plugin_status[plugin_name] = {"status": "ok", "missing": []} - logger.info(f"插件 {plugin_name} 依赖检查通过") - - # 汇总结果 - total_missing = len({dep.package_name for dep in all_required_missing}) - total_optional_missing = len({dep.package_name for dep in all_optional_missing}) - - logger.info(f"依赖检查完成 - 缺少必需包: {total_missing}个, 缺少可选包: {total_optional_missing}个") - - # 如果需要自动安装 - install_success = True - if auto_install and all_required_missing: - # 去重 - unique_required = {} - for dep in all_required_missing: - unique_required[dep.package_name] = dep - - logger.info(f"开始自动安装 {len(unique_required)} 个必需依赖包...") - install_success = dependency_manager.install_dependencies(list(unique_required.values()), auto_install=True) - - return { - "total_plugins_checked": len(plugin_status), - "plugins_with_missing_required": len( - [p for p in plugin_status.values() if p["status"] == "missing_required"] - ), - "plugins_with_missing_optional": len( - [p for p in plugin_status.values() if p["status"] == "missing_optional"] - ), - "total_missing_required": total_missing, - "total_missing_optional": total_optional_missing, - "plugin_status": plugin_status, - "auto_install_attempted": auto_install and bool(all_required_missing), - "auto_install_success": install_success, - "install_summary": dependency_manager.get_install_summary(), - } - - def generate_plugin_requirements(self, output_path: str = "plugin_requirements.txt") -> bool: - """生成所有插件依赖的requirements文件 - - Args: - output_path: 输出文件路径 - - Returns: - bool: 生成是否成功 - """ - logger.info("开始生成插件依赖requirements文件...") - - all_dependencies = [] - - for plugin_name, _plugin_instance in self.loaded_plugins.items(): - plugin_info = component_registry.get_plugin_info(plugin_name) - if plugin_info and plugin_info.python_dependencies: - all_dependencies.append(plugin_info.python_dependencies) - - if not all_dependencies: - logger.info("没有找到任何插件依赖") - return False - - return dependency_manager.generate_requirements_file(all_dependencies, output_path) - - def check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: - """检查插件版本兼容性 - - Args: - plugin_name: 插件名称 - manifest_data: manifest数据 - - Returns: - Tuple[bool, str]: (是否兼容, 错误信息) - """ - if "host_application" not in manifest_data: - # 没有版本要求,默认兼容 - return True, "" - - host_app = manifest_data["host_application"] - if not isinstance(host_app, dict): - return True, "" - - min_version = host_app.get("min_version", "") - max_version = host_app.get("max_version", "") - - if not min_version and not max_version: - return True, "" - - try: - from src.plugin_system.utils.manifest_utils import VersionComparator - - current_version = VersionComparator.get_current_host_version() - is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version) - - if not is_compatible: - return False, f"版本不兼容: {error_msg}" - else: - logger.debug(f"插件 {plugin_name} 版本兼容性检查通过") - return True, "" - - except Exception as e: - logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}") - return True, "" # 检查失败时默认允许加载 - - -# 全局插件管理器实例 -plugin_manager = PluginManager() - -# 注释掉以解决插件目录重复加载的情况 -# 默认插件目录 -# plugin_manager.add_plugin_directory("src/plugins/built_in") -# plugin_manager.add_plugin_directory("src/plugins/examples") -# 用户插件目录 -# plugin_manager.add_plugin_directory("plugins") diff --git a/src/plugin_system/utils/__init__.py b/src/plugin_system/utils/__init__.py index 10a4fef34..c64a34660 100644 --- a/src/plugin_system/utils/__init__.py +++ b/src/plugin_system/utils/__init__.py @@ -4,11 +4,16 @@ 提供插件开发和管理的实用工具 """ -from src.plugin_system.utils.manifest_utils import ( +from .manifest_utils import ( ManifestValidator, ManifestGenerator, validate_plugin_manifest, generate_plugin_manifest, ) -__all__ = ["ManifestValidator", "ManifestGenerator", "validate_plugin_manifest", "generate_plugin_manifest"] +__all__ = [ + "ManifestValidator", + "ManifestGenerator", + "validate_plugin_manifest", + "generate_plugin_manifest", +] diff --git a/src/plugin_system/utils/manifest_utils.py b/src/plugin_system/utils/manifest_utils.py index 7be7ba900..b6e5a1f30 100644 --- a/src/plugin_system/utils/manifest_utils.py +++ b/src/plugin_system/utils/manifest_utils.py @@ -305,7 +305,7 @@ class ManifestValidator: # 检查URL格式(可选字段) for url_field in ["homepage_url", "repository_url"]: if url_field in manifest_data and manifest_data[url_field]: - url = manifest_data[url_field] + url: str = manifest_data[url_field] if not (url.startswith("http://") or url.startswith("https://")): self.validation_warnings.append(f"{url_field}建议使用完整的URL格式") diff --git a/src/plugins/built_in/tts_plugin/plugin.py b/src/plugins/built_in/tts_plugin/plugin.py index d60186a13..b72106b0b 100644 --- a/src/plugins/built_in/tts_plugin/plugin.py +++ b/src/plugins/built_in/tts_plugin/plugin.py @@ -1,4 +1,5 @@ -from src.plugin_system.base.base_plugin import BasePlugin, register_plugin +from src.plugin_system.apis.plugin_register_api import register_plugin +from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system.base.component_types import ComponentInfo from src.common.logger import get_logger from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode diff --git a/src/plugins/built_in/vtb_plugin/plugin.py b/src/plugins/built_in/vtb_plugin/plugin.py index a87071e63..2932205b5 100644 --- a/src/plugins/built_in/vtb_plugin/plugin.py +++ b/src/plugins/built_in/vtb_plugin/plugin.py @@ -1,4 +1,5 @@ -from src.plugin_system.base.base_plugin import BasePlugin, register_plugin +from src.plugin_system.apis.plugin_register_api import register_plugin +from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system.base.component_types import ComponentInfo from src.common.logger import get_logger from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode From 023e524b3b12cde9b95f150dd08b2214c0b0924b Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Tue, 8 Jul 2025 10:43:28 +0800 Subject: [PATCH 04/13] =?UTF-8?q?=E5=BF=98=E4=BA=86=E5=B1=95=E7=A4=BA?= =?UTF-8?q?=E7=BB=9F=E8=AE=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_system/base/base_plugin.py | 2 ++ src/plugin_system/core/__init__.py | 1 + src/plugin_system/core/plugin_manager.py | 16 ++++++++++------ 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index 5fdf20d2e..a9aae4347 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -522,6 +522,7 @@ class BasePlugin(ABC): def register_plugin(self) -> bool: """注册插件及其所有组件""" from src.plugin_system.core.component_registry import component_registry + components = self.get_plugin_components() # 检查依赖 @@ -552,6 +553,7 @@ class BasePlugin(ABC): def _check_dependencies(self) -> bool: """检查插件依赖""" from src.plugin_system.core.component_registry import component_registry + if not self.dependencies: return True diff --git a/src/plugin_system/core/__init__.py b/src/plugin_system/core/__init__.py index 6bd3d3935..50537b903 100644 --- a/src/plugin_system/core/__init__.py +++ b/src/plugin_system/core/__init__.py @@ -7,6 +7,7 @@ from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.dependency_manager import dependency_manager + __all__ = [ "plugin_manager", "component_registry", diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 0de8f6eb6..a30a3028b 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -89,7 +89,9 @@ class PluginManager: total_registered += 1 else: total_failed_registration += 1 - + + self._show_stats(total_registered, total_failed_registration) + return total_registered, total_failed_registration def load_registered_plugin_classes(self, plugin_name: str) -> bool: @@ -173,13 +175,14 @@ class PluginManager: """ 重新扫描插件根目录 """ + # --------------------------------------- NEED REFACTORING --------------------------------------- for directory in self.plugin_directories: if os.path.exists(directory): logger.debug(f"重新扫描插件根目录: {directory}") self._load_plugin_modules_from_directory(directory) else: logger.warning(f"插件根目录不存在: {directory}") - + def get_loaded_plugins(self) -> List[PluginInfo]: """获取所有已加载的插件信息""" return list(component_registry.get_all_plugins().values()) @@ -187,7 +190,7 @@ class PluginManager: def get_enabled_plugins(self) -> List[PluginInfo]: """获取所有启用的插件信息""" return list(component_registry.get_enabled_plugins().values()) - + def enable_plugin(self, plugin_name: str) -> bool: # -------------------------------- NEED REFACTORING -------------------------------- """启用插件""" @@ -222,7 +225,7 @@ class PluginManager: Optional[BasePlugin]: 插件实例或None """ return self.loaded_plugins.get(plugin_name) - + def get_plugin_stats(self) -> Dict[str, Any]: """获取插件统计信息""" all_plugins = component_registry.get_all_plugins() @@ -241,7 +244,7 @@ class PluginManager: "loaded_plugin_files": len(self.loaded_plugins), "failed_plugin_details": self.failed_plugins.copy(), } - + def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, any]: """检查所有插件的Python依赖包 @@ -566,5 +569,6 @@ class PluginManager: else: logger.info(f"✅ 插件加载成功: {plugin_name}") + # 全局插件管理器实例 -plugin_manager = PluginManager() \ No newline at end of file +plugin_manager = PluginManager() From 855211e861e4d38cdf0eedbb1ccc428dfa61f555 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Tue, 8 Jul 2025 23:23:18 +0800 Subject: [PATCH 05/13] =?UTF-8?q?fix=20ruff,=20=E5=88=A0=E9=99=A4=E4=B8=80?= =?UTF-8?q?=E4=BA=9B=E5=86=97=E4=BD=99=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../heart_flow/heartflow_message_processor.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index 66ddf362e..5499a1f4c 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -3,7 +3,6 @@ from src.config.config import global_config from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.storage import MessageStorage from src.chat.heart_flow.heartflow import heartflow -from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.timer_calculator import Timer from src.common.logger import get_logger @@ -95,26 +94,18 @@ class HeartFCMessageReceiver: """ try: # 1. 消息解析与初始化 - groupinfo = message.message_info.group_info userinfo = message.message_info.user_info - messageinfo = message.message_info - - chat = await get_chat_manager().get_or_create_stream( - platform=messageinfo.platform, - user_info=userinfo, - group_info=groupinfo, - ) + chat = message.chat_stream await self.storage.store_message(message, chat) subheartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) - message.update_chat_stream(chat) - # 6. 兴趣度计算与更新 + # 2. 兴趣度计算与更新 interested_rate, is_mentioned = await _calculate_interest(message) subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) - # 7. 日志记录 + # 3. 日志记录 mes_name = chat.group_info.group_name if chat.group_info else "私聊" # current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time)) current_talk_frequency = global_config.chat.get_current_talk_frequency(chat.stream_id) From 2bbf5e1c59abe50e81e596891b6fb332a1250237 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Tue, 8 Jul 2025 23:43:15 +0800 Subject: [PATCH 06/13] fix ruff again --- src/chat/normal_chat/normal_chat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index 51642a700..ec73be5f1 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -304,7 +304,7 @@ class NormalChat: semaphore = asyncio.Semaphore(5) - async def process_and_acquire(msg_id, message, interest_value, is_mentioned): + async def process_and_acquire(msg_id, message, interest_value, is_mentioned, semaphore): """处理单个兴趣消息并管理信号量""" async with semaphore: try: @@ -334,7 +334,7 @@ class NormalChat: self.interest_dict.pop(msg_id, None) tasks = [ - process_and_acquire(msg_id, message, interest_value, is_mentioned) + process_and_acquire(msg_id, message, interest_value, is_mentioned, semaphore) for msg_id, (message, interest_value, is_mentioned) in items_to_process ] From d5cd0e8538bbd432ca8b188665a1f6a2e8a5b9ff Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Wed, 9 Jul 2025 21:54:43 +0800 Subject: [PATCH 07/13] =?UTF-8?q?=E4=BF=AE=E6=94=B9import=E9=A1=BA?= =?UTF-8?q?=E5=BA=8F=EF=BC=8C=E6=8A=8A=E9=AD=94=E6=B3=95=E5=AD=97=E5=8F=98?= =?UTF-8?q?=E4=B8=BA=E6=9E=9A=E4=B8=BE=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/focus_chat/heartFC_chat.py | 18 ++-- src/chat/heart_flow/chat_state_info.py | 3 + src/chat/heart_flow/heartflow.py | 7 +- .../heart_flow/heartflow_message_processor.py | 2 +- src/chat/heart_flow/sub_heartflow.py | 45 +++++----- src/chat/normal_chat/normal_chat.py | 39 ++++---- src/chat/normal_chat/priority_manager.py | 3 +- src/chat/planner_actions/action_manager.py | 31 +++---- src/chat/planner_actions/action_modifier.py | 18 ++-- src/chat/planner_actions/planner.py | 23 ++--- src/plugin_system/base/component_types.py | 6 ++ start_lpmm.bat | 88 ------------------- 12 files changed, 98 insertions(+), 185 deletions(-) delete mode 100644 start_lpmm.bat diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index 08008bfe9..70cda57c6 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -4,20 +4,21 @@ import time import traceback from collections import deque from typing import List, Optional, Dict, Any, Deque, Callable, Awaitable -from src.chat.message_receive.chat_stream import get_chat_manager from rich.traceback import install -from src.chat.utils.prompt_builder import global_prompt_manager + +from src.config.config import global_config from src.common.logger import get_logger +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.timer_calculator import Timer -from src.chat.focus_chat.focus_loop_info import FocusLoopInfo from src.chat.planner_actions.planner import ActionPlanner from src.chat.planner_actions.action_modifier import ActionModifier from src.chat.planner_actions.action_manager import ActionManager -from src.config.config import global_config +from src.chat.focus_chat.focus_loop_info import FocusLoopInfo from src.chat.focus_chat.hfc_performance_logger import HFCPerformanceLogger -from src.person_info.relationship_builder_manager import relationship_builder_manager from src.chat.focus_chat.hfc_utils import CycleDetail - +from src.person_info.relationship_builder_manager import relationship_builder_manager +from src.plugin_system.base.component_types import ChatMode install(extra_lines=3) @@ -134,8 +135,7 @@ class HeartFChatting: def _handle_loop_completion(self, task: asyncio.Task): """当 _hfc_loop 任务完成时执行的回调。""" try: - exception = task.exception() - if exception: + if exception := task.exception(): logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}") logger.error(traceback.format_exc()) # Log full traceback for exceptions else: @@ -342,7 +342,7 @@ class HeartFChatting: # 调用完整的动作修改流程 await self.action_modifier.modify_actions( loop_info=self.loop_info, - mode="focus", + mode=ChatMode.FOCUS, ) except Exception as e: logger.error(f"{self.log_prefix} 动作修改失败: {e}") diff --git a/src/chat/heart_flow/chat_state_info.py b/src/chat/heart_flow/chat_state_info.py index 33936186b..871516d49 100644 --- a/src/chat/heart_flow/chat_state_info.py +++ b/src/chat/heart_flow/chat_state_info.py @@ -5,6 +5,9 @@ class ChatState(enum.Enum): ABSENT = "没在看群" NORMAL = "随便水群" FOCUSED = "认真水群" + + def __str__(self): + return self.name class ChatStateInfo: diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py index ca6e8be7b..fdcfba6a3 100644 --- a/src/chat/heart_flow/heartflow.py +++ b/src/chat/heart_flow/heartflow.py @@ -16,14 +16,11 @@ class Heartflow: async def get_or_create_subheartflow(self, subheartflow_id: Any) -> Optional["SubHeartflow"]: """获取或创建一个新的SubHeartflow实例""" if subheartflow_id in self.subheartflows: - subflow = self.subheartflows.get(subheartflow_id) - if subflow: + if subflow := self.subheartflows.get(subheartflow_id): return subflow try: - new_subflow = SubHeartflow( - subheartflow_id, - ) + new_subflow = SubHeartflow(subheartflow_id) await new_subflow.initialize() diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index fc4337887..d01775168 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -123,7 +123,7 @@ class HeartFCMessageReceiver: logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]") - # 8. 关系处理 + # 4. 关系处理 if global_config.relationship.enable_relationship: await _process_relationship(message) diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py index 9ef357379..9f6a49895 100644 --- a/src/chat/heart_flow/sub_heartflow.py +++ b/src/chat/heart_flow/sub_heartflow.py @@ -72,27 +72,28 @@ class SubHeartflow: 停止 NormalChat 实例 切出 CHAT 状态时使用 """ - if self.normal_chat_instance: - logger.info(f"{self.log_prefix} 离开normal模式") - try: - logger.debug(f"{self.log_prefix} 开始调用 stop_chat()") - # 使用更短的超时时间,强制快速停止 - await asyncio.wait_for(self.normal_chat_instance.stop_chat(), timeout=3.0) - logger.debug(f"{self.log_prefix} stop_chat() 调用完成") - except asyncio.TimeoutError: - logger.warning(f"{self.log_prefix} 停止 NormalChat 超时,强制清理") - # 超时时强制清理实例 + if not self.normal_chat_instance: + return + logger.info(f"{self.log_prefix} 离开normal模式") + try: + logger.debug(f"{self.log_prefix} 开始调用 stop_chat()") + # 使用更短的超时时间,强制快速停止 + await asyncio.wait_for(self.normal_chat_instance.stop_chat(), timeout=3.0) + logger.debug(f"{self.log_prefix} stop_chat() 调用完成") + except asyncio.TimeoutError: + logger.warning(f"{self.log_prefix} 停止 NormalChat 超时,强制清理") + # 超时时强制清理实例 + self.normal_chat_instance = None + except Exception as e: + logger.error(f"{self.log_prefix} 停止 NormalChat 监控任务时出错: {e}") + # 出错时也要清理实例,避免状态不一致 + self.normal_chat_instance = None + finally: + # 确保实例被清理 + if self.normal_chat_instance: + logger.warning(f"{self.log_prefix} 强制清理 NormalChat 实例") self.normal_chat_instance = None - except Exception as e: - logger.error(f"{self.log_prefix} 停止 NormalChat 监控任务时出错: {e}") - # 出错时也要清理实例,避免状态不一致 - self.normal_chat_instance = None - finally: - # 确保实例被清理 - if self.normal_chat_instance: - logger.warning(f"{self.log_prefix} 强制清理 NormalChat 实例") - self.normal_chat_instance = None - logger.debug(f"{self.log_prefix} _stop_normal_chat 完成") + logger.debug(f"{self.log_prefix} _stop_normal_chat 完成") async def _start_normal_chat(self, rewind=False) -> bool: """ @@ -348,6 +349,4 @@ class SubHeartflow: if elapsed_since_exit >= cooldown_duration: return 1.0 # 冷却完成 - # 计算进度:0表示刚开始冷却,1表示冷却完成 - progress = elapsed_since_exit / cooldown_duration - return progress + return elapsed_since_exit / cooldown_duration diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index 4704cb238..b5e9890eb 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -1,28 +1,30 @@ import asyncio import time +import traceback from random import random from typing import List, Optional +from maim_message import UserInfo, Seg + from src.config.config import global_config from src.common.logger import get_logger -from src.person_info.person_info import get_person_info_manager -from src.plugin_system.apis import generator_api -from maim_message import UserInfo, Seg -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from src.chat.utils.timer_calculator import Timer from src.common.message_repository import count_messages -from src.chat.utils.prompt_builder import global_prompt_manager -from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet +from src.plugin_system.apis import generator_api +from src.plugin_system.base.component_types import ChatMode +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet from src.chat.message_receive.normal_message_sender import message_manager from src.chat.normal_chat.willing.willing_manager import get_willing_manager from src.chat.planner_actions.action_manager import ActionManager -from src.person_info.relationship_builder_manager import relationship_builder_manager -from .priority_manager import PriorityManager -import traceback from src.chat.planner_actions.planner import ActionPlanner from src.chat.planner_actions.action_modifier import ActionModifier - from src.chat.utils.utils import get_chat_type_and_target_info +from src.chat.utils.prompt_builder import global_prompt_manager +from src.chat.utils.timer_calculator import Timer from src.mood.mood_manager import mood_manager +from src.person_info.person_info import get_person_info_manager +from src.person_info.relationship_builder_manager import relationship_builder_manager +from .priority_manager import PriorityManager + willing_manager = get_willing_manager() @@ -70,7 +72,7 @@ class NormalChat: # Planner相关初始化 self.action_manager = ActionManager() - self.planner = ActionPlanner(self.stream_id, self.action_manager, mode="normal") + self.planner = ActionPlanner(self.stream_id, self.action_manager, mode=ChatMode.NORMAL) self.action_modifier = ActionModifier(self.action_manager, self.stream_id) self.enable_planner = global_config.normal_chat.enable_planner # 从配置中读取是否启用planner @@ -126,13 +128,8 @@ class NormalChat: continue # 条目已被其他任务处理 message, interest_value, _ = value - if not self._disabled: - # 更新消息段信息 - # self._update_user_message_segments(message) - - # 添加消息到优先级管理器 - if self.priority_manager: - self.priority_manager.add_message(message, interest_value) + if not self._disabled and self.priority_manager: + self.priority_manager.add_message(message, interest_value) except Exception: logger.error( @@ -564,8 +561,8 @@ class NormalChat: available_actions = None if self.enable_planner: try: - await self.action_modifier.modify_actions(mode="normal", message_content=message.processed_plain_text) - available_actions = self.action_manager.get_using_actions_for_mode("normal") + await self.action_modifier.modify_actions(mode=ChatMode.NORMAL, message_content=message.processed_plain_text) + available_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL) except Exception as e: logger.warning(f"[{self.stream_name}] 获取available_actions失败: {e}") available_actions = None diff --git a/src/chat/normal_chat/priority_manager.py b/src/chat/normal_chat/priority_manager.py index 9e1ef76c2..0296017ff 100644 --- a/src/chat/normal_chat/priority_manager.py +++ b/src/chat/normal_chat/priority_manager.py @@ -2,7 +2,8 @@ import time import heapq import math from typing import List, Dict, Optional -from ..message_receive.message import MessageRecv + +from src.chat.message_receive.message import MessageRecv from src.common.logger import get_logger logger = get_logger("normal_chat") diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 3918831ca..3937d1d14 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -3,14 +3,10 @@ from src.plugin_system.base.base_action import BaseAction from src.chat.message_receive.chat_stream import ChatStream from src.common.logger import get_logger from src.plugin_system.core.component_registry import component_registry -from src.plugin_system.base.component_types import ComponentType +from src.plugin_system.base.component_types import ComponentType, ActionActivationType, ChatMode, ActionInfo logger = get_logger("action_manager") -# 定义动作信息类型 -ActionInfo = Dict[str, Any] - - class ActionManager: """ 动作管理器,用于管理各种类型的动作 @@ -20,8 +16,8 @@ class ActionManager: # 类常量 DEFAULT_RANDOM_PROBABILITY = 0.3 - DEFAULT_MODE = "all" - DEFAULT_ACTIVATION_TYPE = "always" + DEFAULT_MODE = ChatMode.ALL + DEFAULT_ACTIVATION_TYPE = ActionActivationType.ALWAYS def __init__(self): """初始化动作管理器""" @@ -54,11 +50,8 @@ class ActionManager: def _load_plugin_system_actions(self) -> None: """从插件系统的component_registry加载Action组件""" try: - from src.plugin_system.core.component_registry import component_registry - from src.plugin_system.base.component_types import ComponentType - # 获取所有Action组件 - action_components = component_registry.get_components_by_type(ComponentType.ACTION) + action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) for action_name, action_info in action_components.items(): if action_name in self._registered_actions: @@ -181,28 +174,28 @@ class ActionManager: """获取当前正在使用的动作集合""" return self._using_actions.copy() - def get_using_actions_for_mode(self, mode: str) -> Dict[str, ActionInfo]: + def get_using_actions_for_mode(self, mode: ChatMode) -> Dict[str, ActionInfo]: """ 根据聊天模式获取可用的动作集合 Args: - mode: 聊天模式 ("focus", "normal", "all") + mode: 聊天模式 (ChatMode.FOCUS, ChatMode.NORMAL, ChatMode.ALL) Returns: Dict[str, ActionInfo]: 在指定模式下可用的动作集合 """ - filtered_actions = {} + enabled_actions = {} for action_name, action_info in self._using_actions.items(): - action_mode = action_info.get("mode_enable", "all") + action_mode = action_info.mode_enable # 检查动作是否在当前模式下启用 - if action_mode == "all" or action_mode == mode: - filtered_actions[action_name] = action_info + if action_mode in [ChatMode.ALL, mode]: + enabled_actions[action_name] = action_info logger.debug(f"动作 {action_name} 在模式 {mode} 下可用 (mode_enable: {action_mode})") - logger.debug(f"模式 {mode} 下可用动作: {list(filtered_actions.keys())}") - return filtered_actions + logger.debug(f"模式 {mode} 下可用动作: {list(enabled_actions.keys())}") + return enabled_actions def add_action_to_using(self, action_name: str) -> bool: """ diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index a2e0066cf..4b15cbdb0 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -1,15 +1,17 @@ -from typing import List, Optional, Any, Dict -from src.common.logger import get_logger -from src.chat.focus_chat.focus_loop_info import FocusLoopInfo -from src.chat.message_receive.chat_stream import get_chat_manager -from src.config.config import global_config -from src.llm_models.utils_model import LLMRequest import random import asyncio import hashlib import time +from typing import List, Optional, Any, Dict + +from src.common.logger import get_logger +from src.config.config import global_config +from src.llm_models.utils_model import LLMRequest +from src.chat.focus_chat.focus_loop_info import FocusLoopInfo +from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.planner_actions.action_manager import ActionManager from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages +from src.plugin_system.base.component_types import ChatMode logger = get_logger("action_manager") @@ -44,7 +46,7 @@ class ActionModifier: async def modify_actions( self, loop_info=None, - mode: str = "focus", + mode: ChatMode = ChatMode.FOCUS, message_content: str = "", ): """ @@ -528,7 +530,7 @@ class ActionModifier: def get_available_actions_count(self) -> int: """获取当前可用动作数量(排除默认的no_action)""" - current_actions = self.action_manager.get_using_actions_for_mode("normal") + current_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL) # 排除no_action(如果存在) filtered_actions = {k: v for k, v in current_actions.items() if k != "no_action"} return len(filtered_actions) diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index edd5d010d..db7001b14 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -1,18 +1,21 @@ -import json # <--- 确保导入 json +import json +import time import traceback from typing import Dict, Any, Optional from rich.traceback import install +from datetime import datetime +from json_repair import repair_json + from src.llm_models.utils_model import LLMRequest from src.config.config import global_config from src.common.logger import get_logger from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.chat.planner_actions.action_manager import ActionManager -from json_repair import repair_json -from src.chat.utils.utils import get_chat_type_and_target_info -from datetime import datetime -from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat -import time +from src.chat.utils.utils import get_chat_type_and_target_info +from src.chat.planner_actions.action_manager import ActionManager +from src.chat.message_receive.chat_stream import get_chat_manager +from src.plugin_system.base.component_types import ChatMode + logger = get_logger("planner") @@ -54,7 +57,7 @@ def init_prompt(): class ActionPlanner: - def __init__(self, chat_id: str, action_manager: ActionManager, mode: str = "focus"): + def __init__(self, chat_id: str, action_manager: ActionManager, mode: ChatMode = ChatMode.FOCUS): self.chat_id = chat_id self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]" self.mode = mode @@ -62,7 +65,7 @@ class ActionPlanner: # LLM规划器配置 self.planner_llm = LLMRequest( model=global_config.model.planner, - request_type=f"{self.mode}.planner", # 用于动作规划 + request_type=f"{self.mode.value}.planner", # 用于动作规划 ) self.last_obs_time_mark = 0.0 @@ -224,7 +227,7 @@ class ActionPlanner: self.last_obs_time_mark = time.time() - if self.mode == "focus": + if self.mode == ChatMode.FOCUS: by_what = "聊天内容" no_action_block = "" else: diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index b69aaac2a..f720823c6 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -23,6 +23,9 @@ class ActionActivationType(Enum): RANDOM = "random" # 随机启用action到planner KEYWORD = "keyword" # 关键词触发启用action到planner + def __str__(self): + return self.value + # 聊天模式枚举 class ChatMode(Enum): @@ -32,6 +35,9 @@ class ChatMode(Enum): NORMAL = "normal" # Normal聊天模式 ALL = "all" # 所有聊天模式 + def __str__(self): + return self.value + @dataclass class PythonDependency: diff --git a/start_lpmm.bat b/start_lpmm.bat deleted file mode 100644 index eacaa2eb1..000000000 --- a/start_lpmm.bat +++ /dev/null @@ -1,88 +0,0 @@ -@echo off -CHCP 65001 > nul -setlocal enabledelayedexpansion - -echo 你需要选择启动方式,输入字母来选择: -echo V = 不知道什么意思就输入 V -echo C = 输入 C 使用 Conda 环境 -echo. -choice /C CV /N /M "不知道什么意思就输入 V (C/V)?" /T 10 /D V - -set "ENV_TYPE=" -if %ERRORLEVEL% == 1 set "ENV_TYPE=CONDA" -if %ERRORLEVEL% == 2 set "ENV_TYPE=VENV" - -if "%ENV_TYPE%" == "CONDA" goto activate_conda -if "%ENV_TYPE%" == "VENV" goto activate_venv - -REM 如果 choice 超时或返回意外值,默认使用 venv -echo WARN: Invalid selection or timeout from choice. Defaulting to VENV. -set "ENV_TYPE=VENV" -goto activate_venv - -:activate_conda - set /p CONDA_ENV_NAME="请输入要使用的 Conda 环境名称: " - if not defined CONDA_ENV_NAME ( - echo 错误: 未输入 Conda 环境名称. - pause - exit /b 1 - ) - echo 选择: Conda '!CONDA_ENV_NAME!' - REM 激活Conda环境 - call conda activate !CONDA_ENV_NAME! - if !ERRORLEVEL! neq 0 ( - echo 错误: Conda环境 '!CONDA_ENV_NAME!' 激活失败. 请确保Conda已安装并正确配置, 且 '!CONDA_ENV_NAME!' 环境存在. - pause - exit /b 1 - ) - goto env_activated - -:activate_venv - echo Selected: venv (default or selected) - REM 查找venv虚拟环境 - set "venv_path=%~dp0venv\Scripts\activate.bat" - if not exist "%venv_path%" ( - echo Error: venv not found. Ensure the venv directory exists alongside the script. - pause - exit /b 1 - ) - REM 激活虚拟环境 - call "%venv_path%" - if %ERRORLEVEL% neq 0 ( - echo Error: Failed to activate venv virtual environment. - pause - exit /b 1 - ) - goto env_activated - -:env_activated -echo Environment activated successfully! - -REM --- 后续脚本执行 --- - -REM 运行预处理脚本 -python "%~dp0scripts\raw_data_preprocessor.py" -if %ERRORLEVEL% neq 0 ( - echo Error: raw_data_preprocessor.py execution failed. - pause - exit /b 1 -) - -REM 运行信息提取脚本 -python "%~dp0scripts\info_extraction.py" -if %ERRORLEVEL% neq 0 ( - echo Error: info_extraction.py execution failed. - pause - exit /b 1 -) - -REM 运行OpenIE导入脚本 -python "%~dp0scripts\import_openie.py" -if %ERRORLEVEL% neq 0 ( - echo Error: import_openie.py execution failed. - pause - exit /b 1 -) - -echo All processing steps completed! -pause \ No newline at end of file From ab61b1bb22ca2827bbca8f2a852ec07f590272cd Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Thu, 10 Jul 2025 16:46:37 +0800 Subject: [PATCH 08/13] =?UTF-8?q?=E6=8F=92=E4=BB=B6=E7=B3=BB=E7=BB=9Finfo?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=EF=BC=8C=E8=A7=81changes.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- changes.md | 18 ++++++++ plugins/hello_world_plugin/plugin.py | 2 + plugins/take_picture_plugin/plugin.py | 2 + src/chat/heart_flow/chat_state_info.py | 2 +- src/chat/message_receive/chat_stream.py | 1 + src/chat/normal_chat/normal_chat.py | 4 +- src/chat/planner_actions/action_manager.py | 7 +-- src/chat/planner_actions/action_modifier.py | 50 +++++---------------- src/plugin_system/base/base_action.py | 3 +- src/plugin_system/base/base_plugin.py | 40 +++++++++++++---- src/plugin_system/base/component_types.py | 1 + src/plugin_system/core/plugin_manager.py | 31 +++++++------ src/plugins/built_in/core_actions/plugin.py | 2 + src/plugins/built_in/tts_plugin/plugin.py | 2 + src/plugins/built_in/vtb_plugin/plugin.py | 2 + 15 files changed, 99 insertions(+), 68 deletions(-) create mode 100644 changes.md diff --git a/changes.md b/changes.md new file mode 100644 index 000000000..85760965f --- /dev/null +++ b/changes.md @@ -0,0 +1,18 @@ +# 插件API与规范修改 + +1. 现在`plugin_system`的`__init__.py`文件中包含了所有插件API的导入,用户可以直接使用`from plugin_system import *`来导入所有API。 + +2. register_plugin函数现在转移到了`plugin_system.apis.plugin_register_api`模块中,用户可以通过`from plugin_system.apis.plugin_register_api import register_plugin`来导入。 + +3. 现在强制要求的property如下: + - `plugin_name`: 插件名称,必须是唯一的。(与文件夹相同) + - `enable_plugin`: 是否启用插件,默认为`True`。 + - `dependencies`: 插件依赖的其他插件列表,默认为空。**现在并不检查(也许)** + - `python_dependencies`: 插件依赖的Python包列表,默认为空。**现在并不检查** + - `config_file_name`: 插件配置文件名,默认为`config.toml`。 + - `config_schema`: 插件配置文件的schema,用于自动生成配置文件。 + +# 插件系统修改 +1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)** +2. 修复了一下显示插件信息不显示的问题。同时精简了一下显示内容 +3. 修复了插件系统混用了`plugin_name`和`display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。**(可能有遗漏)** \ No newline at end of file diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index eaca35489..dc9b8571c 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -103,6 +103,8 @@ class HelloWorldPlugin(BasePlugin): # 插件基本信息 plugin_name = "hello_world_plugin" # 内部标识符 enable_plugin = True + dependencies = [] # 插件依赖列表 + python_dependencies = [] # Python包依赖列表 config_file_name = "config.toml" # 配置文件名 # 配置节描述 diff --git a/plugins/take_picture_plugin/plugin.py b/plugins/take_picture_plugin/plugin.py index 15406ca16..bbe189526 100644 --- a/plugins/take_picture_plugin/plugin.py +++ b/plugins/take_picture_plugin/plugin.py @@ -443,6 +443,8 @@ class TakePicturePlugin(BasePlugin): plugin_name = "take_picture_plugin" # 内部标识符 enable_plugin = True + dependencies = [] # 插件依赖列表 + python_dependencies = [] # Python包依赖列表 config_file_name = "config.toml" # 配置节描述 diff --git a/src/chat/heart_flow/chat_state_info.py b/src/chat/heart_flow/chat_state_info.py index 871516d49..9f137a953 100644 --- a/src/chat/heart_flow/chat_state_info.py +++ b/src/chat/heart_flow/chat_state_info.py @@ -5,7 +5,7 @@ class ChatState(enum.Enum): ABSENT = "没在看群" NORMAL = "随便水群" FOCUSED = "认真水群" - + def __str__(self): return self.name diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index a82acc413..355cca1e6 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -39,6 +39,7 @@ class ChatMessageContext: return self.message def check_types(self, types: list) -> bool: + # sourcery skip: invert-any-all, use-any, use-next """检查消息类型""" if not self.message.message_info.format_info.accept_format: return False diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index b5e9890eb..4d28c5d88 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -561,7 +561,9 @@ class NormalChat: available_actions = None if self.enable_planner: try: - await self.action_modifier.modify_actions(mode=ChatMode.NORMAL, message_content=message.processed_plain_text) + await self.action_modifier.modify_actions( + mode=ChatMode.NORMAL, message_content=message.processed_plain_text + ) available_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL) except Exception as e: logger.warning(f"[{self.stream_name}] 获取available_actions失败: {e}") diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 3937d1d14..e4dabd22f 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Type, Any +from typing import Dict, List, Optional, Type from src.plugin_system.base.base_action import BaseAction from src.chat.message_receive.chat_stream import ChatStream from src.common.logger import get_logger @@ -7,6 +7,7 @@ from src.plugin_system.base.component_types import ComponentType, ActionActivati logger = get_logger("action_manager") + class ActionManager: """ 动作管理器,用于管理各种类型的动作 @@ -73,7 +74,7 @@ class ActionManager: "activation_keywords": action_info.activation_keywords, "keyword_case_sensitive": action_info.keyword_case_sensitive, # 模式和并行设置 - "mode_enable": action_info.mode_enable.value, + "mode_enable": action_info.mode_enable, "parallel_action": action_info.parallel_action, # 插件信息 "_plugin_name": getattr(action_info, "plugin_name", ""), @@ -187,7 +188,7 @@ class ActionManager: enabled_actions = {} for action_name, action_info in self._using_actions.items(): - action_mode = action_info.mode_enable + action_mode = action_info["mode_enable"] # 检查动作是否在当前模式下启用 if action_mode in [ChatMode.ALL, mode]: diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 4b15cbdb0..6b0e6a633 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -2,16 +2,16 @@ import random import asyncio import hashlib import time -from typing import List, Optional, Any, Dict +from typing import List, Any, Dict from src.common.logger import get_logger from src.config.config import global_config from src.llm_models.utils_model import LLMRequest from src.chat.focus_chat.focus_loop_info import FocusLoopInfo -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext from src.chat.planner_actions.action_manager import ActionManager from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages -from src.plugin_system.base.component_types import ChatMode +from src.plugin_system.base.component_types import ChatMode, ActionInfo logger = get_logger("action_manager") @@ -48,7 +48,7 @@ class ActionModifier: loop_info=None, mode: ChatMode = ChatMode.FOCUS, message_content: str = "", - ): + ): # sourcery skip: use-named-expression """ 动作修改流程,整合传统观察处理和新的激活类型判定 @@ -129,15 +129,14 @@ class ActionModifier: f"{self.log_prefix}{mode}模式动作修改流程结束,最终可用动作: {list(self.action_manager.get_using_actions_for_mode(mode).keys())}||移除记录: {removals_summary}" ) - def _check_action_associated_types(self, all_actions, chat_context): + def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext): type_mismatched_actions = [] for action_name, data in all_actions.items(): - if data.get("associated_types"): - if not chat_context.check_types(data["associated_types"]): - associated_types_str = ", ".join(data["associated_types"]) - reason = f"适配器不支持(需要: {associated_types_str})" - type_mismatched_actions.append((action_name, reason)) - logger.debug(f"{self.log_prefix}决定移除动作: {action_name},原因: {reason}") + if data["associated_types"] and not chat_context.check_types(data["associated_types"]): + associated_types_str = ", ".join(data["associated_types"]) + reason = f"适配器不支持(需要: {associated_types_str})" + type_mismatched_actions.append((action_name, reason)) + logger.debug(f"{self.log_prefix}决定移除动作: {action_name},原因: {reason}") return type_mismatched_actions async def _get_deactivated_actions_by_type( @@ -205,35 +204,6 @@ class ActionModifier: return deactivated_actions - async def process_actions_for_planner( - self, observed_messages_str: str = "", chat_context: Optional[str] = None, extra_context: Optional[str] = None - ) -> Dict[str, Any]: - """ - [已废弃] 此方法现在已被整合到 modify_actions() 中 - - 为了保持向后兼容性而保留,但建议直接使用 ActionManager.get_using_actions() - 规划器应该直接从 ActionManager 获取最终的可用动作集,而不是调用此方法 - - 新的架构: - 1. 主循环调用 modify_actions() 处理完整的动作管理流程 - 2. 规划器直接使用 ActionManager.get_using_actions() 获取最终动作集 - """ - logger.warning( - f"{self.log_prefix}process_actions_for_planner() 已废弃,建议规划器直接使用 ActionManager.get_using_actions()" - ) - - # 为了向后兼容,仍然返回当前使用的动作集 - current_using_actions = self.action_manager.get_using_actions() - all_registered_actions = self.action_manager.get_registered_actions() - - # 构建完整的动作信息 - result = {} - for action_name in current_using_actions.keys(): - if action_name in all_registered_actions: - result[action_name] = all_registered_actions[action_name] - - return result - def _generate_context_hash(self, chat_content: str) -> str: """生成上下文的哈希值用于缓存""" context_content = f"{chat_content}" diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index cc5cbc261..42e36b64d 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Tuple, Optional from src.common.logger import get_logger +from src.chat.message_receive.chat_stream import ChatStream from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType from src.plugin_system.apis import send_api, database_api, message_api import time @@ -31,7 +32,7 @@ class BaseAction(ABC): reasoning: str, cycle_timers: dict, thinking_id: str, - chat_stream=None, + chat_stream: ChatStream = None, log_prefix: str = "", shutting_down: bool = False, plugin_config: dict = None, diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index a9aae4347..b8112a490 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Type, Optional, Any, Union +from typing import Dict, List, Type, Any, Union import os import inspect import toml @@ -29,18 +29,41 @@ class BasePlugin(ABC): """ # 插件基本信息(子类必须定义) - plugin_name: str = "" # 插件内部标识符(如 "hello_world_plugin") - enable_plugin: bool = False # 是否启用插件 - dependencies: List[str] = [] # 依赖的其他插件 - python_dependencies: List[PythonDependency] = [] # Python包依赖 - config_file_name: Optional[str] = None # 配置文件名 + @property + @abstractmethod + def plugin_name(self) -> str: + return "" # 插件内部标识符(如 "hello_world_plugin") + + @property + @abstractmethod + def enable_plugin(self) -> bool: + return True # 是否启用插件 + + @property + @abstractmethod + def dependencies(self) -> List[str]: + return [] # 依赖的其他插件 + + @property + @abstractmethod + def python_dependencies(self) -> List[PythonDependency]: + return [] # Python包依赖 + + @property + @abstractmethod + def config_file_name(self) -> str: + return "" # 配置文件名 # manifest文件相关 manifest_file_name: str = "_manifest.json" # manifest文件名 manifest_data: Dict[str, Any] = {} # manifest数据 # 配置定义 - config_schema: Dict[str, Union[Dict[str, ConfigField], str]] = {} + @property + @abstractmethod + def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]: + return {} + config_section_descriptions: Dict[str, str] = {} def __init__(self, plugin_dir: str = None): @@ -70,7 +93,8 @@ class BasePlugin(ABC): # 创建插件信息对象 self.plugin_info = PluginInfo( - name=self.display_name, # 使用显示名称 + name=self.plugin_name, + display_name=self.display_name, description=self.plugin_description, version=self.plugin_version, author=self.plugin_author, diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index f720823c6..771fba422 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -126,6 +126,7 @@ class CommandInfo(ComponentInfo): class PluginInfo: """插件信息""" + display_name: str # 插件显示名称 name: str # 插件名称 description: str # 插件描述 version: str = "1.0.0" # 插件版本 diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index a30a3028b..9d6bd805c 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -85,16 +85,17 @@ class PluginManager: total_failed_registration = 0 for plugin_name in self.plugin_classes.keys(): - if self.load_registered_plugin_classes(plugin_name): + load_status, count = self.load_registered_plugin_classes(plugin_name) + if load_status: total_registered += 1 else: - total_failed_registration += 1 + total_failed_registration += count self._show_stats(total_registered, total_failed_registration) return total_registered, total_failed_registration - def load_registered_plugin_classes(self, plugin_name: str) -> bool: + def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]: # sourcery skip: extract-duplicate-method, extract-method """ 加载已经注册的插件类 @@ -102,7 +103,7 @@ class PluginManager: plugin_class: Type[BasePlugin] = self.plugin_classes.get(plugin_name) if not plugin_class: logger.error(f"插件 {plugin_name} 的插件类未注册或不存在") - return False + return False, 1 try: # 使用记录的插件目录路径 plugin_dir = self.plugin_paths.get(plugin_name) @@ -116,7 +117,7 @@ class PluginManager: # 检查插件是否启用 if not plugin_instance.enable_plugin: logger.info(f"插件 {plugin_name} 已禁用,跳过加载") - return False + return False, 0 # 检查版本兼容性 is_compatible, compatibility_error = self._check_plugin_version_compatibility( @@ -125,22 +126,22 @@ class PluginManager: if not is_compatible: self.failed_plugins[plugin_name] = compatibility_error logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}") - return False + return False, 1 if plugin_instance.register_plugin(): self.loaded_plugins[plugin_name] = plugin_instance self._show_plugin_components(plugin_name) - return True + return True, 1 else: self.failed_plugins[plugin_name] = "插件注册失败" logger.error(f"❌ 插件注册失败: {plugin_name}") - return False + return False, 1 except FileNotFoundError as e: # manifest文件缺失 error_msg = f"缺少manifest文件: {str(e)}" self.failed_plugins[plugin_name] = error_msg logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") - return False + return False, 1 except ValueError as e: # manifest文件格式错误或验证失败 @@ -148,7 +149,7 @@ class PluginManager: error_msg = f"manifest验证失败: {str(e)}" self.failed_plugins[plugin_name] = error_msg logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") - return False + return False, 1 except Exception as e: # 其他错误 @@ -156,7 +157,7 @@ class PluginManager: self.failed_plugins[plugin_name] = error_msg logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}") logger.debug("详细错误信息: ", exc_info=True) - return False + return False, 1 def unload_registered_plugin_module(self, plugin_name: str) -> None: """ @@ -489,14 +490,16 @@ class PluginManager: info_parts = [part for part in [version_info, author_info, license_info] if part] extra_info = f" ({', '.join(info_parts)})" if info_parts else "" - logger.info(f" 📦 {plugin_name}{extra_info}") + logger.info(f" 📦 {plugin_info.display_name}{extra_info}") # Manifest信息 if plugin_info.manifest_data: + """ if plugin_info.keywords: logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}") if plugin_info.categories: logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}") + """ if plugin_info.homepage_url: logger.info(f" 🌐 主页: {plugin_info.homepage_url}") @@ -533,9 +536,9 @@ class PluginManager: plugins_in_dir.append(plugin_name) if plugins_in_dir: - logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})") + logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})") else: - logger.info(f" 📁 {directory}: 0个插件") + logger.info(f" 📁 {directory}: 0个插件") # 失败信息 if total_failed_registration > 0: diff --git a/src/plugins/built_in/core_actions/plugin.py b/src/plugins/built_in/core_actions/plugin.py index 2b7194063..b15e72522 100644 --- a/src/plugins/built_in/core_actions/plugin.py +++ b/src/plugins/built_in/core_actions/plugin.py @@ -136,6 +136,8 @@ class CoreActionsPlugin(BasePlugin): # 插件基本信息 plugin_name = "core_actions" # 内部标识符 enable_plugin = True + dependencies = [] # 插件依赖列表 + python_dependencies = [] # Python包依赖列表 config_file_name = "config.toml" # 配置节描述 diff --git a/src/plugins/built_in/tts_plugin/plugin.py b/src/plugins/built_in/tts_plugin/plugin.py index 5f563e966..7d45f4d30 100644 --- a/src/plugins/built_in/tts_plugin/plugin.py +++ b/src/plugins/built_in/tts_plugin/plugin.py @@ -109,6 +109,8 @@ class TTSPlugin(BasePlugin): # 插件基本信息 plugin_name = "tts_plugin" # 内部标识符 enable_plugin = True + dependencies = [] # 插件依赖列表 + python_dependencies = [] # Python包依赖列表 config_file_name = "config.toml" # 配置节描述 diff --git a/src/plugins/built_in/vtb_plugin/plugin.py b/src/plugins/built_in/vtb_plugin/plugin.py index 2932205b5..e18841f03 100644 --- a/src/plugins/built_in/vtb_plugin/plugin.py +++ b/src/plugins/built_in/vtb_plugin/plugin.py @@ -110,6 +110,8 @@ class VTBPlugin(BasePlugin): # 插件基本信息 plugin_name = "vtb_plugin" # 内部标识符 enable_plugin = True + dependencies = [] # 插件依赖列表 + python_dependencies = [] # Python包依赖列表 config_file_name = "config.toml" # 配置节描述 From 968eb921073a071f6925236b81f1f8c49f06af3a Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Fri, 11 Jul 2025 00:59:49 +0800 Subject: [PATCH 09/13] =?UTF-8?q?=E4=B8=8D=E5=86=8D=E8=BF=9B=E8=A1=8Cactio?= =?UTF-8?q?n=5Finfo=E8=BD=AC=E6=8D=A2=E4=BA=86=EF=BC=8C=E4=BF=9D=E6=8C=81?= =?UTF-8?q?=E4=B8=80=E8=87=B4=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- src/chat/normal_chat/normal_chat.py | 10 +-- src/chat/planner_actions/action_manager.py | 27 +----- src/chat/planner_actions/action_modifier.py | 29 +++--- src/chat/planner_actions/planner.py | 42 ++++----- src/chat/replyer/default_generator.py | 93 ++++++++++---------- src/mais4u/mais4u_chat/s4u_msg_processor.py | 4 +- src/mood/mood_manager.py | 12 +-- src/plugin_system/apis/generator_api.py | 3 +- src/plugin_system/base/component_types.py | 10 ++- src/plugin_system/core/component_registry.py | 27 +++--- src/plugin_system/core/dependency_manager.py | 16 ++-- src/plugin_system/core/plugin_manager.py | 12 +-- 13 files changed, 137 insertions(+), 151 deletions(-) diff --git a/.gitignore b/.gitignore index 326b85948..2b6f89dcc 100644 --- a/.gitignore +++ b/.gitignore @@ -316,4 +316,5 @@ run_pet.bat !/plugins/hello_world_plugin !/plugins/take_picture_plugin -config.toml \ No newline at end of file +config.toml +备忘录.txt \ No newline at end of file diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index 4d28c5d88..63e394c7c 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -2,14 +2,14 @@ import asyncio import time import traceback from random import random -from typing import List, Optional +from typing import List, Optional, Dict from maim_message import UserInfo, Seg from src.config.config import global_config from src.common.logger import get_logger from src.common.message_repository import count_messages from src.plugin_system.apis import generator_api -from src.plugin_system.base.component_types import ChatMode +from src.plugin_system.base.component_types import ChatMode, ActionInfo from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet from src.chat.message_receive.normal_message_sender import message_manager @@ -175,12 +175,12 @@ class NormalChat: # 改为实例方法 async def _create_thinking_message(self, message: MessageRecv, timestamp: Optional[float] = None) -> str: """创建思考消息""" - messageinfo = message.message_info + message_info = message.message_info bot_user_info = UserInfo( user_id=global_config.bot.qq_account, user_nickname=global_config.bot.nickname, - platform=messageinfo.platform, + platform=message_info.platform, ) thinking_time_point = round(time.time(), 2) @@ -456,7 +456,7 @@ class NormalChat: willing_manager.delete(message.message_info.message_id) async def _generate_normal_response( - self, message: MessageRecv, available_actions: Optional[list] + self, message: MessageRecv, available_actions: Optional[Dict[str, ActionInfo]] ) -> Optional[list]: """生成普通回复""" try: diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index e4dabd22f..45bdfd72d 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -59,32 +59,11 @@ class ActionManager: logger.debug(f"Action组件 {action_name} 已存在,跳过") continue - # 将插件系统的ActionInfo转换为ActionManager格式 - converted_action_info = { - "description": action_info.description, - "parameters": getattr(action_info, "action_parameters", {}), - "require": getattr(action_info, "action_require", []), - "associated_types": getattr(action_info, "associated_types", []), - "enable_plugin": action_info.enabled, - # 激活类型相关 - "focus_activation_type": action_info.focus_activation_type.value, - "normal_activation_type": action_info.normal_activation_type.value, - "random_activation_probability": action_info.random_activation_probability, - "llm_judge_prompt": action_info.llm_judge_prompt, - "activation_keywords": action_info.activation_keywords, - "keyword_case_sensitive": action_info.keyword_case_sensitive, - # 模式和并行设置 - "mode_enable": action_info.mode_enable, - "parallel_action": action_info.parallel_action, - # 插件信息 - "_plugin_name": getattr(action_info, "plugin_name", ""), - } - - self._registered_actions[action_name] = converted_action_info + self._registered_actions[action_name] = action_info # 如果启用,也添加到默认动作集 if action_info.enabled: - self._default_actions[action_name] = converted_action_info + self._default_actions[action_name] = action_info logger.debug( f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})" @@ -188,7 +167,7 @@ class ActionManager: enabled_actions = {} for action_name, action_info in self._using_actions.items(): - action_mode = action_info["mode_enable"] + action_mode = action_info.mode_enable # 检查动作是否在当前模式下启用 if action_mode in [ChatMode.ALL, mode]: diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 6b0e6a633..8aaafc201 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -11,7 +11,7 @@ from src.chat.focus_chat.focus_loop_info import FocusLoopInfo from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext from src.chat.planner_actions.action_manager import ActionManager from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages -from src.plugin_system.base.component_types import ChatMode, ActionInfo +from src.plugin_system.base.component_types import ChatMode, ActionInfo, ActionActivationType logger = get_logger("action_manager") @@ -131,9 +131,9 @@ class ActionModifier: def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext): type_mismatched_actions = [] - for action_name, data in all_actions.items(): - if data["associated_types"] and not chat_context.check_types(data["associated_types"]): - associated_types_str = ", ".join(data["associated_types"]) + for action_name, action_info in all_actions.items(): + if action_info.associated_types and not chat_context.check_types(action_info.associated_types): + associated_types_str = ", ".join(action_info.associated_types) reason = f"适配器不支持(需要: {associated_types_str})" type_mismatched_actions.append((action_name, reason)) logger.debug(f"{self.log_prefix}决定移除动作: {action_name},原因: {reason}") @@ -141,7 +141,7 @@ class ActionModifier: async def _get_deactivated_actions_by_type( self, - actions_with_info: Dict[str, Any], + actions_with_info: Dict[str, ActionInfo], mode: str = "focus", chat_content: str = "", ) -> List[tuple[str, str]]: @@ -164,27 +164,26 @@ class ActionModifier: random.shuffle(actions_to_check) for action_name, action_info in actions_to_check: - activation_type = f"{mode}_activation_type" - activation_type = action_info.get(activation_type, "always") - - if activation_type == "always": + mode_activation_type = f"{mode}_activation_type" + activation_type = getattr(action_info, mode_activation_type, ActionActivationType.ALWAYS) + if activation_type == ActionActivationType.ALWAYS: continue # 总是激活,无需处理 - elif activation_type == "random": - probability = action_info.get("random_activation_probability", ActionManager.DEFAULT_RANDOM_PROBABILITY) - if not (random.random() < probability): + elif activation_type == ActionActivationType.RANDOM: + probability = action_info.random_activation_probability or ActionManager.DEFAULT_RANDOM_PROBABILITY + if random.random() >= probability: reason = f"RANDOM类型未触发(概率{probability})" deactivated_actions.append((action_name, reason)) logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}") - elif activation_type == "keyword": + elif activation_type == ActionActivationType.KEYWORD: if not self._check_keyword_activation(action_name, action_info, chat_content): - keywords = action_info.get("activation_keywords", []) + keywords = action_info.activation_keywords reason = f"关键词未匹配(关键词: {keywords})" deactivated_actions.append((action_name, reason)) logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}") - elif activation_type == "llm_judge": + elif activation_type == ActionActivationType.LLM_JUDGE: llm_judge_actions[action_name] = action_info else: diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index db7001b14..f4c8a9a4a 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -14,7 +14,7 @@ from src.chat.utils.chat_message_builder import build_readable_messages, get_raw from src.chat.utils.utils import get_chat_type_and_target_info from src.chat.planner_actions.action_manager import ActionManager from src.chat.message_receive.chat_stream import get_chat_manager -from src.plugin_system.base.component_types import ChatMode +from src.plugin_system.base.component_types import ChatMode, ActionInfo logger = get_logger("planner") @@ -26,7 +26,7 @@ def init_prompt(): Prompt( """ {time_block} -{indentify_block} +{identity_block} 你现在需要根据聊天内容,选择的合适的action来参与聊天。 {chat_context_description},以下是具体的聊天内容: {chat_content_block} @@ -78,6 +78,7 @@ class ActionPlanner: action = "no_reply" # 默认动作 reasoning = "规划器初始化默认" action_data = {} + current_available_actions: Dict[str, ActionInfo] = {} try: is_group_chat = True @@ -89,7 +90,7 @@ class ActionPlanner: # 获取完整的动作信息 all_registered_actions = self.action_manager.get_registered_actions() - current_available_actions = {} + for action_name in current_available_actions_dict.keys(): if action_name in all_registered_actions: current_available_actions[action_name] = all_registered_actions[action_name] @@ -101,13 +102,17 @@ class ActionPlanner: len(current_available_actions) == 1 and "no_reply" in current_available_actions ): action = "no_reply" - reasoning = "没有可用的动作" if not current_available_actions else "只有no_reply动作可用,跳过规划" + reasoning = "只有no_reply动作可用,跳过规划" if current_available_actions else "没有可用的动作" logger.info(f"{self.log_prefix}{reasoning}") logger.debug( f"{self.log_prefix}[focus]沉默后恢复到默认动作集, 当前可用: {list(self.action_manager.get_using_actions().keys())}" ) return { - "action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning}, + "action_result": { + "action_type": action, + "action_data": action_data, + "reasoning": reasoning, + }, } # --- 构建提示词 (调用修改后的 PromptBuilder 方法) --- @@ -135,7 +140,7 @@ class ActionPlanner: except Exception as req_e: logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") - reasoning = f"LLM 请求失败,你的模型出现问题: {req_e}" + reasoning = f"LLM 请求失败,模型出现问题: {req_e}" action = "no_reply" if llm_content: @@ -168,8 +173,8 @@ class ActionPlanner: logger.warning( f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_reply'" ) - action = "no_reply" reasoning = f"LLM 返回了当前不可用的动作 '{action}' (可用: {list(current_available_actions.keys())})。原始理由: {reasoning}" + action = "no_reply" except Exception as json_e: logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'") @@ -185,8 +190,7 @@ class ActionPlanner: is_parallel = False if action in current_available_actions: - action_info = current_available_actions[action] - is_parallel = action_info.get("parallel_action", False) + is_parallel = current_available_actions[action].parallel_action action_result = { "action_type": action, @@ -196,19 +200,17 @@ class ActionPlanner: "is_parallel": is_parallel, } - plan_result = { + return { "action_result": action_result, "action_prompt": prompt, } - return plan_result - async def build_planner_prompt( self, is_group_chat: bool, # Now passed as argument chat_target_info: Optional[dict], # Now passed as argument - current_available_actions, - ) -> str: + current_available_actions: Dict[str, ActionInfo], + ) -> str: # sourcery skip: use-join """构建 Planner LLM 的提示词 (获取模板并填充数据)""" try: message_list_before_now = get_raw_msg_before_timestamp_with_chat( @@ -247,23 +249,23 @@ class ActionPlanner: action_options_block = "" for using_actions_name, using_actions_info in current_available_actions.items(): - if using_actions_info["parameters"]: + if using_actions_info.action_parameters: param_text = "\n" - for param_name, param_description in using_actions_info["parameters"].items(): + for param_name, param_description in using_actions_info.action_parameters.items(): param_text += f' "{param_name}":"{param_description}"\n' param_text = param_text.rstrip("\n") else: param_text = "" require_text = "" - for require_item in using_actions_info["require"]: + for require_item in using_actions_info.action_require: require_text += f"- {require_item}\n" require_text = require_text.rstrip("\n") using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt") using_action_prompt = using_action_prompt.format( action_name=using_actions_name, - action_description=using_actions_info["description"], + action_description=using_actions_info.description, action_parameters=param_text, action_require=require_text, ) @@ -280,7 +282,7 @@ class ActionPlanner: else: bot_nickname = "" bot_core_personality = global_config.personality.personality_core - indentify_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:" + identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:" planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") prompt = planner_prompt_template.format( @@ -291,7 +293,7 @@ class ActionPlanner: no_action_block=no_action_block, action_options_text=action_options_block, moderation_prompt=moderation_prompt_block, - indentify_block=indentify_block, + identity_block=identity_block, ) return prompt diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 846112305..6cb526d11 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -1,33 +1,31 @@ import traceback -from typing import List, Optional, Dict, Any, Tuple - -from src.chat.message_receive.message import MessageRecv, MessageThinking, MessageSending -from src.chat.message_receive.message import Seg # Local import needed after move -from src.chat.message_receive.message import UserInfo -from src.chat.message_receive.chat_stream import get_chat_manager -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config -from src.chat.utils.timer_calculator import Timer # <--- Import Timer -from src.chat.message_receive.uni_message_sender import HeartFCSender -from src.chat.utils.utils import get_chat_type_and_target_info -from src.chat.message_receive.chat_stream import ChatStream -from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat import time import asyncio -from src.chat.express.expression_selector import expression_selector -from src.mood.mood_manager import mood_manager -from src.person_info.relationship_fetcher import relationship_fetcher_manager import random import ast -from src.person_info.person_info import get_person_info_manager -from datetime import datetime import re +from typing import List, Optional, Dict, Any, Tuple +from datetime import datetime + +from src.common.logger import get_logger +from src.config.config import global_config +from src.llm_models.utils_model import LLMRequest +from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageThinking, MessageSending +from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream +from src.chat.message_receive.uni_message_sender import HeartFCSender +from src.chat.utils.timer_calculator import Timer # <--- Import Timer +from src.chat.utils.utils import get_chat_type_and_target_info +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat +from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp +from src.chat.express.expression_selector import expression_selector from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.memory_system.memory_activator import MemoryActivator +from src.mood.mood_manager import mood_manager +from src.person_info.relationship_fetcher import relationship_fetcher_manager +from src.person_info.person_info import get_person_info_manager from src.tools.tool_executor import ToolExecutor +from src.plugin_system.base.component_types import ActionInfo logger = get_logger("replyer") @@ -143,12 +141,12 @@ class DefaultReplyer: return None chat = anchor_message.chat_stream - messageinfo = anchor_message.message_info + message_info = anchor_message.message_info thinking_time_point = parse_thinking_id_to_timestamp(thinking_id) bot_user_info = UserInfo( user_id=global_config.bot.qq_account, user_nickname=global_config.bot.nickname, - platform=messageinfo.platform, + platform=message_info.platform, ) thinking_message = MessageThinking( @@ -168,7 +166,7 @@ class DefaultReplyer: reply_data: Dict[str, Any] = None, reply_to: str = "", extra_info: str = "", - available_actions: List[str] = None, + available_actions: Optional[Dict[str, ActionInfo]] = None, enable_tool: bool = True, enable_timeout: bool = False, ) -> Tuple[bool, Optional[str]]: @@ -177,7 +175,7 @@ class DefaultReplyer: (已整合原 HeartFCGenerator 的功能) """ if available_actions is None: - available_actions = [] + available_actions = {} if reply_data is None: reply_data = {} try: @@ -323,8 +321,8 @@ class DefaultReplyer: if not global_config.expression.enable_expression: return "" - style_habbits = [] - grammar_habbits = [] + style_habits = [] + grammar_habits = [] # 使用从处理器传来的选中表达方式 # LLM模式:调用LLM选择5-10个,然后随机选5个 @@ -338,22 +336,22 @@ class DefaultReplyer: if isinstance(expr, dict) and "situation" in expr and "style" in expr: expr_type = expr.get("type", "style") if expr_type == "grammar": - grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") + grammar_habits.append(f"当{expr['situation']}时,使用 {expr['style']}") else: - style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") + style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}") else: logger.debug(f"{self.log_prefix} 没有从处理器获得表达方式,将使用空的表达方式") # 不再在replyer中进行随机选择,全部交给处理器处理 - style_habbits_str = "\n".join(style_habbits) - grammar_habbits_str = "\n".join(grammar_habbits) + style_habits_str = "\n".join(style_habits) + grammar_habits_str = "\n".join(grammar_habits) # 动态构建expression habits块 expression_habits_block = "" - if style_habbits_str.strip(): - expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habbits_str}\n\n" - if grammar_habbits_str.strip(): - expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habbits_str}\n" + if style_habits_str.strip(): + expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n" + if grammar_habits_str.strip(): + expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habits_str}\n" return expression_habits_block @@ -361,13 +359,13 @@ class DefaultReplyer: if not global_config.memory.enable_memory: return "" - running_memorys = await self.memory_activator.activate_memory_with_chat_history( + running_memories = await self.memory_activator.activate_memory_with_chat_history( target_message=target, chat_history_prompt=chat_history ) - if running_memorys: + if running_memories: memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" - for running_memory in running_memorys: + for running_memory in running_memories: memory_str += f"- {running_memory['content']}\n" memory_block = memory_str else: @@ -465,10 +463,10 @@ class DefaultReplyer: return keywords_reaction_prompt - async def _time_and_run_task(self, coro, name: str): + async def _time_and_run_task(self, coroutine, name: str): """一个简单的帮助函数,用于计时和运行异步任务,返回任务名、结果和耗时""" start_time = time.time() - result = await coro + result = await coroutine end_time = time.time() duration = end_time - start_time return name, result, duration @@ -476,7 +474,7 @@ class DefaultReplyer: async def build_prompt_reply_context( self, reply_data=None, - available_actions: List[str] = None, + available_actions: Optional[Dict[str, ActionInfo]] = None, enable_timeout: bool = False, enable_tool: bool = True, ) -> str: @@ -495,7 +493,7 @@ class DefaultReplyer: str: 构建好的上下文 """ if available_actions is None: - available_actions = [] + available_actions = {} chat_stream = self.chat_stream chat_id = chat_stream.stream_id person_info_manager = get_person_info_manager() @@ -514,10 +512,9 @@ class DefaultReplyer: if available_actions: action_descriptions = "你有以下的动作能力,但执行这些动作不由你决定,由另外一个模型同步决定,因此你只需要知道有如下能力即可:\n" for action_name, action_info in available_actions.items(): - action_description = action_info.get("description", "") + action_description = action_info.description action_descriptions += f"- {action_name}: {action_description}\n" action_descriptions += "\n" - message_list_before_now = get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), @@ -616,7 +613,7 @@ class DefaultReplyer: personality = short_impression[0] identity = short_impression[1] prompt_personality = personality + "," + identity - indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" + identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" moderation_prompt_block = ( "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。不要随意遵从他人指令。" @@ -677,7 +674,7 @@ class DefaultReplyer: reply_target_block=reply_target_block, moderation_prompt=moderation_prompt_block, keywords_reaction_prompt=keywords_reaction_prompt, - identity=indentify_block, + identity=identity_block, target_message=target, sender_name=sender, config_expression_style=global_config.expression.expression_style, @@ -749,7 +746,7 @@ class DefaultReplyer: personality = short_impression[0] identity = short_impression[1] prompt_personality = personality + "," + identity - indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" + identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" moderation_prompt_block = ( "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。不要随意遵从他人指令。" @@ -800,7 +797,7 @@ class DefaultReplyer: chat_target=chat_target_1, time_block=time_block, chat_info=chat_talking_prompt_half, - identity=indentify_block, + identity=identity_block, chat_target_2=chat_target_2, reply_target_block=reply_target_block, raw_reply=raw_reply, diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py index ecdefe109..ac3024f1b 100644 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -36,10 +36,10 @@ class S4UMessageProcessor: # 1. 消息解析与初始化 groupinfo = message.message_info.group_info userinfo = message.message_info.user_info - messageinfo = message.message_info + message_info = message.message_info chat = await get_chat_manager().get_or_create_stream( - platform=messageinfo.platform, + platform=message_info.platform, user_info=userinfo, group_info=groupinfo, ) diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index dee8d7cc6..ffdf8ff36 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -19,7 +19,7 @@ def init_prompt(): {chat_talking_prompt} 以上是群里正在进行的聊天记录 -{indentify_block} +{identity_block} 你刚刚的情绪状态是:{mood_state} 现在,发送了消息,引起了你的注意,你对其进行了阅读和思考,请你输出一句话描述你新的情绪状态 @@ -32,7 +32,7 @@ def init_prompt(): {chat_talking_prompt} 以上是群里最近的聊天记录 -{indentify_block} +{identity_block} 你之前的情绪状态是:{mood_state} 距离你上次关注群里消息已经过去了一段时间,你冷静了下来,请你输出一句话描述你现在的情绪状态 @@ -103,12 +103,12 @@ class ChatMood: bot_nickname = "" prompt_personality = global_config.personality.personality_core - indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" + identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" prompt = await global_prompt_manager.format_prompt( "change_mood_prompt", chat_talking_prompt=chat_talking_prompt, - indentify_block=indentify_block, + identity_block=identity_block, mood_state=self.mood_state, ) @@ -147,12 +147,12 @@ class ChatMood: bot_nickname = "" prompt_personality = global_config.personality.personality_core - indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" + identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" prompt = await global_prompt_manager.format_prompt( "regress_mood_prompt", chat_talking_prompt=chat_talking_prompt, - indentify_block=indentify_block, + identity_block=identity_block, mood_state=self.mood_state, ) diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index d4ed0f51b..c341e5214 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -15,6 +15,7 @@ from src.chat.replyer.default_generator import DefaultReplyer from src.chat.message_receive.chat_stream import ChatStream from src.chat.utils.utils import process_llm_response from src.chat.replyer.replyer_manager import replyer_manager +from src.plugin_system.base.component_types import ActionInfo logger = get_logger("generator_api") @@ -69,7 +70,7 @@ async def generate_reply( action_data: Dict[str, Any] = None, reply_to: str = "", extra_info: str = "", - available_actions: List[str] = None, + available_actions: Optional[Dict[str, ActionInfo]] = None, enable_tool: bool = False, enable_splitter: bool = True, enable_chinese_typo: bool = True, diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 771fba422..bc66100d9 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -66,7 +66,7 @@ class ComponentInfo: name: str # 组件名称 component_type: ComponentType # 组件类型 - description: str # 组件描述 + description: str = "" # 组件描述 enabled: bool = True # 是否启用 plugin_name: str = "" # 所属插件名称 is_built_in: bool = False # 是否为内置组件 @@ -81,17 +81,19 @@ class ComponentInfo: class ActionInfo(ComponentInfo): """动作组件信息""" + action_parameters: Dict[str, str] = field(default_factory=dict) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"} + action_require: List[str] = field(default_factory=list) # 动作需求说明 + associated_types: List[str] = field(default_factory=list) # 关联的消息类型 + # 激活类型相关 focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS random_activation_probability: float = 0.0 llm_judge_prompt: str = "" activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表 keyword_case_sensitive: bool = False + # 模式和并行设置 mode_enable: ChatMode = ChatMode.ALL parallel_action: bool = False - action_parameters: Dict[str, Any] = field(default_factory=dict) # 动作参数 - action_require: List[str] = field(default_factory=list) # 动作需求说明 - associated_types: List[str] = field(default_factory=list) # 关联的消息类型 def __post_init__(self): super().__post_init__() diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 809319802..2ec77c7b7 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -35,7 +35,7 @@ class ComponentRegistry: # Action特定注册表 self._action_registry: Dict[str, BaseAction] = {} # action名 -> action类 - self._default_actions: Dict[str, str] = {} # 启用的action名 -> 描述 + # self._action_descriptions: Dict[str, str] = {} # 启用的action名 -> 描述 # Command特定注册表 self._command_registry: Dict[str, BaseCommand] = {} # command名 -> command类 @@ -99,13 +99,16 @@ class ComponentRegistry: return True def _register_action_component(self, action_info: ActionInfo, action_class: BaseAction): + # -------------------------------- NEED REFACTORING -------------------------------- + # -------------------------------- LOGIC ERROR ------------------------------------- """注册Action组件到Action特定注册表""" action_name = action_info.name self._action_registry[action_name] = action_class # 如果启用,添加到默认动作集 - if action_info.enabled: - self._default_actions[action_name] = action_info.description + # ---- HERE ---- + # if action_info.enabled: + # self._action_descriptions[action_name] = action_info.description def _register_command_component(self, command_info: CommandInfo, command_class: BaseCommand): """注册Command组件到Command特定注册表""" @@ -231,10 +234,6 @@ class ComponentRegistry: """获取Action注册表(用于兼容现有系统)""" return self._action_registry.copy() - def get_default_actions(self) -> Dict[str, str]: - """获取默认启用的Action列表(用于兼容现有系统)""" - return self._default_actions.copy() - def get_action_info(self, action_name: str) -> Optional[ActionInfo]: """获取Action信息""" info = self.get_component_info(action_name, ComponentType.ACTION) @@ -343,6 +342,8 @@ class ComponentRegistry: # === 状态管理方法 === def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool: + # -------------------------------- NEED REFACTORING -------------------------------- + # -------------------------------- LOGIC ERROR ------------------------------------- """启用组件,支持命名空间解析""" # 首先尝试找到正确的命名空间化名称 component_info = self.get_component_info(component_name, component_type) @@ -364,13 +365,16 @@ class ComponentRegistry: if namespaced_name in self._components: self._components[namespaced_name].enabled = True # 如果是Action,更新默认动作集 - if isinstance(component_info, ActionInfo): - self._default_actions[component_name] = component_info.description + # ---- HERE ---- + # if isinstance(component_info, ActionInfo): + # self._action_descriptions[component_name] = component_info.description logger.debug(f"已启用组件: {component_name} -> {namespaced_name}") return True return False def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool: + # -------------------------------- NEED REFACTORING -------------------------------- + # -------------------------------- LOGIC ERROR ------------------------------------- """禁用组件,支持命名空间解析""" # 首先尝试找到正确的命名空间化名称 component_info = self.get_component_info(component_name, component_type) @@ -392,8 +396,9 @@ class ComponentRegistry: if namespaced_name in self._components: self._components[namespaced_name].enabled = False # 如果是Action,从默认动作集中移除 - if component_name in self._default_actions: - del self._default_actions[component_name] + # ---- HERE ---- + # if component_name in self._action_descriptions: + # del self._action_descriptions[component_name] logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}") return True return False diff --git a/src/plugin_system/core/dependency_manager.py b/src/plugin_system/core/dependency_manager.py index dcba27c73..4a995e028 100644 --- a/src/plugin_system/core/dependency_manager.py +++ b/src/plugin_system/core/dependency_manager.py @@ -37,16 +37,14 @@ class DependencyManager: missing_optional = [] for dep in dependencies: - if not self._is_package_available(dep.package_name): - if dep.optional: - missing_optional.append(dep) - logger.warning(f"可选依赖包缺失: {dep.package_name} - {dep.description}") - else: - missing_required.append(dep) - logger.error(f"必需依赖包缺失: {dep.package_name} - {dep.description}") - else: + if self._is_package_available(dep.package_name): logger.debug(f"依赖包已存在: {dep.package_name}") - + elif dep.optional: + missing_optional.append(dep) + logger.warning(f"可选依赖包缺失: {dep.package_name} - {dep.description}") + else: + missing_required.append(dep) + logger.error(f"必需依赖包缺失: {dep.package_name} - {dep.description}") return missing_required, missing_optional def _is_package_available(self, package_name: str) -> bool: diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 9d6bd805c..fd75d8c9d 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -24,12 +24,14 @@ class PluginManager: """ def __init__(self): - self.plugin_directories: List[str] = [] - self.loaded_plugins: Dict[str, "BasePlugin"] = {} - self.failed_plugins: Dict[str, str] = {} - self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射 + self.plugin_directories: List[str] = [] # 插件根目录列表 + self.plugin_classes: Dict[str, Type[BasePlugin]] = {} # 全局插件类注册表,插件名 -> 插件类 + self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径 + + self.loaded_plugins: Dict[str, BasePlugin] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例 + self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件类及其错误信息,插件名 -> 错误信息 + self.events_subscriptions: Dict[EventType, List[Callable]] = {} - self.plugin_classes: Dict[str, Type[BasePlugin]] = {} # 全局插件类注册表 # 确保插件目录存在 self._ensure_plugin_directories() From b303a95f61fbf6aaf4110a85881c821bb32f3dc5 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sat, 12 Jul 2025 00:34:49 +0800 Subject: [PATCH 10/13] =?UTF-8?q?=E9=83=A8=E5=88=86=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E6=B3=A8=E8=A7=A3=E4=BF=AE=E5=A4=8D=EF=BC=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?import=E9=A1=BA=E5=BA=8F=EF=BC=8C=E5=88=A0=E9=99=A4=E6=97=A0?= =?UTF-8?q?=E7=94=A8API=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/__init__.py | 0 src/api/apiforgui.py | 26 -- src/api/basic_info_api.py | 169 ---------- src/api/config_api.py | 317 ------------------ src/api/maigraphql/__init__.py | 22 -- src/api/maigraphql/schema.py | 1 - src/api/main.py | 112 ------- src/api/reload_config.py | 24 -- src/chat/emoji_system/emoji_manager.py | 45 +-- src/chat/express/expression_selector.py | 17 +- src/chat/express/exprssion_learner.py | 20 +- src/chat/focus_chat/focus_loop_info.py | 4 +- src/chat/focus_chat/heartFC_chat.py | 36 +- src/chat/focus_chat/hfc_performance_logger.py | 1 + src/chat/focus_chat/hfc_utils.py | 9 +- src/chat/heart_flow/heartflow.py | 8 +- .../heart_flow/heartflow_message_processor.py | 30 +- src/chat/heart_flow/sub_heartflow.py | 12 +- src/chat/memory_system/Hippocampus.py | 120 +++---- src/chat/memory_system/memory_activator.py | 13 +- src/chat/memory_system/sample_distribution.py | 42 --- src/chat/message_receive/bot.py | 26 +- src/chat/message_receive/chat_stream.py | 32 +- src/chat/message_receive/message.py | 55 ++- .../message_receive/normal_message_sender.py | 35 +- src/chat/message_receive/storage.py | 31 +- .../message_receive/uni_message_sender.py | 23 +- src/chat/normal_chat/normal_chat.py | 8 +- src/chat/normal_chat/priority_manager.py | 2 +- .../normal_chat/willing/mode_classical.py | 4 +- .../normal_chat/willing/willing_manager.py | 18 +- src/chat/planner_actions/action_manager.py | 68 ++-- src/chat/planner_actions/action_modifier.py | 28 +- src/chat/planner_actions/planner.py | 7 +- src/chat/replyer/default_generator.py | 58 ++-- src/chat/replyer/replyer_manager.py | 12 +- src/chat/utils/chat_message_builder.py | 61 ++-- src/chat/utils/utils_image.py | 8 +- src/person_info/person_info.py | 4 +- src/plugin_system/base/base_action.py | 4 +- src/plugin_system/base/base_command.py | 2 +- src/plugin_system/base/base_plugin.py | 4 +- src/plugin_system/base/component_types.py | 4 +- src/plugin_system/core/component_registry.py | 49 +-- 44 files changed, 405 insertions(+), 1166 deletions(-) delete mode 100644 src/api/__init__.py delete mode 100644 src/api/apiforgui.py delete mode 100644 src/api/basic_info_api.py delete mode 100644 src/api/config_api.py delete mode 100644 src/api/maigraphql/__init__.py delete mode 100644 src/api/maigraphql/schema.py delete mode 100644 src/api/main.py delete mode 100644 src/api/reload_config.py diff --git a/src/api/__init__.py b/src/api/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/api/apiforgui.py b/src/api/apiforgui.py deleted file mode 100644 index 058c6fc96..000000000 --- a/src/api/apiforgui.py +++ /dev/null @@ -1,26 +0,0 @@ -from src.chat.heart_flow.heartflow import heartflow -from src.chat.heart_flow.sub_heartflow import ChatState -from src.common.logger import get_logger - -logger = get_logger("api") - - -async def get_all_subheartflow_ids() -> list: - """获取所有子心流的ID列表""" - all_subheartflows = heartflow.subheartflow_manager.get_all_subheartflows() - return [subheartflow.subheartflow_id for subheartflow in all_subheartflows] - - -async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatState) -> bool: - """强制改变子心流的状态""" - subheartflow = await heartflow.get_or_create_subheartflow(subheartflow_id) - if subheartflow: - return await heartflow.force_change_subheartflow_status(subheartflow_id, status) - return False - - -async def get_all_states(): - """获取所有状态""" - all_states = await heartflow.api_get_all_states() - logger.debug(f"所有状态: {all_states}") - return all_states diff --git a/src/api/basic_info_api.py b/src/api/basic_info_api.py deleted file mode 100644 index 4e5fa4c7d..000000000 --- a/src/api/basic_info_api.py +++ /dev/null @@ -1,169 +0,0 @@ -import platform -import psutil -import sys -import os - - -def get_system_info(): - """获取操作系统信息""" - return { - "system": platform.system(), - "release": platform.release(), - "version": platform.version(), - "machine": platform.machine(), - "processor": platform.processor(), - } - - -def get_python_version(): - """获取 Python 版本信息""" - return sys.version - - -def get_cpu_usage(): - """获取系统总CPU使用率""" - return psutil.cpu_percent(interval=1) - - -def get_process_cpu_usage(): - """获取当前进程CPU使用率""" - process = psutil.Process(os.getpid()) - return process.cpu_percent(interval=1) - - -def get_memory_usage(): - """获取系统内存使用情况 (单位 MB)""" - mem = psutil.virtual_memory() - bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa - return { - "total_mb": bytes_to_mb(mem.total), - "available_mb": bytes_to_mb(mem.available), - "percent": mem.percent, - "used_mb": bytes_to_mb(mem.used), - "free_mb": bytes_to_mb(mem.free), - } - - -def get_process_memory_usage(): - """获取当前进程内存使用情况 (单位 MB)""" - process = psutil.Process(os.getpid()) - mem_info = process.memory_info() - bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa - return { - "rss_mb": bytes_to_mb(mem_info.rss), # Resident Set Size: 实际使用物理内存 - "vms_mb": bytes_to_mb(mem_info.vms), # Virtual Memory Size: 虚拟内存大小 - "percent": process.memory_percent(), # 进程内存使用百分比 - } - - -def get_disk_usage(path="/"): - """获取指定路径磁盘使用情况 (单位 GB)""" - disk = psutil.disk_usage(path) - bytes_to_gb = lambda x: round(x / (1024 * 1024 * 1024), 2) # noqa - return { - "total_gb": bytes_to_gb(disk.total), - "used_gb": bytes_to_gb(disk.used), - "free_gb": bytes_to_gb(disk.free), - "percent": disk.percent, - } - - -def get_all_basic_info(): - """获取所有基本信息并封装返回""" - # 对于进程CPU使用率,需要先初始化 - process = psutil.Process(os.getpid()) - process.cpu_percent(interval=None) # 初始化调用 - process_cpu = process.cpu_percent(interval=0.1) # 短暂间隔获取 - - return { - "system_info": get_system_info(), - "python_version": get_python_version(), - "cpu_usage_percent": get_cpu_usage(), - "process_cpu_usage_percent": process_cpu, - "memory_usage": get_memory_usage(), - "process_memory_usage": get_process_memory_usage(), - "disk_usage_root": get_disk_usage("/"), - } - - -def get_all_basic_info_string() -> str: - """获取所有基本信息并以带解释的字符串形式返回""" - info = get_all_basic_info() - - sys_info = info["system_info"] - mem_usage = info["memory_usage"] - proc_mem_usage = info["process_memory_usage"] - disk_usage = info["disk_usage_root"] - - # 对进程内存使用百分比进行格式化,保留两位小数 - proc_mem_percent = round(proc_mem_usage["percent"], 2) - - output_string = f"""[系统信息] - - 操作系统: {sys_info["system"]} (例如: Windows, Linux) - - 发行版本: {sys_info["release"]} (例如: 11, Ubuntu 20.04) - - 详细版本: {sys_info["version"]} - - 硬件架构: {sys_info["machine"]} (例如: AMD64) - - 处理器信息: {sys_info["processor"]} - -[Python 环境] - - Python 版本: {info["python_version"]} - -[CPU 状态] - - 系统总 CPU 使用率: {info["cpu_usage_percent"]}% - - 当前进程 CPU 使用率: {info["process_cpu_usage_percent"]}% - -[系统内存使用情况] - - 总物理内存: {mem_usage["total_mb"]} MB - - 可用物理内存: {mem_usage["available_mb"]} MB - - 物理内存使用率: {mem_usage["percent"]}% - - 已用物理内存: {mem_usage["used_mb"]} MB - - 空闲物理内存: {mem_usage["free_mb"]} MB - -[当前进程内存使用情况] - - 实际使用物理内存 (RSS): {proc_mem_usage["rss_mb"]} MB - - 占用虚拟内存 (VMS): {proc_mem_usage["vms_mb"]} MB - - 进程内存使用率: {proc_mem_percent}% - -[磁盘使用情况 (根目录)] - - 总空间: {disk_usage["total_gb"]} GB - - 已用空间: {disk_usage["used_gb"]} GB - - 可用空间: {disk_usage["free_gb"]} GB - - 磁盘使用率: {disk_usage["percent"]}% -""" - return output_string - - -if __name__ == "__main__": - print(f"System Info: {get_system_info()}") - print(f"Python Version: {get_python_version()}") - print(f"CPU Usage: {get_cpu_usage()}%") - # 第一次调用 process.cpu_percent() 会返回0.0或一个无意义的值,需要间隔一段时间再调用 - # 或者在初始化Process对象后,先调用一次cpu_percent(interval=None),然后再调用cpu_percent(interval=1) - current_process = psutil.Process(os.getpid()) - current_process.cpu_percent(interval=None) # 初始化 - print(f"Process CPU Usage: {current_process.cpu_percent(interval=1)}%") # 实际获取 - - memory_usage_info = get_memory_usage() - print( - f"Memory Usage: Total={memory_usage_info['total_mb']}MB, Used={memory_usage_info['used_mb']}MB, Percent={memory_usage_info['percent']}%" - ) - - process_memory_info = get_process_memory_usage() - print( - f"Process Memory Usage: RSS={process_memory_info['rss_mb']}MB, VMS={process_memory_info['vms_mb']}MB, Percent={process_memory_info['percent']}%" - ) - - disk_usage_info = get_disk_usage("/") - print( - f"Disk Usage (Root): Total={disk_usage_info['total_gb']}GB, Used={disk_usage_info['used_gb']}GB, Percent={disk_usage_info['percent']}%" - ) - - print("\n--- All Basic Info (JSON) ---") - all_info = get_all_basic_info() - import json - - print(json.dumps(all_info, indent=4, ensure_ascii=False)) - - print("\n--- All Basic Info (String with Explanations) ---") - info_string = get_all_basic_info_string() - print(info_string) diff --git a/src/api/config_api.py b/src/api/config_api.py deleted file mode 100644 index 07f36a9d8..000000000 --- a/src/api/config_api.py +++ /dev/null @@ -1,317 +0,0 @@ -from typing import List, Optional, Dict, Any -import strawberry - -# from packaging.version import Version -import os - -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) - - -@strawberry.type -class APIBotConfig: - """机器人配置类""" - - INNER_VERSION: str # 配置文件内部版本号(toml为字符串) - MAI_VERSION: str # 硬编码的版本信息 - - # bot - BOT_QQ: Optional[int] # 机器人QQ号 - BOT_NICKNAME: Optional[str] # 机器人昵称 - BOT_ALIAS_NAMES: List[str] # 机器人别名列表 - - # group - talk_allowed_groups: List[int] # 允许回复消息的群号列表 - talk_frequency_down_groups: List[int] # 降低回复频率的群号列表 - ban_user_id: List[int] # 禁止回复和读取消息的QQ号列表 - - # personality - personality_core: str # 人格核心特点描述 - personality_sides: List[str] # 人格细节描述列表 - - # identity - identity_detail: List[str] # 身份特点列表 - age: int # 年龄(岁) - gender: str # 性别 - appearance: str # 外貌特征描述 - - # platforms - platforms: Dict[str, str] # 平台信息 - - # chat - allow_focus_mode: bool # 是否允许专注聊天状态 - base_normal_chat_num: int # 最多允许多少个群进行普通聊天 - base_focused_chat_num: int # 最多允许多少个群进行专注聊天 - observation_context_size: int # 观察到的最长上下文大小 - message_buffer: bool # 是否启用消息缓冲 - ban_words: List[str] # 禁止词列表 - ban_msgs_regex: List[str] # 禁止消息的正则表达式列表 - - # normal_chat - model_reasoning_probability: float # 推理模型概率 - model_normal_probability: float # 普通模型概率 - emoji_chance: float # 表情符号出现概率 - thinking_timeout: int # 思考超时时间 - willing_mode: str # 意愿模式 - response_interested_rate_amplifier: float # 回复兴趣率放大器 - emoji_response_penalty: float # 表情回复惩罚 - mentioned_bot_inevitable_reply: bool # 提及 bot 必然回复 - at_bot_inevitable_reply: bool # @bot 必然回复 - - # focus_chat - reply_trigger_threshold: float # 回复触发阈值 - default_decay_rate_per_second: float # 默认每秒衰减率 - - # compressed - compressed_length: int # 压缩长度 - compress_length_limit: int # 压缩长度限制 - - # emoji - max_emoji_num: int # 最大表情符号数量 - max_reach_deletion: bool # 达到最大数量时是否删除 - check_interval: int # 检查表情包的时间间隔(分钟) - save_emoji: bool # 是否保存表情包 - steal_emoji: bool # 是否偷取表情包 - enable_check: bool # 是否启用表情包过滤 - check_prompt: str # 表情包过滤要求 - - # memory - build_memory_interval: int # 记忆构建间隔 - build_memory_distribution: List[float] # 记忆构建分布 - build_memory_sample_num: int # 采样数量 - build_memory_sample_length: int # 采样长度 - memory_compress_rate: float # 记忆压缩率 - forget_memory_interval: int # 记忆遗忘间隔 - memory_forget_time: int # 记忆遗忘时间(小时) - memory_forget_percentage: float # 记忆遗忘比例 - consolidate_memory_interval: int # 记忆整合间隔 - consolidation_similarity_threshold: float # 相似度阈值 - consolidation_check_percentage: float # 检查节点比例 - memory_ban_words: List[str] # 记忆禁止词列表 - - # mood - mood_update_interval: float # 情绪更新间隔 - mood_decay_rate: float # 情绪衰减率 - mood_intensity_factor: float # 情绪强度因子 - - # keywords_reaction - keywords_reaction_enable: bool # 是否启用关键词反应 - keywords_reaction_rules: List[Dict[str, Any]] # 关键词反应规则 - - # chinese_typo - chinese_typo_enable: bool # 是否启用中文错别字 - chinese_typo_error_rate: float # 中文错别字错误率 - chinese_typo_min_freq: int # 中文错别字最小频率 - chinese_typo_tone_error_rate: float # 中文错别字声调错误率 - chinese_typo_word_replace_rate: float # 中文错别字单词替换率 - - # response_splitter - enable_response_splitter: bool # 是否启用回复分割器 - response_max_length: int # 回复最大长度 - response_max_sentence_num: int # 回复最大句子数 - enable_kaomoji_protection: bool # 是否启用颜文字保护 - - model_max_output_length: int # 模型最大输出长度 - - # remote - remote_enable: bool # 是否启用远程功能 - - # experimental - enable_friend_chat: bool # 是否启用好友聊天 - talk_allowed_private: List[int] # 允许私聊的QQ号列表 - pfc_chatting: bool # 是否启用PFC聊天 - - # 模型配置 - llm_reasoning: Dict[str, Any] # 推理模型配置 - llm_normal: Dict[str, Any] # 普通模型配置 - llm_topic_judge: Dict[str, Any] # 主题判断模型配置 - summary: Dict[str, Any] # 总结模型配置 - vlm: Dict[str, Any] # VLM模型配置 - llm_heartflow: Dict[str, Any] # 心流模型配置 - llm_observation: Dict[str, Any] # 观察模型配置 - llm_sub_heartflow: Dict[str, Any] # 子心流模型配置 - llm_plan: Optional[Dict[str, Any]] # 计划模型配置 - embedding: Dict[str, Any] # 嵌入模型配置 - llm_PFC_action_planner: Optional[Dict[str, Any]] # PFC行动计划模型配置 - llm_PFC_chat: Optional[Dict[str, Any]] # PFC聊天模型配置 - llm_PFC_reply_checker: Optional[Dict[str, Any]] # PFC回复检查模型配置 - llm_tool_use: Optional[Dict[str, Any]] # 工具使用模型配置 - - api_urls: Optional[Dict[str, str]] # API地址配置 - - @staticmethod - def validate_config(config: dict): - """ - 校验传入的 toml 配置字典是否合法。 - :param config: toml库load后的配置字典 - :raises: ValueError, KeyError, TypeError - """ - # 检查主层级 - required_sections = [ - "inner", - "bot", - "groups", - "personality", - "identity", - "platforms", - "chat", - "normal_chat", - "focus_chat", - "emoji", - "memory", - "mood", - "keywords_reaction", - "chinese_typo", - "response_splitter", - "remote", - "experimental", - "model", - ] - for section in required_sections: - if section not in config: - raise KeyError(f"缺少配置段: [{section}]") - - # 检查部分关键字段 - if "version" not in config["inner"]: - raise KeyError("缺少 inner.version 字段") - if not isinstance(config["inner"]["version"], str): - raise TypeError("inner.version 必须为字符串") - - if "qq" not in config["bot"]: - raise KeyError("缺少 bot.qq 字段") - if not isinstance(config["bot"]["qq"], int): - raise TypeError("bot.qq 必须为整数") - - if "personality_core" not in config["personality"]: - raise KeyError("缺少 personality.personality_core 字段") - if not isinstance(config["personality"]["personality_core"], str): - raise TypeError("personality.personality_core 必须为字符串") - - if "identity_detail" not in config["identity"]: - raise KeyError("缺少 identity.identity_detail 字段") - if not isinstance(config["identity"]["identity_detail"], list): - raise TypeError("identity.identity_detail 必须为列表") - - # 可继续添加更多字段的类型和值检查 - # ... - - # 检查模型配置 - model_keys = [ - "llm_reasoning", - "llm_normal", - "llm_topic_judge", - "summary", - "vlm", - "llm_heartflow", - "llm_observation", - "llm_sub_heartflow", - "embedding", - ] - if "model" not in config: - raise KeyError("缺少 [model] 配置段") - for key in model_keys: - if key not in config["model"]: - raise KeyError(f"缺少 model.{key} 配置") - - # 检查通过 - return True - - -@strawberry.type -class APIEnvConfig: - """环境变量配置""" - - HOST: str # 服务主机地址 - PORT: int # 服务端口 - - PLUGINS: List[str] # 插件列表 - - MONGODB_HOST: str # MongoDB 主机地址 - MONGODB_PORT: int # MongoDB 端口 - DATABASE_NAME: str # 数据库名称 - - CHAT_ANY_WHERE_BASE_URL: str # ChatAnywhere 基础URL - SILICONFLOW_BASE_URL: str # SiliconFlow 基础URL - DEEP_SEEK_BASE_URL: str # DeepSeek 基础URL - - DEEP_SEEK_KEY: Optional[str] # DeepSeek API Key - CHAT_ANY_WHERE_KEY: Optional[str] # ChatAnywhere API Key - SILICONFLOW_KEY: Optional[str] # SiliconFlow API Key - - SIMPLE_OUTPUT: Optional[bool] # 是否简化输出 - CONSOLE_LOG_LEVEL: Optional[str] # 控制台日志等级 - FILE_LOG_LEVEL: Optional[str] # 文件日志等级 - DEFAULT_CONSOLE_LOG_LEVEL: Optional[str] # 默认控制台日志等级 - DEFAULT_FILE_LOG_LEVEL: Optional[str] # 默认文件日志等级 - - @strawberry.field - def get_env(self) -> str: - return "env" - - @staticmethod - def validate_config(config: dict): - """ - 校验环境变量配置字典是否合法。 - :param config: 环境变量配置字典 - :raises: KeyError, TypeError - """ - required_fields = [ - "HOST", - "PORT", - "PLUGINS", - "MONGODB_HOST", - "MONGODB_PORT", - "DATABASE_NAME", - "CHAT_ANY_WHERE_BASE_URL", - "SILICONFLOW_BASE_URL", - "DEEP_SEEK_BASE_URL", - ] - for field in required_fields: - if field not in config: - raise KeyError(f"缺少环境变量配置字段: {field}") - - if not isinstance(config["HOST"], str): - raise TypeError("HOST 必须为字符串") - if not isinstance(config["PORT"], int): - raise TypeError("PORT 必须为整数") - if not isinstance(config["PLUGINS"], list): - raise TypeError("PLUGINS 必须为列表") - if not isinstance(config["MONGODB_HOST"], str): - raise TypeError("MONGODB_HOST 必须为字符串") - if not isinstance(config["MONGODB_PORT"], int): - raise TypeError("MONGODB_PORT 必须为整数") - if not isinstance(config["DATABASE_NAME"], str): - raise TypeError("DATABASE_NAME 必须为字符串") - if not isinstance(config["CHAT_ANY_WHERE_BASE_URL"], str): - raise TypeError("CHAT_ANY_WHERE_BASE_URL 必须为字符串") - if not isinstance(config["SILICONFLOW_BASE_URL"], str): - raise TypeError("SILICONFLOW_BASE_URL 必须为字符串") - if not isinstance(config["DEEP_SEEK_BASE_URL"], str): - raise TypeError("DEEP_SEEK_BASE_URL 必须为字符串") - - # 可选字段类型检查 - optional_str_fields = [ - "DEEP_SEEK_KEY", - "CHAT_ANY_WHERE_KEY", - "SILICONFLOW_KEY", - "CONSOLE_LOG_LEVEL", - "FILE_LOG_LEVEL", - "DEFAULT_CONSOLE_LOG_LEVEL", - "DEFAULT_FILE_LOG_LEVEL", - ] - for field in optional_str_fields: - if field in config and config[field] is not None and not isinstance(config[field], str): - raise TypeError(f"{field} 必须为字符串或None") - - if ( - "SIMPLE_OUTPUT" in config - and config["SIMPLE_OUTPUT"] is not None - and not isinstance(config["SIMPLE_OUTPUT"], bool) - ): - raise TypeError("SIMPLE_OUTPUT 必须为布尔值或None") - - # 检查通过 - return True - - -print("当前路径:") -print(ROOT_PATH) diff --git a/src/api/maigraphql/__init__.py b/src/api/maigraphql/__init__.py deleted file mode 100644 index c414911de..000000000 --- a/src/api/maigraphql/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -import strawberry - -from fastapi import FastAPI -from strawberry.fastapi import GraphQLRouter - -from src.common.server import get_global_server - - -@strawberry.type -class Query: - @strawberry.field - def hello(self) -> str: - return "Hello World" - - -schema = strawberry.Schema(Query) - -graphql_app = GraphQLRouter(schema) - -fast_api_app: FastAPI = get_global_server().get_app() - -fast_api_app.include_router(graphql_app, prefix="/graphql") diff --git a/src/api/maigraphql/schema.py b/src/api/maigraphql/schema.py deleted file mode 100644 index 2ae28399f..000000000 --- a/src/api/maigraphql/schema.py +++ /dev/null @@ -1 +0,0 @@ -pass diff --git a/src/api/main.py b/src/api/main.py deleted file mode 100644 index 598b8aec5..000000000 --- a/src/api/main.py +++ /dev/null @@ -1,112 +0,0 @@ -from fastapi import APIRouter -from strawberry.fastapi import GraphQLRouter -import os -import sys - -# from src.chat.heart_flow.heartflow import heartflow -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) -# from src.config.config import BotConfig -from src.common.logger import get_logger -from src.api.reload_config import reload_config as reload_config_func -from src.common.server import get_global_server -from src.api.apiforgui import ( - get_all_subheartflow_ids, - forced_change_subheartflow_status, - get_subheartflow_cycle_info, - get_all_states, -) -from src.chat.heart_flow.sub_heartflow import ChatState -from src.api.basic_info_api import get_all_basic_info # 新增导入 - - -router = APIRouter() - - -logger = get_logger("api") - -logger.info("麦麦API服务器已启动") -graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema - -router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"]) - - -@router.post("/config/reload") -async def reload_config(): - return await reload_config_func() - - -@router.get("/gui/subheartflow/get/all") -async def get_subheartflow_ids(): - """获取所有子心流的ID列表""" - return await get_all_subheartflow_ids() - - -@router.post("/gui/subheartflow/forced_change_status") -async def forced_change_subheartflow_status_api(subheartflow_id: str, status: ChatState): # noqa - """强制改变子心流的状态""" - # 参数检查 - if not isinstance(status, ChatState): - logger.warning(f"无效的状态参数: {status}") - return {"status": "failed", "reason": "invalid status"} - logger.info(f"尝试将子心流 {subheartflow_id} 状态更改为 {status.value}") - success = await forced_change_subheartflow_status(subheartflow_id, status) - if success: - logger.info(f"子心流 {subheartflow_id} 状态更改为 {status.value} 成功") - return {"status": "success"} - else: - logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败") - return {"status": "failed"} - - -@router.get("/stop") -async def force_stop_maibot(): - """强制停止MAI Bot""" - from bot import request_shutdown - - success = await request_shutdown() - if success: - logger.info("MAI Bot已强制停止") - return {"status": "success"} - else: - logger.error("MAI Bot强制停止失败") - return {"status": "failed"} - - -@router.get("/gui/subheartflow/cycleinfo") -async def get_subheartflow_cycle_info_api(subheartflow_id: str, history_len: int): - """获取子心流的循环信息""" - cycle_info = await get_subheartflow_cycle_info(subheartflow_id, history_len) - if cycle_info: - return {"status": "success", "data": cycle_info} - else: - logger.warning(f"子心流 {subheartflow_id} 循环信息未找到") - return {"status": "failed", "reason": "subheartflow not found"} - - -@router.get("/gui/get_all_states") -async def get_all_states_api(): - """获取所有状态""" - all_states = await get_all_states() - if all_states: - return {"status": "success", "data": all_states} - else: - logger.warning("获取所有状态失败") - return {"status": "failed", "reason": "failed to get all states"} - - -@router.get("/info") -async def get_system_basic_info(): - """获取系统基本信息""" - logger.info("请求系统基本信息") - try: - info = get_all_basic_info() - return {"status": "success", "data": info} - except Exception as e: - logger.error(f"获取系统基本信息失败: {e}") - return {"status": "failed", "reason": str(e)} - - -def start_api_server(): - """启动API服务器""" - get_global_server().register_router(router, prefix="/api/v1") - # pass diff --git a/src/api/reload_config.py b/src/api/reload_config.py deleted file mode 100644 index 087c47e4f..000000000 --- a/src/api/reload_config.py +++ /dev/null @@ -1,24 +0,0 @@ -from fastapi import HTTPException -from rich.traceback import install -from src.config.config import get_config_dir, load_config -from src.common.logger import get_logger -import os - -install(extra_lines=3) - -logger = get_logger("api") - - -async def reload_config(): - try: - from src.config import config as config_module - - logger.debug("正在重载配置文件...") - bot_config_path = os.path.join(get_config_dir(), "bot_config.toml") - config_module.global_config = load_config(config_path=bot_config_path) - logger.debug("配置文件重载成功") - return {"status": "reloaded"} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e - except Exception as e: - raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 3511d938b..11fb0f62d 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -5,20 +5,19 @@ import os import random import time import traceback -from typing import Optional, Tuple, List, Any -from PIL import Image import io import re - -# from gradio_client import file +import binascii +from typing import Optional, Tuple, List, Any +from PIL import Image +from rich.traceback import install from src.common.database.database_model import Emoji from src.common.database.database import db as peewee_db +from src.common.logger import get_logger from src.config.config import global_config from src.chat.utils.utils_image import image_path_to_base64, get_image_manager from src.llm_models.utils_model import LLMRequest -from src.common.logger import get_logger -from rich.traceback import install install(extra_lines=3) @@ -26,7 +25,7 @@ logger = get_logger("emoji") BASE_DIR = os.path.join("data") EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录 -EMOJI_REGISTED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录 +EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录 MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中 """ @@ -85,7 +84,7 @@ class MaiEmoji: logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}") try: with Image.open(io.BytesIO(image_bytes)) as img: - self.format = img.format.lower() + self.format = img.format.lower() # type: ignore logger.debug(f"[初始化] 格式获取成功: {self.format}") except Exception as pil_error: logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}") @@ -100,7 +99,7 @@ class MaiEmoji: logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}") self.is_deleted = True return None - except base64.binascii.Error as b64_error: + except (binascii.Error, ValueError) as b64_error: logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}") self.is_deleted = True return None @@ -113,7 +112,7 @@ class MaiEmoji: async def register_to_db(self) -> bool: """ 注册表情包 - 将表情包对应的文件,从当前路径移动到EMOJI_REGISTED_DIR目录下 + 将表情包对应的文件,从当前路径移动到EMOJI_REGISTERED_DIR目录下 并修改对应的实例属性,然后将表情包信息保存到数据库中 """ try: @@ -122,7 +121,7 @@ class MaiEmoji: # 源路径是当前实例的完整路径 self.full_path source_full_path = self.full_path # 目标完整路径 - destination_full_path = os.path.join(EMOJI_REGISTED_DIR, self.filename) + destination_full_path = os.path.join(EMOJI_REGISTERED_DIR, self.filename) # 检查源文件是否存在 if not os.path.exists(source_full_path): @@ -139,7 +138,7 @@ class MaiEmoji: logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}") # 更新实例的路径属性为新路径 self.full_path = destination_full_path - self.path = EMOJI_REGISTED_DIR + self.path = EMOJI_REGISTERED_DIR # self.filename 保持不变 except Exception as move_error: logger.error(f"[错误] 移动文件失败: {str(move_error)}") @@ -202,7 +201,7 @@ class MaiEmoji: try: will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash) result = will_delete_emoji.delete_instance() # Returns the number of rows deleted. - except Emoji.DoesNotExist: + except Emoji.DoesNotExist: # type: ignore logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") result = 0 # Indicate no DB record was deleted @@ -298,7 +297,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]: def _ensure_emoji_dir() -> None: """确保表情存储目录存在""" os.makedirs(EMOJI_DIR, exist_ok=True) - os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True) + os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True) async def clear_temp_emoji() -> None: @@ -331,10 +330,10 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") return removed_count + cleaned_count = 0 try: # 获取内存中所有有效表情包的完整路径集合 tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted} - cleaned_count = 0 # 遍历指定目录中的所有文件 for file_name in os.listdir(emoji_dir): @@ -358,11 +357,11 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r else: logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。") - return removed_count + cleaned_count - except Exception as e: logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}") + return removed_count + cleaned_count + class EmojiManager: _instance = None @@ -414,7 +413,7 @@ class EmojiManager: emoji_update.usage_count += 1 emoji_update.last_used_time = time.time() # Update last used time emoji_update.save() # Persist changes to DB - except Emoji.DoesNotExist: + except Emoji.DoesNotExist: # type: ignore logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") except Exception as e: logger.error(f"记录表情使用失败: {str(e)}") @@ -570,8 +569,8 @@ class EmojiManager: if objects_to_remove: self.emoji_objects = [e for e in self.emoji_objects if e not in objects_to_remove] - # 清理 EMOJI_REGISTED_DIR 目录中未被追踪的文件 - removed_count = await clean_unused_emojis(EMOJI_REGISTED_DIR, self.emoji_objects, removed_count) + # 清理 EMOJI_REGISTERED_DIR 目录中未被追踪的文件 + removed_count = await clean_unused_emojis(EMOJI_REGISTERED_DIR, self.emoji_objects, removed_count) # 输出清理结果 if removed_count > 0: @@ -850,11 +849,13 @@ class EmojiManager: if isinstance(image_base64, str): image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_bytes = base64.b64decode(image_base64) - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore # 调用AI获取描述 if image_format == "gif" or image_format == "GIF": - image_base64 = get_image_manager().transform_gif(image_base64) + image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore + if not image_base64: + raise RuntimeError("GIF表情包转换失败") prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg") else: diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index b85f53b79..0b1eaef7a 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -1,14 +1,16 @@ -from .exprssion_learner import get_expression_learner -import random -from typing import List, Dict, Tuple -from json_repair import repair_json import json import os import time +import random + +from typing import List, Dict, Tuple, Optional +from json_repair import repair_json + from src.llm_models.utils_model import LLMRequest from src.config.config import global_config from src.common.logger import get_logger from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from .exprssion_learner import get_expression_learner logger = get_logger("expression_selector") @@ -165,7 +167,12 @@ class ExpressionSelector: logger.error(f"批量更新表达方式count失败 for {file_path}: {e}") async def select_suitable_expressions_llm( - self, chat_id: str, chat_info: str, max_num: int = 10, min_num: int = 5, target_message: str = None + self, + chat_id: str, + chat_info: str, + max_num: int = 10, + min_num: int = 5, + target_message: Optional[str] = None, ) -> List[Dict[str, str]]: """使用LLM选择适合的表达方式""" diff --git a/src/chat/express/exprssion_learner.py b/src/chat/express/exprssion_learner.py index 9b170d9a3..738a88b95 100644 --- a/src/chat/express/exprssion_learner.py +++ b/src/chat/express/exprssion_learner.py @@ -1,14 +1,16 @@ import time import random +import json +import os + from typing import List, Dict, Optional, Any, Tuple + from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest from src.config.config import global_config from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -import os from src.chat.message_receive.chat_stream import get_chat_manager -import json MAX_EXPRESSION_COUNT = 300 @@ -74,7 +76,8 @@ class ExpressionLearner: ) self.llm_model = None - def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: + def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]: + # sourcery skip: extract-duplicate-method, remove-unnecessary-cast """ 获取指定chat_id的style和grammar表达方式 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作 @@ -119,10 +122,10 @@ class ExpressionLearner: min_len = min(len(s1), len(s2)) if min_len < 5: return False - same = sum(1 for a, b in zip(s1, s2) if a == b) + same = sum(a == b for a, b in zip(s1, s2)) return same / min_len > 0.8 - async def learn_and_store_expression(self) -> List[Tuple[str, str, str]]: + async def learn_and_store_expression(self) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str, str]]]: """ 学习并存储表达方式,分别学习语言风格和句法特点 同时对所有已存储的表达方式进行全局衰减 @@ -158,12 +161,12 @@ class ExpressionLearner: for _ in range(3): learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25) if not learnt_style: - return [] + return [], [] for _ in range(1): learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10) if not learnt_grammar: - return [] + return [], [] return learnt_style, learnt_grammar @@ -214,6 +217,7 @@ class ExpressionLearner: return result async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]: + # sourcery skip: use-join """ 选择从当前到最近1小时内的随机num条消息,然后学习这些消息的表达方式 type: "style" or "grammar" @@ -249,7 +253,7 @@ class ExpressionLearner: return [] # 按chat_id分组 - chat_dict: Dict[str, List[Dict[str, str]]] = {} + chat_dict: Dict[str, List[Dict[str, Any]]] = {} for chat_id, situation, style in learnt_expressions: if chat_id not in chat_dict: chat_dict[chat_id] = [] diff --git a/src/chat/focus_chat/focus_loop_info.py b/src/chat/focus_chat/focus_loop_info.py index 342368df7..827c544a2 100644 --- a/src/chat/focus_chat/focus_loop_info.py +++ b/src/chat/focus_chat/focus_loop_info.py @@ -1,10 +1,10 @@ # 定义了来自外部世界的信息 # 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体 from datetime import datetime +from typing import List + from src.common.logger import get_logger from src.chat.focus_chat.hfc_utils import CycleDetail -from typing import List -# Import the new utility function logger = get_logger("loop_info") diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index 70cda57c6..05600c256 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -8,7 +8,7 @@ from rich.traceback import install from src.config.config import global_config from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.timer_calculator import Timer from src.chat.planner_actions.planner import ActionPlanner @@ -49,7 +49,9 @@ class HeartFChatting: """ # 基础属性 self.stream_id: str = chat_id # 聊天流ID - self.chat_stream = get_chat_manager().get_stream(self.stream_id) + self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore + if not self.chat_stream: + raise ValueError(f"无法找到聊天流: {self.stream_id}") self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]" self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id) @@ -171,7 +173,7 @@ class HeartFChatting: # 执行规划和处理阶段 try: async with self._get_cycle_context(): - thinking_id = "tid" + str(round(time.time(), 2)) + thinking_id = f"tid{str(round(time.time(), 2))}" self._current_cycle_detail.set_thinking_id(thinking_id) # 使用异步上下文管理器处理消息 @@ -245,7 +247,7 @@ class HeartFChatting: logger.info( f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考," - f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " + f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "") ) @@ -256,7 +258,7 @@ class HeartFChatting: cycle_performance_data = { "cycle_id": self._current_cycle_detail.cycle_id, "action_type": action_result.get("action_type", "unknown"), - "total_time": self._current_cycle_detail.end_time - self._current_cycle_detail.start_time, + "total_time": self._current_cycle_detail.end_time - self._current_cycle_detail.start_time, # type: ignore "step_times": cycle_timers.copy(), "reasoning": action_result.get("reasoning", ""), "success": self._current_cycle_detail.loop_action_info.get("action_taken", False), @@ -447,11 +449,8 @@ class HeartFChatting: # 处理动作并获取结果 result = await action_handler.handle_action() - if len(result) == 3: - success, reply_text, command = result - else: - success, reply_text = result - command = "" + success, reply_text = result + command = "" # 检查action_data中是否有系统命令,优先使用系统命令 if "_system_command" in action_data: @@ -478,15 +477,14 @@ class HeartFChatting: ) # 设置系统命令,在下次循环检查时触发退出 command = "stop_focus_chat" - else: - if reply_text == "timeout": - self.reply_timeout_count += 1 - if self.reply_timeout_count > 5: - logger.warning( - f"[{self.log_prefix} ] 连续回复超时次数过多,{global_config.chat.thinking_timeout}秒 内大模型没有返回有效内容,请检查你的api是否速度过慢或配置错误。建议不要使用推理模型,推理模型生成速度过慢。或者尝试拉高thinking_timeout参数,这可能导致回复时间过长。" - ) - logger.warning(f"{self.log_prefix} 回复生成超时{global_config.chat.thinking_timeout}s,已跳过") - return False, "", "" + elif reply_text == "timeout": + self.reply_timeout_count += 1 + if self.reply_timeout_count > 5: + logger.warning( + f"[{self.log_prefix} ] 连续回复超时次数过多,{global_config.chat.thinking_timeout}秒 内大模型没有返回有效内容,请检查你的api是否速度过慢或配置错误。建议不要使用推理模型,推理模型生成速度过慢。或者尝试拉高thinking_timeout参数,这可能导致回复时间过长。" + ) + logger.warning(f"{self.log_prefix} 回复生成超时{global_config.chat.thinking_timeout}s,已跳过") + return False, "", "" return success, reply_text, command diff --git a/src/chat/focus_chat/hfc_performance_logger.py b/src/chat/focus_chat/hfc_performance_logger.py index 64e65ff85..702a8445f 100644 --- a/src/chat/focus_chat/hfc_performance_logger.py +++ b/src/chat/focus_chat/hfc_performance_logger.py @@ -2,6 +2,7 @@ import json from datetime import datetime from typing import Dict, Any from pathlib import Path + from src.common.logger import get_logger logger = get_logger("hfc_performance") diff --git a/src/chat/focus_chat/hfc_utils.py b/src/chat/focus_chat/hfc_utils.py index 11b04c801..0393c2175 100644 --- a/src/chat/focus_chat/hfc_utils.py +++ b/src/chat/focus_chat/hfc_utils.py @@ -1,11 +1,12 @@ import time -from typing import Optional +import json + +from typing import Optional, Dict, Any + from src.chat.message_receive.message import MessageRecv, BaseMessageInfo from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.message import UserInfo from src.common.logger import get_logger -import json -from typing import Dict, Any logger = get_logger(__name__) @@ -117,7 +118,7 @@ async def create_empty_anchor_message( placeholder_msg_info = BaseMessageInfo( message_id=placeholder_id, platform=platform, - group_info=group_info, + group_info=group_info, # type: ignore user_info=placeholder_user, time=time.time(), ) diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py index fdcfba6a3..4c5285259 100644 --- a/src/chat/heart_flow/heartflow.py +++ b/src/chat/heart_flow/heartflow.py @@ -1,7 +1,7 @@ -from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState +from typing import Any, Optional, Dict + from src.common.logger import get_logger -from typing import Any, Optional -from typing import Dict +from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState from src.chat.message_receive.chat_stream import get_chat_manager logger = get_logger("heartflow") @@ -34,7 +34,7 @@ class Heartflow: logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True) return None - async def force_change_subheartflow_status(self, subheartflow_id: str, status: ChatState) -> None: + async def force_change_subheartflow_status(self, subheartflow_id: str, status: ChatState) -> bool: """强制改变子心流的状态""" # 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据 return await self.force_change_state(subheartflow_id, status) diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index d01775168..aa8bfdbf0 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -1,21 +1,21 @@ -from src.chat.memory_system.Hippocampus import hippocampus_manager -from src.config.config import global_config import asyncio +import re +import math +import traceback + +from typing import Tuple + +from src.config.config import global_config +from src.chat.memory_system.Hippocampus import hippocampus_manager from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.storage import MessageStorage from src.chat.heart_flow.heartflow import heartflow from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.timer_calculator import Timer from src.common.logger import get_logger -import re -import math -import traceback -from typing import Tuple - from src.person_info.relationship_manager import get_relationship_manager from src.mood.mood_manager import mood_manager - logger = get_logger("chat") @@ -26,16 +26,16 @@ async def _process_relationship(message: MessageRecv) -> None: message: 消息对象,包含用户信息 """ platform = message.message_info.platform - user_id = message.message_info.user_info.user_id - nickname = message.message_info.user_info.user_nickname - cardname = message.message_info.user_info.user_cardname or nickname + user_id = message.message_info.user_info.user_id # type: ignore + nickname = message.message_info.user_info.user_nickname # type: ignore + cardname = message.message_info.user_info.user_cardname or nickname # type: ignore relationship_manager = get_relationship_manager() is_known = await relationship_manager.is_known_some_one(platform, user_id) if not is_known: logger.info(f"首次认识用户: {nickname}") - await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) + await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: @@ -105,9 +105,9 @@ class HeartFCMessageReceiver: # 2. 兴趣度计算与更新 interested_rate, is_mentioned = await _calculate_interest(message) - subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) + subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) # type: ignore - chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) + chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) # type: ignore asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate)) # 3. 日志记录 @@ -119,7 +119,7 @@ class HeartFCMessageReceiver: picid_pattern = r"\[picid:([^\]]+)\]" processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text) - logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") + logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]") diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py index 9f6a49895..fc230e255 100644 --- a/src/chat/heart_flow/sub_heartflow.py +++ b/src/chat/heart_flow/sub_heartflow.py @@ -1,16 +1,18 @@ import asyncio import time -from typing import Optional, List, Dict, Tuple import traceback + +from typing import Optional, List, Dict, Tuple +from rich.traceback import install + from src.common.logger import get_logger +from src.config.config import global_config from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.focus_chat.heartFC_chat import HeartFChatting from src.chat.normal_chat.normal_chat import NormalChat from src.chat.heart_flow.chat_state_info import ChatState, ChatStateInfo from src.chat.utils.utils import get_chat_type_and_target_info -from src.config.config import global_config -from rich.traceback import install logger = get_logger("sub_heartflow") @@ -40,7 +42,7 @@ class SubHeartflow: self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id) self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id # 兴趣消息集合 - self.interest_dict: Dict[str, tuple[MessageRecv, float, bool]] = {} + self.interest_dict: Dict[str, Tuple[MessageRecv, float, bool]] = {} # focus模式退出冷却时间管理 self.last_focus_exit_time: float = 0 # 上次退出focus模式的时间 @@ -297,7 +299,7 @@ class SubHeartflow: ) def add_message_to_normal_chat_cache(self, message: MessageRecv, interest_value: float, is_mentioned: bool): - self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned) + self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned) # type: ignore # 如果字典长度超过10,删除最旧的消息 if len(self.interest_dict) > 30: oldest_key = next(iter(self.interest_dict)) diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 29a26f64c..a3ee46a7a 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -42,7 +42,7 @@ def calculate_information_content(text): return entropy -def cosine_similarity(v1, v2): +def cosine_similarity(v1, v2): # sourcery skip: assign-if-exp, reintroduce-else """计算余弦相似度""" dot_product = np.dot(v1, v2) norm1 = np.linalg.norm(v1) @@ -89,14 +89,13 @@ class MemoryGraph: if not isinstance(self.G.nodes[concept]["memory_items"], list): self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]] self.G.nodes[concept]["memory_items"].append(memory) - # 更新最后修改时间 - self.G.nodes[concept]["last_modified"] = current_time else: self.G.nodes[concept]["memory_items"] = [memory] # 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time if "created_time" not in self.G.nodes[concept]: self.G.nodes[concept]["created_time"] = current_time - self.G.nodes[concept]["last_modified"] = current_time + # 更新最后修改时间 + self.G.nodes[concept]["last_modified"] = current_time else: # 如果是新节点,创建新的记忆列表 self.G.add_node( @@ -108,11 +107,7 @@ class MemoryGraph: def get_dot(self, concept): # 检查节点是否存在于图中 - if concept in self.G: - # 从图中获取节点数据 - node_data = self.G.nodes[concept] - return concept, node_data - return None + return (concept, self.G.nodes[concept]) if concept in self.G else None def get_related_item(self, topic, depth=1): if topic not in self.G: @@ -139,8 +134,7 @@ class MemoryGraph: if depth >= 2: # 获取相邻节点的记忆项 for neighbor in neighbors: - node_data = self.get_dot(neighbor) - if node_data: + if node_data := self.get_dot(neighbor): concept, data = node_data if "memory_items" in data: memory_items = data["memory_items"] @@ -194,9 +188,9 @@ class MemoryGraph: class Hippocampus: def __init__(self): self.memory_graph = MemoryGraph() - self.model_summary = None - self.entorhinal_cortex = None - self.parahippocampal_gyrus = None + self.model_summary: LLMRequest = None # type: ignore + self.entorhinal_cortex: EntorhinalCortex = None # type: ignore + self.parahippocampal_gyrus: ParahippocampalGyrus = None # type: ignore def initialize(self): # 初始化子组件 @@ -218,7 +212,7 @@ class Hippocampus: memory_items = [memory_items] if memory_items else [] # 使用集合来去重,避免排序 - unique_items = set(str(item) for item in memory_items) + unique_items = {str(item) for item in memory_items} # 使用frozenset来保证顺序一致性 content = f"{concept}:{frozenset(unique_items)}" return hash(content) @@ -231,6 +225,7 @@ class Hippocampus: @staticmethod def find_topic_llm(text, topic_num): + # sourcery skip: inline-immediately-returned-variable prompt = ( f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来," f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。" @@ -240,6 +235,7 @@ class Hippocampus: @staticmethod def topic_what(text, topic): + # sourcery skip: inline-immediately-returned-variable # 不再需要 time_info 参数 prompt = ( f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,' @@ -480,9 +476,7 @@ class Hippocampus: top_memories = memory_similarities[:max_memory_length] # 添加到结果中 - for memory, similarity in top_memories: - all_memories.append((node, [memory], similarity)) - # logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})") + all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories) else: logger.info("节点没有记忆") @@ -646,9 +640,7 @@ class Hippocampus: top_memories = memory_similarities[:max_memory_length] # 添加到结果中 - for memory, similarity in top_memories: - all_memories.append((node, [memory], similarity)) - # logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})") + all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories) else: logger.info("节点没有记忆") @@ -823,11 +815,11 @@ class EntorhinalCortex: logger.debug(f"回忆往事: {readable_timestamp}") chat_samples = [] for timestamp in timestamps: - # 调用修改后的 random_get_msg_snippet - messages = self.random_get_msg_snippet( - timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg - ) - if messages: + if messages := self.random_get_msg_snippet( + timestamp, + global_config.memory.memory_build_sample_length, + max_memorized_time_per_msg, + ): time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条") chat_samples.append(messages) @@ -838,6 +830,7 @@ class EntorhinalCortex: @staticmethod def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None: + # sourcery skip: invert-any-all, use-any, use-named-expression, use-next """从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)""" try_count = 0 time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟 @@ -847,22 +840,21 @@ class EntorhinalCortex: timestamp_start = target_timestamp timestamp_end = target_timestamp + time_window_seconds - chosen_message = get_raw_msg_by_timestamp( - timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest" - ) + if chosen_message := get_raw_msg_by_timestamp( + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + limit=1, + limit_mode="earliest", + ): + chat_id: str = chosen_message[0].get("chat_id") # type: ignore - if chosen_message: - chat_id = chosen_message[0].get("chat_id") - - messages = get_raw_msg_by_timestamp_with_chat( + if messages := get_raw_msg_by_timestamp_with_chat( timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, limit_mode="earliest", chat_id=chat_id, - ) - - if messages: + ): # 检查获取到的所有消息是否都未达到最大记忆次数 all_valid = True for message in messages: @@ -975,7 +967,7 @@ class EntorhinalCortex: ).execute() if nodes_to_delete: - GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() + GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore # 处理边的信息 db_edges = list(GraphEdges.select()) @@ -1114,7 +1106,7 @@ class EntorhinalCortex: node_start = time.time() if nodes_data: batch_size = 500 # 增加批量大小 - with GraphNodes._meta.database.atomic(): + with GraphNodes._meta.database.atomic(): # type: ignore for i in range(0, len(nodes_data), batch_size): batch = nodes_data[i : i + batch_size] GraphNodes.insert_many(batch).execute() @@ -1125,7 +1117,7 @@ class EntorhinalCortex: edge_start = time.time() if edges_data: batch_size = 500 # 增加批量大小 - with GraphEdges._meta.database.atomic(): + with GraphEdges._meta.database.atomic(): # type: ignore for i in range(0, len(edges_data), batch_size): batch = edges_data[i : i + batch_size] GraphEdges.insert_many(batch).execute() @@ -1489,32 +1481,30 @@ class ParahippocampalGyrus: # --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 --- last_modified = node_data.get("last_modified", current_time) # 条件1:检查是否长时间未修改 (超过24小时) - if current_time - last_modified > 3600 * 24: - # 条件2:再次确认节点包含记忆项(理论上已确认,但作为保险) - if memory_items: - current_count = len(memory_items) - # 如果列表非空,才进行随机选择 - if current_count > 0: - removed_item = random.choice(memory_items) - try: - memory_items.remove(removed_item) + if current_time - last_modified > 3600 * 24 and memory_items: + current_count = len(memory_items) + # 如果列表非空,才进行随机选择 + if current_count > 0: + removed_item = random.choice(memory_items) + try: + memory_items.remove(removed_item) - # 条件3:检查移除后 memory_items 是否变空 - if memory_items: # 如果移除后列表不为空 - # self.memory_graph.G.nodes[node]["memory_items"] = memory_items # 直接修改列表即可 - self.memory_graph.G.nodes[node]["last_modified"] = current_time # 更新修改时间 - node_changes["reduced"].append(f"{node} (数量: {current_count} -> {len(memory_items)})") - else: # 如果移除后列表为空 - # 尝试移除节点,处理可能的错误 - try: - self.memory_graph.G.remove_node(node) - node_changes["removed"].append(f"{node}(遗忘清空)") # 标记为遗忘清空 - logger.debug(f"[遗忘] 节点 {node} 因移除最后一项而被清空。") - except nx.NetworkXError as e: - logger.warning(f"[遗忘] 尝试移除节点 {node} 时发生错误(可能已被移除):{e}") - except ValueError: - # 这个错误理论上不应发生,因为 removed_item 来自 memory_items - logger.warning(f"[遗忘] 尝试从节点 '{node}' 移除不存在的项目 '{removed_item[:30]}...'") + # 条件3:检查移除后 memory_items 是否变空 + if memory_items: # 如果移除后列表不为空 + # self.memory_graph.G.nodes[node]["memory_items"] = memory_items # 直接修改列表即可 + self.memory_graph.G.nodes[node]["last_modified"] = current_time # 更新修改时间 + node_changes["reduced"].append(f"{node} (数量: {current_count} -> {len(memory_items)})") + else: # 如果移除后列表为空 + # 尝试移除节点,处理可能的错误 + try: + self.memory_graph.G.remove_node(node) + node_changes["removed"].append(f"{node}(遗忘清空)") # 标记为遗忘清空 + logger.debug(f"[遗忘] 节点 {node} 因移除最后一项而被清空。") + except nx.NetworkXError as e: + logger.warning(f"[遗忘] 尝试移除节点 {node} 时发生错误(可能已被移除):{e}") + except ValueError: + # 这个错误理论上不应发生,因为 removed_item 来自 memory_items + logger.warning(f"[遗忘] 尝试从节点 '{node}' 移除不存在的项目 '{removed_item[:30]}...'") node_check_end = time.time() logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}秒") @@ -1669,7 +1659,7 @@ class ParahippocampalGyrus: class HippocampusManager: def __init__(self): - self._hippocampus = None + self._hippocampus: Hippocampus = None # type: ignore self._initialized = False def initialize(self): diff --git a/src/chat/memory_system/memory_activator.py b/src/chat/memory_system/memory_activator.py index 560fe01a6..66ff89755 100644 --- a/src/chat/memory_system/memory_activator.py +++ b/src/chat/memory_system/memory_activator.py @@ -13,7 +13,7 @@ from json_repair import repair_json logger = get_logger("memory_activator") -def get_keywords_from_json(json_str): +def get_keywords_from_json(json_str) -> List: """ 从JSON字符串中提取关键词列表 @@ -28,15 +28,8 @@ def get_keywords_from_json(json_str): fixed_json = repair_json(json_str) # 如果repair_json返回的是字符串,需要解析为Python对象 - if isinstance(fixed_json, str): - result = json.loads(fixed_json) - else: - # 如果repair_json直接返回了字典对象,直接使用 - result = fixed_json - - # 提取关键词 - keywords = result.get("keywords", []) - return keywords + result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json + return result.get("keywords", []) except Exception as e: logger.error(f"解析关键词JSON失败: {e}") return [] diff --git a/src/chat/memory_system/sample_distribution.py b/src/chat/memory_system/sample_distribution.py index b3b84eb4c..69f23a770 100644 --- a/src/chat/memory_system/sample_distribution.py +++ b/src/chat/memory_system/sample_distribution.py @@ -1,52 +1,10 @@ import numpy as np -from scipy import stats from datetime import datetime, timedelta from rich.traceback import install install(extra_lines=3) -class DistributionVisualizer: - def __init__(self, mean=0, std=1, skewness=0, sample_size=10): - """ - 初始化分布可视化器 - - 参数: - mean (float): 期望均值 - std (float): 标准差 - skewness (float): 偏度 - sample_size (int): 样本大小 - """ - self.mean = mean - self.std = std - self.skewness = skewness - self.sample_size = sample_size - self.samples = None - - def generate_samples(self): - """生成具有指定参数的样本""" - if self.skewness == 0: - # 对于无偏度的情况,直接使用正态分布 - self.samples = np.random.normal(loc=self.mean, scale=self.std, size=self.sample_size) - else: - # 使用 scipy.stats 生成具有偏度的分布 - self.samples = stats.skewnorm.rvs(a=self.skewness, loc=self.mean, scale=self.std, size=self.sample_size) - - def get_weighted_samples(self): - """获取加权后的样本数列""" - if self.samples is None: - self.generate_samples() - # 将样本值乘以样本大小 - return self.samples * self.sample_size - - def get_statistics(self): - """获取分布的统计信息""" - if self.samples is None: - self.generate_samples() - - return {"均值": np.mean(self.samples), "标准差": np.std(self.samples), "实际偏度": stats.skew(self.samples)} - - class MemoryBuildScheduler: def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50): """ diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index b460ad99b..3d1f1e341 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -1,23 +1,25 @@ import traceback import os +import re + from typing import Dict, Any +from maim_message import UserInfo from src.common.logger import get_logger +from src.config.config import global_config from src.mood.mood_manager import mood_manager # 导入情绪管理器 -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream from src.chat.message_receive.message import MessageRecv -from src.experimental.only_message_process import MessageProcessor from src.chat.message_receive.storage import MessageStorage -from src.experimental.PFC.pfc_manager import PFCManager from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.config.config import global_config +from src.experimental.only_message_process import MessageProcessor +from src.experimental.PFC.pfc_manager import PFCManager from src.plugin_system.core.component_registry import component_registry # 导入新插件系统 from src.plugin_system.base.base_command import BaseCommand from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor -from maim_message import UserInfo -from src.chat.message_receive.chat_stream import ChatStream -import re + + # 定义日志配置 # 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录) @@ -184,8 +186,8 @@ class ChatBot: get_chat_manager().register_message(message) chat = await get_chat_manager().get_or_create_stream( - platform=message.message_info.platform, - user_info=user_info, + platform=message.message_info.platform, # type: ignore + user_info=user_info, # type: ignore group_info=group_info, ) @@ -195,8 +197,10 @@ class ChatBot: await message.process() # 过滤检查 - if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( - message.raw_message, chat, user_info + if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore + message.raw_message, # type: ignore + chat, + user_info, # type: ignore ): return diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 355cca1e6..8b71314a6 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -3,18 +3,17 @@ import hashlib import time import copy from typing import Dict, Optional, TYPE_CHECKING - - -from ...common.database.database import db -from ...common.database.database_model import ChatStreams # 新增导入 +from rich.traceback import install from maim_message import GroupInfo, UserInfo +from src.common.logger import get_logger +from src.common.database.database import db +from src.common.database.database_model import ChatStreams # 新增导入 + # 避免循环导入,使用TYPE_CHECKING进行类型提示 if TYPE_CHECKING: from .message import MessageRecv -from src.common.logger import get_logger -from rich.traceback import install install(extra_lines=3) @@ -28,7 +27,7 @@ class ChatMessageContext: def __init__(self, message: "MessageRecv"): self.message = message - def get_template_name(self) -> str: + def get_template_name(self) -> Optional[str]: """获取模板名称""" if self.message.message_info.template_info and not self.message.message_info.template_info.template_default: return self.message.message_info.template_info.template_name @@ -41,10 +40,10 @@ class ChatMessageContext: def check_types(self, types: list) -> bool: # sourcery skip: invert-any-all, use-any, use-next """检查消息类型""" - if not self.message.message_info.format_info.accept_format: + if not self.message.message_info.format_info.accept_format: # type: ignore return False for t in types: - if t not in self.message.message_info.format_info.accept_format: + if t not in self.message.message_info.format_info.accept_format: # type: ignore return False return True @@ -68,7 +67,7 @@ class ChatStream: platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None, - data: dict = None, + data: Optional[dict] = None, ): self.stream_id = stream_id self.platform = platform @@ -77,7 +76,7 @@ class ChatStream: self.create_time = data.get("create_time", time.time()) if data else time.time() self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time self.saved = False - self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息 + self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息 def to_dict(self) -> dict: """转换为字典格式""" @@ -99,7 +98,7 @@ class ChatStream: return cls( stream_id=data["stream_id"], platform=data["platform"], - user_info=user_info, + user_info=user_info, # type: ignore group_info=group_info, data=data, ) @@ -163,8 +162,8 @@ class ChatManager: def register_message(self, message: "MessageRecv"): """注册消息到聊天流""" stream_id = self._generate_stream_id( - message.message_info.platform, - message.message_info.user_info, + message.message_info.platform, # type: ignore + message.message_info.user_info, # type: ignore message.message_info.group_info, ) self.last_messages[stream_id] = message @@ -185,10 +184,7 @@ class ChatManager: def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str: """获取聊天流ID""" - if is_group: - components = [platform, str(id)] - else: - components = [platform, str(id), "private"] + components = [platform, id] if is_group else [platform, id, "private"] key = "_".join(components) return hashlib.md5(key.encode()).hexdigest() diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 7575e0e53..f8d917574 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -1,17 +1,15 @@ import time -from abc import abstractmethod -from dataclasses import dataclass -from typing import Optional, Any, TYPE_CHECKING - import urllib3 -from src.common.logger import get_logger - -if TYPE_CHECKING: - from .chat_stream import ChatStream -from ..utils.utils_image import get_image_manager -from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase +from abc import abstractmethod +from dataclasses import dataclass from rich.traceback import install +from typing import Optional, Any +from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase + +from src.common.logger import get_logger +from src.chat.utils.utils_image import get_image_manager +from .chat_stream import ChatStream install(extra_lines=3) @@ -27,7 +25,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @dataclass class Message(MessageBase): - chat_stream: "ChatStream" = None + chat_stream: "ChatStream" = None # type: ignore reply: Optional["Message"] = None processed_plain_text: str = "" memorized_times: int = 0 @@ -55,7 +53,7 @@ class Message(MessageBase): ) # 调用父类初始化 - super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) + super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore self.chat_stream = chat_stream # 文本处理相关属性 @@ -66,6 +64,7 @@ class Message(MessageBase): self.reply = reply async def _process_message_segments(self, segment: Seg) -> str: + # sourcery skip: remove-unnecessary-else, swap-if-else-branches """递归处理消息段,转换为文字描述 Args: @@ -78,13 +77,13 @@ class Message(MessageBase): # 处理消息段列表 segments_text = [] for seg in segment.data: - processed = await self._process_message_segments(seg) + processed = await self._process_message_segments(seg) # type: ignore if processed: segments_text.append(processed) return " ".join(segments_text) else: # 处理单个消息段 - return await self._process_single_segment(segment) + return await self._process_single_segment(segment) # type: ignore @abstractmethod async def _process_single_segment(self, segment): @@ -138,7 +137,7 @@ class MessageRecv(Message): if segment.type == "text": self.is_picid = False self.is_emoji = False - return segment.data + return segment.data # type: ignore elif segment.type == "image": # 如果是base64图片数据 if isinstance(segment.data, str): @@ -160,7 +159,7 @@ class MessageRecv(Message): elif segment.type == "mention_bot": self.is_picid = False self.is_emoji = False - self.is_mentioned = float(segment.data) + self.is_mentioned = float(segment.data) # type: ignore return "" elif segment.type == "priority_info": self.is_picid = False @@ -186,7 +185,7 @@ class MessageRecv(Message): """生成详细文本,包含时间和用户信息""" timestamp = self.message_info.time user_info = self.message_info.user_info - name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" + name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore return f"[{timestamp}] {name}: {self.processed_plain_text}\n" @@ -234,7 +233,7 @@ class MessageProcessBase(Message): """ try: if seg.type == "text": - return seg.data + return seg.data # type: ignore elif seg.type == "image": # 如果是base64图片数据 if isinstance(seg.data, str): @@ -250,7 +249,7 @@ class MessageProcessBase(Message): if self.reply and hasattr(self.reply, "processed_plain_text"): # print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}") # print(f"reply: {self.reply}") - return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" + return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore return None else: return f"[{seg.type}:{str(seg.data)}]" @@ -264,7 +263,7 @@ class MessageProcessBase(Message): timestamp = self.message_info.time user_info = self.message_info.user_info - name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" + name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n" @@ -313,7 +312,7 @@ class MessageSending(MessageProcessBase): is_emoji: bool = False, thinking_start_time: float = 0, apply_set_reply_logic: bool = False, - reply_to: str = None, + reply_to: str = None, # type: ignore ): # 调用父类初始化 super().__init__( @@ -344,7 +343,7 @@ class MessageSending(MessageProcessBase): self.message_segment = Seg( type="seglist", data=[ - Seg(type="reply", data=self.reply.message_info.message_id), + Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore self.message_segment, ], ) @@ -364,10 +363,10 @@ class MessageSending(MessageProcessBase): ) -> "MessageSending": """从思考状态消息创建发送状态消息""" return cls( - message_id=thinking.message_info.message_id, + message_id=thinking.message_info.message_id, # type: ignore chat_stream=thinking.chat_stream, message_segment=message_segment, - bot_user_info=thinking.message_info.user_info, + bot_user_info=thinking.message_info.user_info, # type: ignore reply=thinking.reply, is_head=is_head, is_emoji=is_emoji, @@ -399,13 +398,11 @@ class MessageSet: if not isinstance(message, MessageSending): raise TypeError("MessageSet只能添加MessageSending类型的消息") self.messages.append(message) - self.messages.sort(key=lambda x: x.message_info.time) + self.messages.sort(key=lambda x: x.message_info.time) # type: ignore def get_message_by_index(self, index: int) -> Optional[MessageSending]: """通过索引获取消息""" - if 0 <= index < len(self.messages): - return self.messages[index] - return None + return self.messages[index] if 0 <= index < len(self.messages) else None def get_message_by_time(self, target_time: float) -> Optional[MessageSending]: """获取最接近指定时间的消息""" @@ -415,7 +412,7 @@ class MessageSet: left, right = 0, len(self.messages) - 1 while left < right: mid = (left + right) // 2 - if self.messages[mid].message_info.time < target_time: + if self.messages[mid].message_info.time < target_time: # type: ignore left = mid + 1 else: right = mid diff --git a/src/chat/message_receive/normal_message_sender.py b/src/chat/message_receive/normal_message_sender.py index aa6721db3..95d296473 100644 --- a/src/chat/message_receive/normal_message_sender.py +++ b/src/chat/message_receive/normal_message_sender.py @@ -1,21 +1,16 @@ -# src/plugins/chat/message_sender.py import asyncio import time from asyncio import Task from typing import Union -from src.common.message.api import get_global_api - -# from ...common.database import db # 数据库依赖似乎不需要了,注释掉 -from .message import MessageSending, MessageThinking, MessageSet - -from src.chat.message_receive.storage import MessageStorage -from ..utils.utils import truncate_message, calculate_typing_time, count_messages_between - -from src.common.logger import get_logger from rich.traceback import install -install(extra_lines=3) +from src.common.logger import get_logger +from src.common.message.api import get_global_api +from src.chat.message_receive.storage import MessageStorage +from src.chat.utils.utils import truncate_message, calculate_typing_time, count_messages_between +from .message import MessageSending, MessageThinking, MessageSet +install(extra_lines=3) logger = get_logger("sender") @@ -79,9 +74,10 @@ class MessageContainer: def count_thinking_messages(self) -> int: """计算当前容器中思考消息的数量""" - return sum(1 for msg in self.messages if isinstance(msg, MessageThinking)) + return sum(isinstance(msg, MessageThinking) for msg in self.messages) def get_timeout_sending_messages(self) -> list[MessageSending]: + # sourcery skip: merge-nested-ifs """获取所有超时的MessageSending对象(思考时间超过20秒),按thinking_start_time排序 - 从旧 sender 合并""" current_time = time.time() timeout_messages = [] @@ -230,9 +226,7 @@ class MessageManager: f"[{message.chat_stream.stream_id}] 处理发送消息 {getattr(message.message_info, 'message_id', 'N/A')} 时出错: {e}" ) logger.exception("详细错误信息:") - # 考虑是否移除出错的消息,防止无限循环 - removed = container.remove_message(message) - if removed: + if container.remove_message(message): logger.warning(f"[{message.chat_stream.stream_id}] 已移除处理出错的消息。") async def _process_chat_messages(self, chat_id: str): @@ -261,10 +255,7 @@ class MessageManager: # --- 处理发送消息 --- await self._handle_sending_message(container, message_earliest) - # --- 处理超时发送消息 (来自旧 sender) --- - # 在处理完最早的消息后,检查是否有超时的发送消息 - timeout_sending_messages = container.get_timeout_sending_messages() - if timeout_sending_messages: + if timeout_sending_messages := container.get_timeout_sending_messages(): logger.debug(f"[{chat_id}] 发现 {len(timeout_sending_messages)} 条超时的发送消息") for msg in timeout_sending_messages: # 确保不是刚刚处理过的最早消息 (虽然理论上应该已被移除,但以防万一) @@ -274,6 +265,7 @@ class MessageManager: await self._handle_sending_message(container, msg) # 复用处理逻辑 async def _start_processor_loop(self): + # sourcery skip: list-comprehension, move-assign-in-block, use-named-expression """消息处理器主循环""" while self._running: tasks = [] @@ -282,10 +274,7 @@ class MessageManager: # 创建 keys 的快照以安全迭代 chat_ids = list(self.containers.keys()) - for chat_id in chat_ids: - # 为每个 chat_id 创建一个处理任务 - tasks.append(asyncio.create_task(self._process_chat_messages(chat_id))) - + tasks.extend(asyncio.create_task(self._process_chat_messages(chat_id)) for chat_id in chat_ids) if tasks: try: # 等待当前批次的所有任务完成 diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index c40c4eb75..d5fc7b514 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,11 +1,10 @@ import re from typing import Union -# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance -from .message import MessageSending, MessageRecv -from .chat_stream import ChatStream -from ...common.database.database_model import Messages, RecalledMessages, Images # Import Peewee models +from src.common.database.database_model import Messages, RecalledMessages, Images from src.common.logger import get_logger +from .chat_stream import ChatStream +from .message import MessageSending, MessageRecv logger = get_logger("message_storage") @@ -44,7 +43,7 @@ class MessageStorage: reply_to = "" chat_info_dict = chat_stream.to_dict() - user_info_dict = message.message_info.user_info.to_dict() + user_info_dict = message.message_info.user_info.to_dict() # type: ignore # message_id 现在是 TextField,直接使用字符串值 msg_id = message.message_info.message_id @@ -56,7 +55,7 @@ class MessageStorage: Messages.create( message_id=msg_id, - time=float(message.message_info.time), + time=float(message.message_info.time), # type: ignore chat_id=chat_stream.stream_id, # Flattened chat_info reply_to=reply_to, @@ -103,7 +102,7 @@ class MessageStorage: try: # Assuming input 'time' is a string timestamp that can be converted to float current_time_float = float(time) - RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() + RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() # type: ignore except Exception: logger.exception("删除撤回消息失败") @@ -115,22 +114,19 @@ class MessageStorage: """更新最新一条匹配消息的message_id""" try: if message.message_segment.type == "notify": - mmc_message_id = message.message_segment.data.get("echo") - qq_message_id = message.message_segment.data.get("actual_id") + mmc_message_id = message.message_segment.data.get("echo") # type: ignore + qq_message_id = message.message_segment.data.get("actual_id") # type: ignore else: logger.info(f"更新消息ID错误,seg类型为{message.message_segment.type}") return if not qq_message_id: logger.info("消息不存在message_id,无法更新") return - # 查询最新一条匹配消息 - matched_message = ( + if matched_message := ( Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first() - ) - - if matched_message: + ): # 更新找到的消息记录 - Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() + Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") else: logger.debug("未找到匹配的消息") @@ -155,10 +151,7 @@ class MessageStorage: image_record = ( Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first() ) - if image_record: - return f"[picid:{image_record.image_id}]" - else: - return match.group(0) # 保持原样 + return f"[picid:{image_record.image_id}]" if image_record else match.group(0) except Exception: return match.group(0) diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 0efcf16d8..663bf23a8 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -1,16 +1,17 @@ import asyncio -from typing import Dict, Optional # 重新导入类型 -from src.chat.message_receive.message import MessageSending, MessageThinking -from src.common.message.api import get_global_api -from src.chat.message_receive.storage import MessageStorage -from src.chat.utils.utils import truncate_message -from src.common.logger import get_logger -from src.chat.utils.utils import calculate_typing_time -from rich.traceback import install import traceback -install(extra_lines=3) +from typing import Dict, Optional +from rich.traceback import install +from src.common.message.api import get_global_api +from src.common.logger import get_logger +from src.chat.message_receive.message import MessageSending, MessageThinking +from src.chat.message_receive.storage import MessageStorage +from src.chat.utils.utils import truncate_message +from src.chat.utils.utils import calculate_typing_time + +install(extra_lines=3) logger = get_logger("sender") @@ -86,10 +87,10 @@ class HeartFCSender: """ if not message.chat_stream: logger.error("消息缺少 chat_stream,无法发送") - raise Exception("消息缺少 chat_stream,无法发送") + raise ValueError("消息缺少 chat_stream,无法发送") if not message.message_info or not message.message_info.message_id: logger.error("消息缺少 message_info 或 message_id,无法发送") - raise Exception("消息缺少 message_info 或 message_id,无法发送") + raise ValueError("消息缺少 message_info 或 message_id,无法发送") chat_id = message.chat_stream.stream_id message_id = message.message_info.message_id diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index 63e394c7c..414d607a1 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -1,6 +1,7 @@ import asyncio import time import traceback + from random import random from typing import List, Optional, Dict from maim_message import UserInfo, Seg @@ -40,7 +41,7 @@ class NormalChat: def __init__( self, chat_stream: ChatStream, - interest_dict: dict = None, + interest_dict: Optional[Dict] = None, on_switch_to_focus_callback=None, get_cooldown_progress_callback=None, ): @@ -147,10 +148,7 @@ class NormalChat: while not self._disabled: try: if not self.priority_manager.is_empty(): - # 获取最高优先级的消息 - message = self.priority_manager.get_highest_priority_message() - - if message: + if message := self.priority_manager.get_highest_priority_message(): logger.info( f"[{self.stream_name}] 从队列中取出消息进行处理: User {message.message_info.user_info.user_id}, Time: {time.strftime('%H:%M:%S', time.localtime(message.message_info.time))}" ) diff --git a/src/chat/normal_chat/priority_manager.py b/src/chat/normal_chat/priority_manager.py index 0296017ff..8c1c0e731 100644 --- a/src/chat/normal_chat/priority_manager.py +++ b/src/chat/normal_chat/priority_manager.py @@ -53,7 +53,7 @@ class PriorityManager: """ 添加新消息到合适的队列中。 """ - user_id = message.message_info.user_info.user_id + user_id = message.message_info.user_info.user_id # type: ignore is_vip = message.priority_info.get("message_type") == "vip" if message.priority_info else False message_priority = message.priority_info.get("message_priority", 0.0) if message.priority_info else 0.0 diff --git a/src/chat/normal_chat/willing/mode_classical.py b/src/chat/normal_chat/willing/mode_classical.py index 0b296bbf4..7539274c1 100644 --- a/src/chat/normal_chat/willing/mode_classical.py +++ b/src/chat/normal_chat/willing/mode_classical.py @@ -35,9 +35,7 @@ class ClassicalWillingManager(BaseWillingManager): self.chat_reply_willing[chat_id] = min(current_willing, 3.0) - reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1) - - return reply_probability + return min(max((current_willing - 0.5), 0.01) * 2, 1) async def before_generate_reply_handle(self, message_id): chat_id = self.ongoing_messages[message_id].chat_id diff --git a/src/chat/normal_chat/willing/willing_manager.py b/src/chat/normal_chat/willing/willing_manager.py index 0fa701f94..f797bc3e0 100644 --- a/src/chat/normal_chat/willing/willing_manager.py +++ b/src/chat/normal_chat/willing/willing_manager.py @@ -1,14 +1,16 @@ -from src.common.logger import get_logger +import importlib +import asyncio + +from abc import ABC, abstractmethod +from typing import Dict, Optional +from rich.traceback import install from dataclasses import dataclass + +from src.common.logger import get_logger from src.config.config import global_config from src.chat.message_receive.chat_stream import ChatStream, GroupInfo from src.chat.message_receive.message import MessageRecv from src.person_info.person_info import PersonInfoManager, get_person_info_manager -from abc import ABC, abstractmethod -import importlib -from typing import Dict, Optional -import asyncio -from rich.traceback import install install(extra_lines=3) @@ -92,8 +94,8 @@ class BaseWillingManager(ABC): self.logger = logger def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float): - person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) - self.ongoing_messages[message.message_info.message_id] = WillingInfo( + person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) # type: ignore + self.ongoing_messages[message.message_info.message_id] = WillingInfo( # type: ignore message=message, chat=chat, person_info_manager=get_person_info_manager(), diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 45bdfd72d..ed045436f 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -27,14 +27,11 @@ class ActionManager: # 当前正在使用的动作集合,默认加载默认动作 self._using_actions: Dict[str, ActionInfo] = {} - # 默认动作集,仅作为快照,用于恢复默认 - self._default_actions: Dict[str, ActionInfo] = {} - # 加载插件动作 self._load_plugin_actions() # 初始化时将默认动作加载到使用中的动作 - self._using_actions = self._default_actions.copy() + self._using_actions = component_registry.get_default_actions() def _load_plugin_actions(self) -> None: """ @@ -52,7 +49,7 @@ class ActionManager: """从插件系统的component_registry加载Action组件""" try: # 获取所有Action组件 - action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) + action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore for action_name, action_info in action_components.items(): if action_name in self._registered_actions: @@ -61,10 +58,6 @@ class ActionManager: self._registered_actions[action_name] = action_info - # 如果启用,也添加到默认动作集 - if action_info.enabled: - self._default_actions[action_name] = action_info - logger.debug( f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})" ) @@ -106,7 +99,9 @@ class ActionManager: """ try: # 获取组件类 - 明确指定查询Action类型 - component_class = component_registry.get_component_class(action_name, ComponentType.ACTION) + component_class: Type[BaseAction] = component_registry.get_component_class( + action_name, ComponentType.ACTION + ) # type: ignore if not component_class: logger.warning(f"{log_prefix} 未找到Action组件: {action_name}") return None @@ -146,10 +141,6 @@ class ActionManager: """获取所有已注册的动作集""" return self._registered_actions.copy() - def get_default_actions(self) -> Dict[str, ActionInfo]: - """获取默认动作集""" - return self._default_actions.copy() - def get_using_actions(self) -> Dict[str, ActionInfo]: """获取当前正在使用的动作集合""" return self._using_actions.copy() @@ -217,31 +208,31 @@ class ActionManager: logger.debug(f"已从使用集中移除动作 {action_name}") return True - def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool: - """ - 添加新的动作到注册集 + # def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool: + # """ + # 添加新的动作到注册集 - Args: - action_name: 动作名称 - description: 动作描述 - parameters: 动作参数定义,默认为空字典 - require: 动作依赖项,默认为空列表 + # Args: + # action_name: 动作名称 + # description: 动作描述 + # parameters: 动作参数定义,默认为空字典 + # require: 动作依赖项,默认为空列表 - Returns: - bool: 添加是否成功 - """ - if action_name in self._registered_actions: - return False + # Returns: + # bool: 添加是否成功 + # """ + # if action_name in self._registered_actions: + # return False - if parameters is None: - parameters = {} - if require is None: - require = [] + # if parameters is None: + # parameters = {} + # if require is None: + # require = [] - action_info = {"description": description, "parameters": parameters, "require": require} + # action_info = {"description": description, "parameters": parameters, "require": require} - self._registered_actions[action_name] = action_info - return True + # self._registered_actions[action_name] = action_info + # return True def remove_action(self, action_name: str) -> bool: """从注册集移除指定动作""" @@ -260,10 +251,9 @@ class ActionManager: def restore_actions(self) -> None: """恢复到默认动作集""" - logger.debug( - f"恢复动作集: 从 {list(self._using_actions.keys())} 恢复到默认动作集 {list(self._default_actions.keys())}" - ) - self._using_actions = self._default_actions.copy() + actions_to_restore = list(self._using_actions.keys()) + self._using_actions = component_registry.get_default_actions() + logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") def add_system_action_if_needed(self, action_name: str) -> bool: """ @@ -293,4 +283,4 @@ class ActionManager: """ from src.plugin_system.core.component_registry import component_registry - return component_registry.get_component_class(action_name) + return component_registry.get_component_class(action_name) # type: ignore diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 8aaafc201..21a4ce06e 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -2,7 +2,7 @@ import random import asyncio import hashlib import time -from typing import List, Any, Dict +from typing import List, Any, Dict, TYPE_CHECKING from src.common.logger import get_logger from src.config.config import global_config @@ -13,6 +13,9 @@ from src.chat.planner_actions.action_manager import ActionManager from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages from src.plugin_system.base.component_types import ChatMode, ActionInfo, ActionActivationType +if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream + logger = get_logger("action_manager") @@ -27,7 +30,7 @@ class ActionModifier: def __init__(self, action_manager: ActionManager, chat_id: str): """初始化动作处理器""" self.chat_id = chat_id - self.chat_stream = get_chat_manager().get_stream(self.chat_id) + self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]" self.action_manager = action_manager @@ -142,7 +145,7 @@ class ActionModifier: async def _get_deactivated_actions_by_type( self, actions_with_info: Dict[str, ActionInfo], - mode: str = "focus", + mode: ChatMode = ChatMode.FOCUS, chat_content: str = "", ) -> List[tuple[str, str]]: """ @@ -270,7 +273,7 @@ class ActionModifier: task_results = await asyncio.gather(*tasks, return_exceptions=True) # 处理结果并更新缓存 - for _, (action_name, result) in enumerate(zip(task_names, task_results)): + for action_name, result in zip(task_names, task_results): if isinstance(result, Exception): logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}") results[action_name] = False @@ -286,7 +289,7 @@ class ActionModifier: except Exception as e: logger.error(f"{self.log_prefix}并行LLM判定失败: {e}") # 如果并行执行失败,为所有任务返回False - for action_name in tasks_to_run.keys(): + for action_name in tasks_to_run: results[action_name] = False # 清理过期缓存 @@ -297,10 +300,11 @@ class ActionModifier: def _cleanup_expired_cache(self, current_time: float): """清理过期的缓存条目""" expired_keys = [] - for cache_key, cache_data in self._llm_judge_cache.items(): - if current_time - cache_data["timestamp"] > self._cache_expiry_time: - expired_keys.append(cache_key) - + expired_keys.extend( + cache_key + for cache_key, cache_data in self._llm_judge_cache.items() + if current_time - cache_data["timestamp"] > self._cache_expiry_time + ) for key in expired_keys: del self._llm_judge_cache[key] @@ -379,7 +383,7 @@ class ActionModifier: def _check_keyword_activation( self, action_name: str, - action_info: Dict[str, Any], + action_info: ActionInfo, chat_content: str = "", ) -> bool: """ @@ -396,8 +400,8 @@ class ActionModifier: bool: 是否应该激活此action """ - activation_keywords = action_info.get("activation_keywords", []) - case_sensitive = action_info.get("keyword_case_sensitive", False) + activation_keywords = action_info.activation_keywords + case_sensitive = action_info.keyword_case_sensitive if not activation_keywords: logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词") diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index f4c8a9a4a..850f43d12 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -70,7 +70,7 @@ class ActionPlanner: self.last_obs_time_mark = 0.0 - async def plan(self) -> Dict[str, Any]: + async def plan(self) -> Dict[str, Any]: # sourcery skip: dict-comprehension """ 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 """ @@ -162,7 +162,6 @@ class ActionPlanner: reasoning = parsed_json.get("reasoning", "未提供原因") # 将所有其他属性添加到action_data - action_data = {} for key, value in parsed_json.items(): if key not in ["action", "reasoning"]: action_data[key] = value @@ -285,7 +284,7 @@ class ActionPlanner: identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:" planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") - prompt = planner_prompt_template.format( + return planner_prompt_template.format( time_block=time_block, by_what=by_what, chat_context_description=chat_context_description, @@ -295,8 +294,6 @@ class ActionPlanner: moderation_prompt=moderation_prompt_block, identity_block=identity_block, ) - return prompt - except Exception as e: logger.error(f"构建 Planner 提示词时出错: {e}") logger.error(traceback.format_exc()) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 6cb526d11..084dfd58c 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -130,9 +130,7 @@ class DefaultReplyer: # 提取权重,如果模型配置中没有'weight'键,则默认为1.0 weights = [config.get("weight", 1.0) for config in configs] - # random.choices 返回一个列表,我们取第一个元素 - selected_config = random.choices(population=configs, weights=weights, k=1)[0] - return selected_config + return random.choices(population=configs, weights=weights, k=1)[0] async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str): """创建思考消息 (尝试锚定到 anchor_message)""" @@ -314,8 +312,7 @@ class DefaultReplyer: logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID,跳过信息提取") return f"你完全不认识{sender},不理解ta的相关信息。" - relation_info = await relationship_fetcher.build_relation_info(person_id, text, chat_history) - return relation_info + return await relationship_fetcher.build_relation_info(person_id, text, chat_history) async def build_expression_habits(self, chat_history, target): if not global_config.expression.enable_expression: @@ -363,15 +360,13 @@ class DefaultReplyer: target_message=target, chat_history_prompt=chat_history ) - if running_memories: - memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" - for running_memory in running_memories: - memory_str += f"- {running_memory['content']}\n" - memory_block = memory_str - else: - memory_block = "" + if not running_memories: + return "" - return memory_block + memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" + for running_memory in running_memories: + memory_str += f"- {running_memory['content']}\n" + return memory_str async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True): """构建工具信息块 @@ -453,7 +448,7 @@ class DefaultReplyer: for name, content in result.groupdict().items(): reaction = reaction.replace(f"[{name}]", content) logger.info(f"匹配到正则表达式:{pattern_str},触发反应:{reaction}") - keywords_reaction_prompt += reaction + "," + keywords_reaction_prompt += f"{reaction}," break except re.error as e: logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}") @@ -477,7 +472,7 @@ class DefaultReplyer: available_actions: Optional[Dict[str, ActionInfo]] = None, enable_timeout: bool = False, enable_tool: bool = True, - ) -> str: + ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if """ 构建回复器上下文 @@ -612,7 +607,7 @@ class DefaultReplyer: short_impression = ["友好活泼", "人类"] personality = short_impression[0] identity = short_impression[1] - prompt_personality = personality + "," + identity + prompt_personality = f"{personality},{identity}" identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" moderation_prompt_block = ( @@ -660,7 +655,7 @@ class DefaultReplyer: "chat_target_private2", sender_name=chat_target_name ) - prompt = await global_prompt_manager.format_prompt( + return await global_prompt_manager.format_prompt( template_name, expression_habits_block=expression_habits_block, chat_target=chat_target_1, @@ -683,8 +678,6 @@ class DefaultReplyer: mood_state=mood_prompt, ) - return prompt - async def build_prompt_rewrite_context( self, reply_data: Dict[str, Any], @@ -745,7 +738,7 @@ class DefaultReplyer: short_impression = ["友好活泼", "人类"] personality = short_impression[0] identity = short_impression[1] - prompt_personality = personality + "," + identity + prompt_personality = f"{personality},{identity}" identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" moderation_prompt_block = ( @@ -790,7 +783,7 @@ class DefaultReplyer: template_name = "default_expressor_prompt" - prompt = await global_prompt_manager.format_prompt( + return await global_prompt_manager.format_prompt( template_name, expression_habits_block=expression_habits_block, relation_info_block=relation_info, @@ -807,8 +800,6 @@ class DefaultReplyer: moderation_prompt=moderation_prompt_block, ) - return prompt - async def send_response_messages( self, anchor_message: Optional[MessageRecv], @@ -816,6 +807,7 @@ class DefaultReplyer: thinking_id: str = "", display_message: str = "", ) -> Optional[MessageSending]: + # sourcery skip: assign-if-exp, boolean-if-exp-identity, remove-unnecessary-cast """发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender""" chat = self.chat_stream chat_id = self.chat_stream.stream_id @@ -849,16 +841,16 @@ class DefaultReplyer: for i, msg_text in enumerate(response_set): # 为每个消息片段生成唯一ID - type = msg_text[0] + msg_type = msg_text[0] data = msg_text[1] - if global_config.debug.debug_show_chat_mode and type == "text": + if global_config.debug.debug_show_chat_mode and msg_type == "text": data += "ᶠ" part_message_id = f"{thinking_id}_{i}" - message_segment = Seg(type=type, data=data) + message_segment = Seg(type=msg_type, data=data) - if type == "emoji": + if msg_type == "emoji": is_emoji = True else: is_emoji = False @@ -871,7 +863,6 @@ class DefaultReplyer: display_message=display_message, reply_to=reply_to, is_emoji=is_emoji, - thinking_id=thinking_id, thinking_start_time=thinking_start_time, ) @@ -895,7 +886,7 @@ class DefaultReplyer: reply_message_ids.append(part_message_id) # 记录我们生成的ID - sent_msg_list.append((type, sent_msg)) + sent_msg_list.append((msg_type, sent_msg)) except Exception as e: logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}") @@ -930,12 +921,9 @@ class DefaultReplyer: ) # await anchor_message.process() - if anchor_message: - sender_info = anchor_message.message_info.user_info - else: - sender_info = None + sender_info = anchor_message.message_info.user_info if anchor_message else None - bot_message = MessageSending( + return MessageSending( message_id=message_id, # 使用片段的唯一ID chat_stream=self.chat_stream, bot_user_info=bot_user_info, @@ -948,8 +936,6 @@ class DefaultReplyer: display_message=display_message, ) - return bot_message - def weighted_sample_no_replacement(items, weights, k) -> list: """ diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index 6a73b7d4b..a2a2aaaa0 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -1,4 +1,5 @@ from typing import Dict, Any, Optional, List + from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.replyer.default_generator import DefaultReplyer from src.common.logger import get_logger @@ -8,7 +9,7 @@ logger = get_logger("ReplyerManager") class ReplyerManager: def __init__(self): - self._replyers: Dict[str, DefaultReplyer] = {} + self._repliers: Dict[str, DefaultReplyer] = {} def get_replyer( self, @@ -29,17 +30,16 @@ class ReplyerManager: return None # 如果已有缓存实例,直接返回 - if stream_id in self._replyers: + if stream_id in self._repliers: logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 返回已存在的回复器实例。") - return self._replyers[stream_id] + return self._repliers[stream_id] # 如果没有缓存,则创建新实例(首次初始化) logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 创建新的回复器实例并缓存。") target_stream = chat_stream if not target_stream: - chat_manager = get_chat_manager() - if chat_manager: + if chat_manager := get_chat_manager(): target_stream = chat_manager.get_stream(stream_id) if not target_stream: @@ -52,7 +52,7 @@ class ReplyerManager: model_configs=model_configs, # 可以是None,此时使用默认模型 request_type=request_type, ) - self._replyers[stream_id] = replyer + self._repliers[stream_id] = replyer return replyer diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index ab97f395b..06044defb 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -1,14 +1,15 @@ -from src.config.config import global_config -from typing import List, Dict, Any, Tuple # 确保类型提示被导入 import time # 导入 time 模块以获取当前时间 import random import re -from src.common.message_repository import find_messages, count_messages -from src.person_info.person_info import PersonInfoManager, get_person_info_manager -from src.chat.utils.utils import translate_timestamp_to_human_readable +from typing import List, Dict, Any, Tuple, Optional from rich.traceback import install + +from src.config.config import global_config +from src.common.message_repository import find_messages, count_messages from src.common.database.database_model import ActionRecords from src.common.database.database_model import Images +from src.person_info.person_info import PersonInfoManager, get_person_info_manager +from src.chat.utils.utils import translate_timestamp_to_human_readable install(extra_lines=3) @@ -135,7 +136,7 @@ def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float = None) -> int: +def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: """ 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。 如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。 @@ -172,7 +173,7 @@ def _build_readable_messages_internal( merge_messages: bool = False, timestamp_mode: str = "relative", truncate: bool = False, - pic_id_mapping: Dict[str, str] = None, + pic_id_mapping: Optional[Dict[str, str]] = None, pic_counter: int = 1, show_pic: bool = True, ) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]: @@ -194,7 +195,7 @@ def _build_readable_messages_internal( if not messages: return "", [], pic_id_mapping or {}, pic_counter - message_details_raw: List[Tuple[float, str, str]] = [] + message_details_raw: List[Tuple[float, str, str, bool]] = [] # 使用传入的映射字典,如果没有则创建新的 if pic_id_mapping is None: @@ -225,7 +226,7 @@ def _build_readable_messages_internal( # 检查是否是动作记录 if msg.get("is_action_record", False): is_action = True - timestamp = msg.get("time") + timestamp: float = msg.get("time") # type: ignore content = msg.get("display_message", "") # 对于动作记录,也处理图片ID content = process_pic_ids(content) @@ -249,9 +250,10 @@ def _build_readable_messages_internal( user_nickname = user_info.get("user_nickname") user_cardname = user_info.get("user_cardname") - timestamp = msg.get("time") + timestamp: float = msg.get("time") # type: ignore + content: str if msg.get("display_message"): - content = msg.get("display_message") + content = msg.get("display_message", "") else: content = msg.get("processed_plain_text", "") # 默认空字符串 @@ -271,6 +273,7 @@ def _build_readable_messages_internal( person_id = PersonInfoManager.get_person_id(platform, user_id) person_info_manager = get_person_info_manager() # 根据 replace_bot_name 参数决定是否替换机器人名称 + person_name: str if replace_bot_name and user_id == global_config.bot.qq_account: person_name = f"{global_config.bot.nickname}(你)" else: @@ -289,12 +292,10 @@ def _build_readable_messages_internal( reply_pattern = r"回复<([^:<>]+):([^:<>]+)>" match = re.search(reply_pattern, content) if match: - aaa = match.group(1) - bbb = match.group(2) + aaa: str = match[1] + bbb: str = match[2] reply_person_id = PersonInfoManager.get_person_id(platform, bbb) - reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") - if not reply_person_name: - reply_person_name = aaa + reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") or aaa # 在内容前加上回复信息 content = re.sub(reply_pattern, lambda m, name=reply_person_name: f"回复 {name}", content, count=1) @@ -309,18 +310,15 @@ def _build_readable_messages_internal( aaa = m.group(1) bbb = m.group(2) at_person_id = PersonInfoManager.get_person_id(platform, bbb) - at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") - if not at_person_name: - at_person_name = aaa + at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") or aaa new_content += f"@{at_person_name}" last_end = m.end() new_content += content[last_end:] content = new_content target_str = "这是QQ的一个功能,用于提及某人,但没那么明显" - if target_str in content: - if random.random() < 0.6: - content = content.replace(target_str, "") + if target_str in content and random.random() < 0.6: + content = content.replace(target_str, "") if content != "": message_details_raw.append((timestamp, person_name, content, False)) @@ -470,6 +468,7 @@ def _build_readable_messages_internal( def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: + # sourcery skip: use-contextlib-suppress """ 构建图片映射信息字符串,显示图片的具体描述内容 @@ -518,9 +517,7 @@ async def build_readable_messages_with_list( messages, replace_bot_name, merge_messages, timestamp_mode, truncate ) - # 生成图片映射信息并添加到最前面 - pic_mapping_info = build_pic_mapping_info(pic_id_mapping) - if pic_mapping_info: + if pic_mapping_info := build_pic_mapping_info(pic_id_mapping): formatted_string = f"{pic_mapping_info}\n\n{formatted_string}" return formatted_string, details_list @@ -535,7 +532,7 @@ def build_readable_messages( truncate: bool = False, show_actions: bool = False, show_pic: bool = True, -) -> str: +) -> str: # sourcery skip: extract-method """ 将消息列表转换为可读的文本格式。 如果提供了 read_mark,则在相应位置插入已读标记。 @@ -658,9 +655,7 @@ def build_readable_messages( # 组合结果 result_parts = [] if pic_mapping_info: - result_parts.append(pic_mapping_info) - result_parts.append("\n") - + result_parts.extend((pic_mapping_info, "\n")) if formatted_before and formatted_after: result_parts.extend([formatted_before, read_mark_line, formatted_after]) elif formatted_before: @@ -733,8 +728,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: platform = msg.get("chat_info_platform") user_id = msg.get("user_id") _timestamp = msg.get("time") + content: str = "" if msg.get("display_message"): - content = msg.get("display_message") + content = msg.get("display_message", "") else: content = msg.get("processed_plain_text", "") @@ -829,10 +825,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: if not all([platform, user_id]) or user_id == global_config.bot.qq_account: continue - person_id = PersonInfoManager.get_person_id(platform, user_id) - - # 只有当获取到有效 person_id 时才添加 - if person_id: + if person_id := PersonInfoManager.get_person_id(platform, user_id): person_ids_set.add(person_id) return list(person_ids_set) # 将集合转换为列表返回 diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 17cfb2323..5579ccf84 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -103,7 +103,7 @@ class ImageManager: image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore # 查询缓存的描述 cached_description = self._get_description_from_db(image_hash, "emoji") @@ -154,7 +154,7 @@ class ImageManager: img_obj.description = description img_obj.timestamp = current_timestamp img_obj.save() - except Images.DoesNotExist: + except Images.DoesNotExist: # type: ignore Images.create( emoji_hash=image_hash, path=file_path, @@ -204,7 +204,7 @@ class ImageManager: return f"[图片:{cached_description}]" # 调用AI获取描述 - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来,请留意其主题,直观感受,输出为一段平文本,最多50字" description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) @@ -491,7 +491,7 @@ class ImageManager: return # 获取图片格式 - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore # 构建prompt prompt = """请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本""" diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 7f22fc2d4..f44a88225 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -3,7 +3,7 @@ from src.common.database.database import db from src.common.database.database_model import PersonInfo # 新增导入 import copy import hashlib -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Union import datetime import asyncio from src.llm_models.utils_model import LLMRequest @@ -84,7 +84,7 @@ class PersonInfoManager: logger.error(f"从 Peewee 加载 person_name_list 失败: {e}") @staticmethod - def get_person_id(platform: str, user_id: int): + def get_person_id(platform: str, user_id: Union[int, str]) -> str: """获取唯一id""" if "-" in platform: platform = platform.split("-")[1] diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 42e36b64d..73c883e0a 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -32,10 +32,10 @@ class BaseAction(ABC): reasoning: str, cycle_timers: dict, thinking_id: str, - chat_stream: ChatStream = None, + chat_stream: ChatStream, log_prefix: str = "", shutting_down: bool = False, - plugin_config: dict = None, + plugin_config: Optional[dict] = None, **kwargs, ): """初始化Action组件 diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 8977c5e70..2c2ddf81e 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -29,7 +29,7 @@ class BaseCommand(ABC): command_examples: List[str] = [] intercept_message: bool = True # 默认拦截消息,不继续处理 - def __init__(self, message: MessageRecv, plugin_config: dict = None): + def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None): """初始化Command组件 Args: diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index b8112a490..fe3813b88 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -66,7 +66,7 @@ class BasePlugin(ABC): config_section_descriptions: Dict[str, str] = {} - def __init__(self, plugin_dir: str = None): + def __init__(self, plugin_dir: str): """初始化插件 Args: @@ -526,7 +526,7 @@ class BasePlugin(ABC): # 从配置中更新 enable_plugin if "plugin" in self.config and "enabled" in self.config["plugin"]: - self.enable_plugin = self.config["plugin"]["enabled"] + self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}") else: logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml") diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index bc66100d9..2bac36e5c 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -81,7 +81,9 @@ class ComponentInfo: class ActionInfo(ComponentInfo): """动作组件信息""" - action_parameters: Dict[str, str] = field(default_factory=dict) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"} + action_parameters: Dict[str, str] = field( + default_factory=dict + ) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"} action_require: List[str] = field(default_factory=list) # 动作需求说明 associated_types: List[str] = field(default_factory=list) # 关联的消息类型 # 激活类型相关 diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 2ec77c7b7..b152a1abc 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Any, Pattern, Union +from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type import re from src.common.logger import get_logger from src.plugin_system.base.component_types import ( @@ -28,25 +28,25 @@ class ComponentRegistry: ComponentType.ACTION: {}, ComponentType.COMMAND: {}, } - self._component_classes: Dict[str, Union[BaseCommand, BaseAction]] = {} # 组件名 -> 组件类 + self._component_classes: Dict[str, Union[Type[BaseCommand], Type[BaseAction]]] = {} # 组件名 -> 组件类 # 插件注册表 self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息 # Action特定注册表 - self._action_registry: Dict[str, BaseAction] = {} # action名 -> action类 - # self._action_descriptions: Dict[str, str] = {} # 启用的action名 -> 描述 + self._action_registry: Dict[str, Type[BaseAction]] = {} # action名 -> action类 + self._default_actions: Dict[str, ActionInfo] = {} # 默认动作集,即启用的Action集,用于重置ActionManager状态 # Command特定注册表 - self._command_registry: Dict[str, BaseCommand] = {} # command名 -> command类 - self._command_patterns: Dict[Pattern, BaseCommand] = {} # 编译后的正则 -> command类 + self._command_registry: Dict[str, Type[BaseCommand]] = {} # command名 -> command类 + self._command_patterns: Dict[Pattern, Type[BaseCommand]] = {} # 编译后的正则 -> command类 logger.info("组件注册中心初始化完成") # === 通用组件注册方法 === def register_component( - self, component_info: ComponentInfo, component_class: Union[BaseCommand, BaseAction] + self, component_info: ComponentInfo, component_class: Union[Type[BaseCommand], Type[BaseAction]] ) -> bool: """注册组件 @@ -88,9 +88,9 @@ class ComponentRegistry: # 根据组件类型进行特定注册(使用原始名称) if component_type == ComponentType.ACTION: - self._register_action_component(component_info, component_class) + self._register_action_component(component_info, component_class) # type: ignore elif component_type == ComponentType.COMMAND: - self._register_command_component(component_info, component_class) + self._register_command_component(component_info, component_class) # type: ignore logger.debug( f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' " @@ -98,7 +98,7 @@ class ComponentRegistry: ) return True - def _register_action_component(self, action_info: ActionInfo, action_class: BaseAction): + def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]): # -------------------------------- NEED REFACTORING -------------------------------- # -------------------------------- LOGIC ERROR ------------------------------------- """注册Action组件到Action特定注册表""" @@ -106,11 +106,10 @@ class ComponentRegistry: self._action_registry[action_name] = action_class # 如果启用,添加到默认动作集 - # ---- HERE ---- - # if action_info.enabled: - # self._action_descriptions[action_name] = action_info.description + if action_info.enabled: + self._default_actions[action_name] = action_info - def _register_command_component(self, command_info: CommandInfo, command_class: BaseCommand): + def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]): """注册Command组件到Command特定注册表""" command_name = command_info.name self._command_registry[command_name] = command_class @@ -122,7 +121,7 @@ class ComponentRegistry: # === 组件查询方法 === - def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]: + def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]: # type: ignore # sourcery skip: class-extract-method """获取组件信息,支持自动命名空间解析 @@ -170,8 +169,10 @@ class ComponentRegistry: return None def get_component_class( - self, component_name: str, component_type: ComponentType = None - ) -> Optional[Union[BaseCommand, BaseAction]]: + self, + component_name: str, + component_type: ComponentType = None, # type: ignore + ) -> Optional[Union[Type[BaseCommand], Type[BaseAction]]]: """获取组件类,支持自动命名空间解析 Args: @@ -230,7 +231,7 @@ class ComponentRegistry: # === Action特定查询方法 === - def get_action_registry(self) -> Dict[str, BaseAction]: + def get_action_registry(self) -> Dict[str, Type[BaseAction]]: """获取Action注册表(用于兼容现有系统)""" return self._action_registry.copy() @@ -239,13 +240,17 @@ class ComponentRegistry: info = self.get_component_info(action_name, ComponentType.ACTION) return info if isinstance(info, ActionInfo) else None + def get_default_actions(self) -> Dict[str, ActionInfo]: + """获取默认动作集""" + return self._default_actions.copy() + # === Command特定查询方法 === - def get_command_registry(self) -> Dict[str, BaseCommand]: + def get_command_registry(self) -> Dict[str, Type[BaseCommand]]: """获取Command注册表(用于兼容现有系统)""" return self._command_registry.copy() - def get_command_patterns(self) -> Dict[Pattern, BaseCommand]: + def get_command_patterns(self) -> Dict[Pattern, Type[BaseCommand]]: """获取Command模式注册表(用于兼容现有系统)""" return self._command_patterns.copy() @@ -254,7 +259,7 @@ class ComponentRegistry: info = self.get_component_info(command_name, ComponentType.COMMAND) return info if isinstance(info, CommandInfo) else None - def find_command_by_text(self, text: str) -> Optional[tuple[BaseCommand, dict, bool, str]]: + def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]: # sourcery skip: use-named-expression, use-next """根据文本查找匹配的命令 @@ -262,7 +267,7 @@ class ComponentRegistry: text: 输入文本 Returns: - Optional[tuple[BaseCommand, dict, bool, str]]: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None + Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None """ for pattern, command_class in self._command_patterns.items(): From d2ad6ea1d8f84a73ef43133e85aefe9c4c69fc1a Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sat, 12 Jul 2025 10:18:16 +0800 Subject: [PATCH 11/13] fix typo --- .../{exprssion_learner.py => expression_learner.py} | 0 src/chat/express/expression_selector.py | 4 +++- src/chat/heart_flow/heartflow_message_processor.py | 7 +++++-- src/chat/message_receive/message.py | 2 +- src/main.py | 2 +- 5 files changed, 10 insertions(+), 5 deletions(-) rename src/chat/express/{exprssion_learner.py => expression_learner.py} (100%) diff --git a/src/chat/express/exprssion_learner.py b/src/chat/express/expression_learner.py similarity index 100% rename from src/chat/express/exprssion_learner.py rename to src/chat/express/expression_learner.py diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 0b1eaef7a..03456e27e 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -10,7 +10,7 @@ from src.llm_models.utils_model import LLMRequest from src.config.config import global_config from src.common.logger import get_logger from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from .exprssion_learner import get_expression_learner +from .expression_learner import get_expression_learner logger = get_logger("expression_selector") @@ -84,6 +84,7 @@ class ExpressionSelector: def get_random_expressions( self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: + # sourcery skip: extract-duplicate-method, move-assign ( learnt_style_expressions, learnt_grammar_expressions, @@ -174,6 +175,7 @@ class ExpressionSelector: min_num: int = 5, target_message: Optional[str] = None, ) -> List[Dict[str, str]]: + # sourcery skip: inline-variable, list-comprehension """使用LLM选择适合的表达方式""" # 1. 获取35个随机表达方式(现在按权重抽取) diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index 4ab29b38e..dd267b079 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -3,7 +3,7 @@ import re import math import traceback -from typing import Tuple +from typing import Tuple, TYPE_CHECKING from src.config.config import global_config from src.chat.memory_system.Hippocampus import hippocampus_manager @@ -16,6 +16,9 @@ from src.common.logger import get_logger from src.person_info.relationship_manager import get_relationship_manager from src.mood.mood_manager import mood_manager +if TYPE_CHECKING: + from src.chat.heart_flow.sub_heartflow import SubHeartflow + logger = get_logger("chat") @@ -104,7 +107,7 @@ class HeartFCMessageReceiver: await self.storage.store_message(message, chat) - subheartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) + subheartflow: SubHeartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) # type: ignore subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 44a1da26f..f444c768f 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -112,7 +112,7 @@ class MessageRecv(Message): self.is_mentioned = None self.priority_mode = "interest" self.priority_info = None - self.interest_value = None + self.interest_value: float = None # type: ignore def update_chat_stream(self, chat_stream: "ChatStream"): self.chat_stream = chat_stream diff --git a/src/main.py b/src/main.py index 64129814e..d481c7d03 100644 --- a/src/main.py +++ b/src/main.py @@ -2,7 +2,7 @@ import asyncio import time from maim_message import MessageServer -from src.chat.express.exprssion_learner import get_expression_learner +from src.chat.express.expression_learner import get_expression_learner from src.common.remote import TelemetryHeartBeatTask from src.manager.async_task_manager import async_task_manager from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask From 7ef0bfb7c8c8763aafa398a49fdac7d966ec12f3 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sun, 13 Jul 2025 00:19:54 +0800 Subject: [PATCH 12/13] =?UTF-8?q?=E5=AE=8C=E6=88=90=E6=89=80=E6=9C=89?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E6=B3=A8=E8=A7=A3=E7=9A=84=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/default_generator.py | 23 ++--- src/chat/replyer/replyer_manager.py | 2 +- src/chat/utils/chat_message_builder.py | 15 +-- src/chat/utils/json_utils.py | 29 +++--- src/chat/utils/prompt_builder.py | 30 +++--- src/chat/utils/statistic.py | 95 ++++++++----------- src/chat/utils/timer_calculator.py | 9 +- src/chat/utils/typo_generator.py | 8 +- src/chat/utils/utils.py | 90 +++++++----------- src/chat/utils/utils_image.py | 38 +++----- src/common/database/database.py | 4 +- src/common/database/database_model.py | 9 +- src/common/logger.py | 48 +++++----- src/common/message/api.py | 7 +- src/common/message_repository.py | 6 +- src/common/remote.py | 12 +-- src/config/auto_update.py | 12 +-- src/config/config.py | 19 ++-- src/config/config_base.py | 2 +- src/config/official_configs.py | 13 +-- src/individuality/identity.py | 4 +- src/individuality/individuality.py | 31 +++--- src/individuality/personality.py | 10 +- src/manager/async_task_manager.py | 7 +- src/mood/mood_manager.py | 16 ++-- src/person_info/person_info.py | 59 ++++++------ src/person_info/relationship_builder.py | 22 ++--- .../relationship_builder_manager.py | 9 +- src/person_info/relationship_fetcher.py | 43 +++++---- src/person_info/relationship_manager.py | 45 ++++----- src/plugin_system/apis/generator_api.py | 38 ++++---- src/tools/tool_executor.py | 37 ++++---- 32 files changed, 358 insertions(+), 434 deletions(-) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 3ad3fe4cf..a9214a9af 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -4,6 +4,7 @@ import asyncio import random import ast import re + from typing import List, Optional, Dict, Any, Tuple from datetime import datetime @@ -161,13 +162,13 @@ class DefaultReplyer: async def generate_reply_with_context( self, - reply_data: Dict[str, Any] = None, + reply_data: Optional[Dict[str, Any]] = None, reply_to: str = "", extra_info: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, enable_tool: bool = True, enable_timeout: bool = False, - ) -> Tuple[bool, Optional[str]]: + ) -> Tuple[bool, Optional[str], Optional[str]]: """ 回复器 (Replier): 核心逻辑,负责生成回复文本。 (已整合原 HeartFCGenerator 的功能) @@ -225,14 +226,14 @@ class DefaultReplyer: except Exception as llm_e: # 精简报错信息 logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}") - return False, None # LLM 调用失败则无法生成回复 + return False, None, prompt # LLM 调用失败则无法生成回复 return True, content, prompt except Exception as e: logger.error(f"{self.log_prefix}回复生成意外失败: {e}") traceback.print_exc() - return False, None + return False, None, prompt async def rewrite_reply_with_context( self, @@ -368,7 +369,7 @@ class DefaultReplyer: memory_str += f"- {running_memory['content']}\n" return memory_str - async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True): + async def build_tool_info(self, chat_history, reply_data: Optional[Dict], enable_tool: bool = True): """构建工具信息块 Args: @@ -393,7 +394,7 @@ class DefaultReplyer: try: # 使用工具执行器获取信息 - tool_results = await self.tool_executor.execute_from_chat_message( + tool_results, _, _ = await self.tool_executor.execute_from_chat_message( sender=sender, target_message=text, chat_history=chat_history, return_details=False ) @@ -468,7 +469,7 @@ class DefaultReplyer: async def build_prompt_reply_context( self, - reply_data=None, + reply_data: Dict[str, Any], available_actions: Optional[Dict[str, ActionInfo]] = None, enable_timeout: bool = False, enable_tool: bool = True, @@ -549,7 +550,7 @@ class DefaultReplyer: ), self._time_and_run_task(self.build_memory_block(chat_talking_prompt_half, target), "build_memory_block"), self._time_and_run_task( - self.build_tool_info(reply_data, chat_talking_prompt_half, enable_tool=enable_tool), "build_tool_info" + self.build_tool_info(chat_talking_prompt_half, reply_data, enable_tool=enable_tool), "build_tool_info" ), ) @@ -806,7 +807,7 @@ class DefaultReplyer: response_set: List[Tuple[str, str]], thinking_id: str = "", display_message: str = "", - ) -> Optional[MessageSending]: + ) -> Optional[List[Tuple[str, bool]]]: # sourcery skip: assign-if-exp, boolean-if-exp-identity, remove-unnecessary-cast """发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender""" chat = self.chat_stream @@ -869,7 +870,7 @@ class DefaultReplyer: try: if ( bot_message.is_private_message() - or bot_message.reply.processed_plain_text != "[System Trigger Context]" + or bot_message.reply.processed_plain_text != "[System Trigger Context]" # type: ignore or mark_head ): set_reply = False @@ -910,7 +911,7 @@ class DefaultReplyer: is_emoji: bool, thinking_start_time: float, display_message: str, - anchor_message: MessageRecv = None, + anchor_message: Optional[MessageRecv] = None, ) -> MessageSending: """构建单个发送消息""" diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index a2a2aaaa0..3f1c731b4 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -1,8 +1,8 @@ from typing import Dict, Any, Optional, List +from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.replyer.default_generator import DefaultReplyer -from src.common.logger import get_logger logger = get_logger("ReplyerManager") diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 8c579e6d3..6bdf7f58d 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -1,6 +1,7 @@ import time # 导入 time 模块以获取当前时间 import random import re + from typing import List, Dict, Any, Tuple, Optional from rich.traceback import install @@ -88,8 +89,8 @@ def get_actions_by_timestamp_with_chat( """获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表""" query = ActionRecords.select().where( (ActionRecords.chat_id == chat_id) - & (ActionRecords.time > timestamp_start) - & (ActionRecords.time < timestamp_end) + & (ActionRecords.time > timestamp_start) # type: ignore + & (ActionRecords.time < timestamp_end) # type: ignore ) if limit > 0: @@ -113,8 +114,8 @@ def get_actions_by_timestamp_with_chat_inclusive( """获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表""" query = ActionRecords.select().where( (ActionRecords.chat_id == chat_id) - & (ActionRecords.time >= timestamp_start) - & (ActionRecords.time <= timestamp_end) + & (ActionRecords.time >= timestamp_start) # type: ignore + & (ActionRecords.time <= timestamp_end) # type: ignore ) if limit > 0: @@ -331,7 +332,7 @@ def _build_readable_messages_internal( if replace_bot_name and user_id == global_config.bot.qq_account: person_name = f"{global_config.bot.nickname}(你)" else: - person_name = person_info_manager.get_value_sync(person_id, "person_name") + person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore # 如果 person_name 未设置,则使用消息中的 nickname 或默认名称 if not person_name: @@ -911,8 +912,8 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: person_ids_set = set() # 使用集合来自动去重 for msg in messages: - platform = msg.get("user_platform") - user_id = msg.get("user_id") + platform: str = msg.get("user_platform") # type: ignore + user_id: str = msg.get("user_id") # type: ignore # 检查必要信息是否存在 且 不是机器人自己 if not all([platform, user_id]) or user_id == global_config.bot.qq_account: diff --git a/src/chat/utils/json_utils.py b/src/chat/utils/json_utils.py index 6226e6e96..892deac4f 100644 --- a/src/chat/utils/json_utils.py +++ b/src/chat/utils/json_utils.py @@ -1,7 +1,8 @@ +import ast import json import logging -from typing import Any, Dict, TypeVar, List, Union, Tuple -import ast + +from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional # 定义类型变量用于泛型类型提示 T = TypeVar("T") @@ -30,18 +31,14 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]: # 尝试标准的 JSON 解析 return json.loads(json_str) except json.JSONDecodeError: - # 如果标准解析失败,尝试将单引号替换为双引号再解析 - # (注意:这种替换可能不安全,如果字符串内容本身包含引号) - # 更安全的方式是用 ast.literal_eval + # 如果标准解析失败,尝试用 ast.literal_eval 解析 try: # logger.debug(f"标准JSON解析失败,尝试用 ast.literal_eval 解析: {json_str[:100]}...") result = ast.literal_eval(json_str) - # 确保结果是字典(因为我们通常期望参数是字典) if isinstance(result, dict): return result - else: - logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}") - return default_value + logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}") + return default_value except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e: logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...") return default_value @@ -53,7 +50,9 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]: return default_value -def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]: +def extract_tool_call_arguments( + tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: """ 从LLM工具调用对象中提取参数 @@ -77,14 +76,12 @@ def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[s logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}") return default_result - # 提取arguments - arguments_str = function_data.get("arguments", "{}") - if not arguments_str: + if arguments_str := function_data.get("arguments", "{}"): + # 解析JSON + return safe_json_loads(arguments_str, default_result) + else: return default_result - # 解析JSON - return safe_json_loads(arguments_str, default_result) - except Exception as e: logger.error(f"提取工具调用参数时出错: {e}") return default_result diff --git a/src/chat/utils/prompt_builder.py b/src/chat/utils/prompt_builder.py index 26f8ffbad..1b107904c 100644 --- a/src/chat/utils/prompt_builder.py +++ b/src/chat/utils/prompt_builder.py @@ -1,12 +1,12 @@ -from typing import Dict, Any, Optional, List, Union import re -from contextlib import asynccontextmanager import asyncio import contextvars -from src.common.logger import get_logger -# import traceback from rich.traceback import install +from contextlib import asynccontextmanager +from typing import Dict, Any, Optional, List, Union + +from src.common.logger import get_logger install(extra_lines=3) @@ -32,6 +32,7 @@ class PromptContext: @asynccontextmanager async def async_scope(self, context_id: Optional[str] = None): + # sourcery skip: hoist-statement-from-if, use-contextlib-suppress """创建一个异步的临时提示模板作用域""" # 保存当前上下文并设置新上下文 if context_id is not None: @@ -88,8 +89,7 @@ class PromptContext: async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None: """异步注册提示模板到指定作用域""" async with self._context_lock: - target_context = context_id or self._current_context - if target_context: + if target_context := context_id or self._current_context: self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt @@ -151,7 +151,7 @@ class Prompt(str): @staticmethod def _process_escaped_braces(template) -> str: - """处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记""" + """处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记""" # type: ignore # 如果传入的是列表,将其转换为字符串 if isinstance(template, list): template = "\n".join(str(item) for item in template) @@ -195,14 +195,8 @@ class Prompt(str): obj._kwargs = kwargs # 修改自动注册逻辑 - if should_register: - if global_prompt_manager._context._current_context: - # 如果存在当前上下文,则注册到上下文中 - # asyncio.create_task(global_prompt_manager._context.register_async(obj)) - pass - else: - # 否则注册到全局管理器 - global_prompt_manager.register(obj) + if should_register and not global_prompt_manager._context._current_context: + global_prompt_manager.register(obj) return obj @classmethod @@ -276,15 +270,13 @@ class Prompt(str): self.name, args=list(args) if args else self._args, _should_register=False, - **kwargs if kwargs else self._kwargs, + **kwargs or self._kwargs, ) # print(f"prompt build result: {ret} name: {ret.name} ") return str(ret) def __str__(self) -> str: - if self._kwargs or self._args: - return super().__str__() - return self.template + return super().__str__() if self._kwargs or self._args else self.template def __repr__(self) -> str: return f"Prompt(template='{self.template}', name='{self.name}')" diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 25d231c01..4e0edd31f 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -1,18 +1,17 @@ -from collections import defaultdict -from datetime import datetime, timedelta -from typing import Any, Dict, Tuple, List import asyncio import concurrent.futures import json import os import glob +from collections import defaultdict +from datetime import datetime, timedelta +from typing import Any, Dict, Tuple, List from src.common.logger import get_logger +from src.common.database.database import db +from src.common.database.database_model import OnlineTime, LLMUsage, Messages from src.manager.async_task_manager import AsyncTask - -from ...common.database.database import db # This db is the Peewee database instance -from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model from src.manager.local_store_manager import local_storage logger = get_logger("maibot_statistic") @@ -76,14 +75,14 @@ class OnlineTimeRecordTask(AsyncTask): with db.atomic(): # Use atomic operations for schema changes OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model - async def run(self): + async def run(self): # sourcery skip: use-named-expression try: current_time = datetime.now() extended_end_time = current_time + timedelta(minutes=1) if self.record_id: # 如果有记录,则更新结束时间 - query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) + query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore updated_rows = query.execute() if updated_rows == 0: # Record might have been deleted or ID is stale, try to find/create @@ -94,7 +93,7 @@ class OnlineTimeRecordTask(AsyncTask): # Look for a record whose end_timestamp is recent enough to be considered ongoing recent_record = ( OnlineTime.select() - .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) + .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore .order_by(OnlineTime.end_timestamp.desc()) .first() ) @@ -123,15 +122,15 @@ def _format_online_time(online_seconds: int) -> str: :param online_seconds: 在线时间(秒) :return: 格式化后的在线时间字符串 """ - total_oneline_time = timedelta(seconds=online_seconds) + total_online_time = timedelta(seconds=online_seconds) - days = total_oneline_time.days - hours = total_oneline_time.seconds // 3600 - minutes = (total_oneline_time.seconds // 60) % 60 - seconds = total_oneline_time.seconds % 60 + days = total_online_time.days + hours = total_online_time.seconds // 3600 + minutes = (total_online_time.seconds // 60) % 60 + seconds = total_online_time.seconds % 60 if days > 0: # 如果在线时间超过1天,则格式化为"X天X小时X分钟" - return f"{total_oneline_time.days}天{hours}小时{minutes}分钟{seconds}秒" + return f"{total_online_time.days}天{hours}小时{minutes}分钟{seconds}秒" elif hours > 0: # 如果在线时间超过1小时,则格式化为"X小时X分钟X秒" return f"{hours}小时{minutes}分钟{seconds}秒" @@ -163,7 +162,7 @@ class StatisticOutputTask(AsyncTask): now = datetime.now() if "deploy_time" in local_storage: # 如果存在部署时间,则使用该时间作为全量统计的起始时间 - deploy_time = datetime.fromtimestamp(local_storage["deploy_time"]) + deploy_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore else: # 否则,使用最大时间范围,并记录部署时间为当前时间 deploy_time = datetime(2000, 1, 1) @@ -252,7 +251,7 @@ class StatisticOutputTask(AsyncTask): # 创建后台任务,不等待完成 collect_task = asyncio.create_task( - loop.run_in_executor(executor, self._collect_all_statistics, now) + loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore ) stats = await collect_task @@ -260,8 +259,8 @@ class StatisticOutputTask(AsyncTask): # 创建并发的输出任务 output_tasks = [ - asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), - asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), + asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore + asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore ] # 等待所有输出任务完成 @@ -320,7 +319,7 @@ class StatisticOutputTask(AsyncTask): # 以最早的时间戳为起始时间获取记录 # Assuming LLMUsage.timestamp is a DateTimeField query_start_time = collect_period[-1][1] - for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): + for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore record_timestamp = record.timestamp # This is already a datetime object for idx, (_, period_start) in enumerate(collect_period): if record_timestamp >= period_start: @@ -388,7 +387,7 @@ class StatisticOutputTask(AsyncTask): query_start_time = collect_period[-1][1] # Assuming OnlineTime.end_timestamp is a DateTimeField - for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): + for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore # record.end_timestamp and record.start_timestamp are datetime objects record_end_timestamp = record.end_timestamp record_start_timestamp = record.start_timestamp @@ -428,7 +427,7 @@ class StatisticOutputTask(AsyncTask): } query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) - for message in Messages.select().where(Messages.time >= query_start_timestamp): + for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore message_time_ts = message.time # This is a float timestamp chat_id = None @@ -661,7 +660,7 @@ class StatisticOutputTask(AsyncTask): if "last_full_statistics" in local_storage: # 如果存在上次完整统计数据,则使用该数据进行增量统计 - last_stat = local_storage["last_full_statistics"] # 上次完整统计数据 + last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射 last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据 @@ -727,6 +726,7 @@ class StatisticOutputTask(AsyncTask): return stat def _convert_defaultdict_to_dict(self, data): + # sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks """递归转换defaultdict为普通dict""" if isinstance(data, defaultdict): # 转换defaultdict为普通dict @@ -812,8 +812,7 @@ class StatisticOutputTask(AsyncTask): # 全局阶段平均时间 if stats[FOCUS_AVG_TIMES_BY_STAGE]: output.append("全局阶段平均时间:") - for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items(): - output.append(f" {stage}: {avg_time:.3f}秒") + output.extend(f" {stage}: {avg_time:.3f}秒" for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items()) output.append("") # Action类型比例 @@ -1050,7 +1049,7 @@ class StatisticOutputTask(AsyncTask): ] tab_content_list.append( - _format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) + _format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) # type: ignore ) # 添加Focus统计内容 @@ -1212,6 +1211,7 @@ class StatisticOutputTask(AsyncTask): f.write(html_template) def _generate_focus_tab(self, stat: dict[str, Any]) -> str: + # sourcery skip: for-append-to-extend, list-comprehension, use-any """生成Focus统计独立分页的HTML内容""" # 为每个时间段准备Focus数据 @@ -1313,12 +1313,11 @@ class StatisticOutputTask(AsyncTask): # 聊天流Action选择比例对比表(横向表格) focus_chat_action_ratios_rows = "" if stat_data.get("focus_action_ratios_by_chat"): - # 获取所有action类型(按全局频率排序) - all_action_types_for_ratio = sorted( - stat_data[FOCUS_ACTION_RATIOS].keys(), key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x], reverse=True - ) - - if all_action_types_for_ratio: + if all_action_types_for_ratio := sorted( + stat_data[FOCUS_ACTION_RATIOS].keys(), + key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x], + reverse=True, + ): # 为每个聊天流生成数据行(按循环数排序) chat_ratio_rows = [] for chat_id in sorted( @@ -1379,16 +1378,11 @@ class StatisticOutputTask(AsyncTask): if period_name == "all_time": from src.manager.local_store_manager import local_storage - start_time = datetime.fromtimestamp(local_storage["deploy_time"]) - time_range = ( - f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - ) + start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore else: start_time = datetime.now() - period_delta - time_range = ( - f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - ) + time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" # 生成该时间段的Focus统计HTML section_html = f"""
@@ -1681,16 +1675,10 @@ class StatisticOutputTask(AsyncTask): if period_name == "all_time": from src.manager.local_store_manager import local_storage - start_time = datetime.fromtimestamp(local_storage["deploy_time"]) - time_range = ( - f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - ) + start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore else: start_time = datetime.now() - period_delta - time_range = ( - f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - ) - + time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" # 生成该时间段的版本对比HTML section_html = f"""
@@ -1865,7 +1853,7 @@ class StatisticOutputTask(AsyncTask): # 查询LLM使用记录 query_start_time = start_time - for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): + for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore record_time = record.timestamp # 找到对应的时间间隔索引 @@ -1875,7 +1863,7 @@ class StatisticOutputTask(AsyncTask): if 0 <= interval_index < len(time_points): # 累加总花费数据 cost = record.cost or 0.0 - total_cost_data[interval_index] += cost + total_cost_data[interval_index] += cost # type: ignore # 累加按模型分类的花费 model_name = record.model_name or "unknown" @@ -1892,7 +1880,7 @@ class StatisticOutputTask(AsyncTask): # 查询消息记录 query_start_timestamp = start_time.timestamp() - for message in Messages.select().where(Messages.time >= query_start_timestamp): + for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore message_time_ts = message.time # 找到对应的时间间隔索引 @@ -1982,6 +1970,7 @@ class StatisticOutputTask(AsyncTask): } def _generate_chart_tab(self, chart_data: dict) -> str: + # sourcery skip: extract-duplicate-method, move-assign-in-block """生成图表选项卡HTML内容""" # 生成不同颜色的调色板 @@ -2293,7 +2282,7 @@ class AsyncStatisticOutputTask(AsyncTask): # 数据收集任务 collect_task = asyncio.create_task( - loop.run_in_executor(executor, self._collect_all_statistics, now) + loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore ) stats = await collect_task @@ -2301,8 +2290,8 @@ class AsyncStatisticOutputTask(AsyncTask): # 创建并发的输出任务 output_tasks = [ - asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), - asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), + asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore + asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore ] # 等待所有输出任务完成 diff --git a/src/chat/utils/timer_calculator.py b/src/chat/utils/timer_calculator.py index df2b9f778..d9479af16 100644 --- a/src/chat/utils/timer_calculator.py +++ b/src/chat/utils/timer_calculator.py @@ -1,7 +1,8 @@ +import asyncio + from time import perf_counter from functools import wraps from typing import Optional, Dict, Callable -import asyncio from rich.traceback import install install(extra_lines=3) @@ -88,10 +89,10 @@ class Timer: self.name = name self.storage = storage - self.elapsed = None + self.elapsed: float = None # type: ignore self.auto_unit = auto_unit - self.start = None + self.start: float = None # type: ignore @staticmethod def _validate_types(name, storage): @@ -120,7 +121,7 @@ class Timer: return None wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper - wrapper.__timer__ = self # 保留计时器引用 + wrapper.__timer__ = self # 保留计时器引用 # type: ignore return wrapper def __enter__(self): diff --git a/src/chat/utils/typo_generator.py b/src/chat/utils/typo_generator.py index 7c373f132..4de219464 100644 --- a/src/chat/utils/typo_generator.py +++ b/src/chat/utils/typo_generator.py @@ -7,10 +7,10 @@ import math import os import random import time +import jieba + from collections import defaultdict from pathlib import Path - -import jieba from pypinyin import Style, pinyin from src.common.logger import get_logger @@ -104,7 +104,7 @@ class ChineseTypoGenerator: try: return "\u4e00" <= char <= "\u9fff" except Exception as e: - logger.debug(e) + logger.debug(str(e)) return False def _get_pinyin(self, sentence): @@ -138,7 +138,7 @@ class ChineseTypoGenerator: # 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况 if not py[-1].isdigit(): # 为非数字结尾的拼音添加数字声调1 - return py + "1" + return f"{py}1" base = py[:-1] # 去掉声调 tone = int(py[-1]) # 获取声调 diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index f3226b2e1..2fbc69559 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -1,23 +1,21 @@ import random import re import time -from collections import Counter - import jieba import numpy as np + +from collections import Counter from maim_message import UserInfo +from typing import Optional, Tuple, Dict from src.common.logger import get_logger - -# from src.mood.mood_manager import mood_manager -from ..message_receive.message import MessageRecv -from src.llm_models.utils_model import LLMRequest -from .typo_generator import ChineseTypoGenerator -from ...config.config import global_config -from ...common.message_repository import find_messages, count_messages -from typing import Optional, Tuple, Dict +from src.common.message_repository import find_messages, count_messages +from src.config.config import global_config +from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.chat_stream import get_chat_manager +from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import PersonInfoManager, get_person_info_manager +from .typo_generator import ChineseTypoGenerator logger = get_logger("chat_utils") @@ -31,11 +29,7 @@ def db_message_to_str(message_dict: dict) -> str: logger.debug(f"message_dict: {message_dict}") time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"])) try: - name = "[(%s)%s]%s" % ( - message_dict["user_id"], - message_dict.get("user_nickname", ""), - message_dict.get("user_cardname", ""), - ) + name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}" except Exception: name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}" content = message_dict.get("processed_plain_text", "") @@ -58,11 +52,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: and message.message_info.additional_config.get("is_mentioned") is not None ): try: - reply_probability = float(message.message_info.additional_config.get("is_mentioned")) + reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore is_mentioned = True return is_mentioned, reply_probability except Exception as e: - logger.warning(e) + logger.warning(str(e)) logger.warning( f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}" ) @@ -135,20 +129,17 @@ def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, c if not recent_messages: return [] - message_detailed_plain_text = "" - message_detailed_plain_text_list = [] - # 反转消息列表,使最新的消息在最后 recent_messages.reverse() if combine: - for msg_db_data in recent_messages: - message_detailed_plain_text += str(msg_db_data["detailed_plain_text"]) - return message_detailed_plain_text - else: - for msg_db_data in recent_messages: - message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"]) - return message_detailed_plain_text_list + return "".join(str(msg_db_data["detailed_plain_text"]) for msg_db_data in recent_messages) + + message_detailed_plain_text_list = [] + + for msg_db_data in recent_messages: + message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"]) + return message_detailed_plain_text_list def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list: @@ -204,10 +195,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]: len_text = len(text) if len_text < 3: - if random.random() < 0.01: - return list(text) # 如果文本很短且触发随机条件,直接按字符分割 - else: - return [text] + return list(text) if random.random() < 0.01 else [text] # 定义分隔符 separators = {",", ",", " ", "。", ";"} @@ -352,10 +340,9 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese max_length = global_config.response_splitter.max_length * 2 max_sentence_num = global_config.response_splitter.max_sentence_num # 如果基本上是中文,则进行长度过滤 - if get_western_ratio(cleaned_text) < 0.1: - if len(cleaned_text) > max_length: - logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复") - return ["懒得说"] + if get_western_ratio(cleaned_text) < 0.1 and len(cleaned_text) > max_length: + logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复") + return ["懒得说"] typo_generator = ChineseTypoGenerator( error_rate=global_config.chinese_typo.error_rate, @@ -420,7 +407,7 @@ def calculate_typing_time( # chinese_time *= 1 / typing_speed_multiplier # english_time *= 1 / typing_speed_multiplier # 计算中文字符数 - chinese_chars = sum(1 for char in input_string if "\u4e00" <= char <= "\u9fff") + chinese_chars = sum("\u4e00" <= char <= "\u9fff" for char in input_string) # 如果只有一个中文字符,使用3倍时间 if chinese_chars == 1 and len(input_string.strip()) == 1: @@ -429,11 +416,7 @@ def calculate_typing_time( # 正常计算所有字符的输入时间 total_time = 0.0 for char in input_string: - if "\u4e00" <= char <= "\u9fff": # 判断是否为中文字符 - total_time += chinese_time - else: # 其他字符(如英文) - total_time += english_time - + total_time += chinese_time if "\u4e00" <= char <= "\u9fff" else english_time if is_emoji: total_time = 1 @@ -453,18 +436,14 @@ def cosine_similarity(v1, v2): dot_product = np.dot(v1, v2) norm1 = np.linalg.norm(v1) norm2 = np.linalg.norm(v2) - if norm1 == 0 or norm2 == 0: - return 0 - return dot_product / (norm1 * norm2) + return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2) def text_to_vector(text): """将文本转换为词频向量""" # 分词 words = jieba.lcut(text) - # 统计词频 - word_freq = Counter(words) - return word_freq + return Counter(words) def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list: @@ -491,9 +470,7 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list: def truncate_message(message: str, max_length=20) -> str: """截断消息,使其不超过指定长度""" - if len(message) > max_length: - return message[:max_length] + "..." - return message + return f"{message[:max_length]}..." if len(message) > max_length else message def protect_kaomoji(sentence): @@ -522,7 +499,7 @@ def protect_kaomoji(sentence): placeholder_to_kaomoji = {} for idx, match in enumerate(kaomoji_matches): - kaomoji = match[0] if match[0] else match[1] + kaomoji = match[0] or match[1] placeholder = f"__KAOMOJI_{idx}__" sentence = sentence.replace(kaomoji, placeholder, 1) placeholder_to_kaomoji[placeholder] = kaomoji @@ -563,7 +540,7 @@ def get_western_ratio(paragraph): if not alnum_chars: return 0.0 - western_count = sum(1 for char in alnum_chars if is_english_letter(char)) + western_count = sum(bool(is_english_letter(char)) for char in alnum_chars) return western_count / len(alnum_chars) @@ -610,6 +587,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) - def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str: + # sourcery skip: merge-comparisons, merge-duplicate-blocks, switch """将时间戳转换为人类可读的时间格式 Args: @@ -621,7 +599,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal" """ if mode == "normal": return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) - if mode == "normal_no_YMD": + elif mode == "normal_no_YMD": return time.strftime("%H:%M:%S", time.localtime(timestamp)) elif mode == "relative": now = time.time() @@ -640,7 +618,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal" else: return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":" else: # mode = "lite" or unknown - # 只返回时分秒格式,喵~ + # 只返回时分秒格式 return time.strftime("%H:%M:%S", time.localtime(timestamp)) @@ -670,8 +648,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: elif chat_stream.user_info: # It's a private chat is_group_chat = False user_info = chat_stream.user_info - platform = chat_stream.platform - user_id = user_info.user_id + platform: str = chat_stream.platform # type: ignore + user_id: str = user_info.user_id # type: ignore # Initialize target_info with basic info target_info = { diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 5579ccf84..d5fa301bb 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -3,21 +3,20 @@ import os import time import hashlib import uuid +import io +import asyncio +import numpy as np + from typing import Optional, Tuple from PIL import Image -import io -import numpy as np -import asyncio - +from rich.traceback import install +from src.common.logger import get_logger from src.common.database.database import db from src.common.database.database_model import Images, ImageDescriptions from src.config.config import global_config from src.llm_models.utils_model import LLMRequest -from src.common.logger import get_logger -from rich.traceback import install - install(extra_lines=3) logger = get_logger("chat_image") @@ -111,7 +110,7 @@ class ImageManager: return f"[表情包,含义看起来是:{cached_description}]" # 调用AI获取描述 - if image_format == "gif" or image_format == "GIF": + if image_format in ["gif", "GIF"]: image_base64_processed = self.transform_gif(image_base64) if image_base64_processed is None: logger.warning("GIF转换失败,无法获取描述") @@ -258,6 +257,7 @@ class ImageManager: @staticmethod def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]: + # sourcery skip: use-contextlib-suppress """将GIF转换为水平拼接的静态图像, 跳过相似的帧 Args: @@ -351,7 +351,7 @@ class ImageManager: # 创建拼接图像 total_width = target_width * len(resized_frames) # 防止总宽度为0 - if total_width == 0 and len(resized_frames) > 0: + if total_width == 0 and resized_frames: logger.warning("计算出的总宽度为0,但有选中帧,可能目标宽度太小") # 至少给点宽度吧 total_width = len(resized_frames) @@ -368,10 +368,7 @@ class ImageManager: # 转换为base64 buffer = io.BytesIO() combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG - result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") - - return result_base64 - + return base64.b64encode(buffer.getvalue()).decode("utf-8") except MemoryError: logger.error("GIF转换失败: 内存不足,可能是GIF太大或帧数太多") return None # 内存不够啦 @@ -380,6 +377,7 @@ class ImageManager: return None # 其他错误也返回None async def process_image(self, image_base64: str) -> Tuple[str, str]: + # sourcery skip: hoist-if-from-if """处理图片并返回图片ID和描述 Args: @@ -418,17 +416,9 @@ class ImageManager: if existing_image.vlm_processed is None: existing_image.vlm_processed = False - existing_image.count += 1 - existing_image.save() - return existing_image.image_id, f"[picid:{existing_image.image_id}]" - else: - # print(f"图片已存在: {existing_image.image_id}") - # print(f"图片描述: {existing_image.description}") - # print(f"图片计数: {existing_image.count}") - # 更新计数 - existing_image.count += 1 - existing_image.save() - return existing_image.image_id, f"[picid:{existing_image.image_id}]" + existing_image.count += 1 + existing_image.save() + return existing_image.image_id, f"[picid:{existing_image.image_id}]" else: # print(f"图片不存在: {image_hash}") image_id = str(uuid.uuid4()) diff --git a/src/common/database/database.py b/src/common/database/database.py index 249664155..ca3614816 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -54,11 +54,11 @@ class DBWrapper: return getattr(get_db(), name) def __getitem__(self, key): - return get_db()[key] + return get_db()[key] # type: ignore # 全局数据库访问点 -memory_db: Database = DBWrapper() +memory_db: Database = DBWrapper() # type: ignore # 定义数据库文件路径 ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 3485fedeb..b411e1b3a 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -406,9 +406,7 @@ def initialize_database(): existing_columns = {row[1] for row in cursor.fetchall()} model_fields = set(model._meta.fields.keys()) - # 检查并添加缺失字段(原有逻辑) - missing_fields = model_fields - existing_columns - if missing_fields: + if missing_fields := model_fields - existing_columns: logger.warning(f"表 '{table_name}' 缺失字段: {missing_fields}") for field_name, field_obj in model._meta.fields.items(): @@ -424,10 +422,7 @@ def initialize_database(): "DateTimeField": "DATETIME", }.get(field_type, "TEXT") alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}" - if field_obj.null: - alter_sql += " NULL" - else: - alter_sql += " NOT NULL" + alter_sql += " NULL" if field_obj.null else " NOT NULL" if hasattr(field_obj, "default") and field_obj.default is not None: # 正确处理不同类型的默认值 default_value = field_obj.default diff --git a/src/common/logger.py b/src/common/logger.py index 40fd15070..a235cf341 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -1,16 +1,16 @@ -import logging - # 使用基于时间戳的文件处理器,简单的轮转份数限制 -from pathlib import Path -from typing import Callable, Optional + +import logging import json import threading import time -from datetime import datetime, timedelta - import structlog import toml +from pathlib import Path +from typing import Callable, Optional +from datetime import datetime, timedelta + # 创建logs目录 LOG_DIR = Path("logs") LOG_DIR.mkdir(exist_ok=True) @@ -160,7 +160,7 @@ def close_handlers(): _console_handler = None -def remove_duplicate_handlers(): +def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-comprehension """移除重复的handler,特别是文件handler""" root_logger = logging.getLogger() @@ -184,7 +184,7 @@ def remove_duplicate_handlers(): # 读取日志配置 -def load_log_config(): +def load_log_config(): # sourcery skip: use-contextlib-suppress """从配置文件加载日志设置""" config_path = Path("config/bot_config.toml") default_config = { @@ -365,7 +365,7 @@ MODULE_COLORS = { "component_registry": "\033[38;5;214m", # 橙黄色 "stream_api": "\033[38;5;220m", # 黄色 "config_api": "\033[38;5;226m", # 亮黄色 - "hearflow_api": "\033[38;5;154m", # 黄绿色 + "heartflow_api": "\033[38;5;154m", # 黄绿色 "action_apis": "\033[38;5;118m", # 绿色 "independent_apis": "\033[38;5;82m", # 绿色 "llm_api": "\033[38;5;46m", # 亮绿色 @@ -412,6 +412,7 @@ class ModuleColoredConsoleRenderer: """自定义控制台渲染器,为不同模块提供不同颜色""" def __init__(self, colors=True): + # sourcery skip: merge-duplicate-blocks, remove-redundant-if self._colors = colors self._config = LOG_CONFIG @@ -443,6 +444,7 @@ class ModuleColoredConsoleRenderer: self._enable_full_content_colors = False def __call__(self, logger, method_name, event_dict): + # sourcery skip: merge-duplicate-blocks """渲染日志消息""" # 获取基本信息 timestamp = event_dict.get("timestamp", "") @@ -662,7 +664,7 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger: """获取logger实例,支持按名称绑定""" if name is None: return raw_logger - logger = binds.get(name) + logger = binds.get(name) # type: ignore if logger is None: logger: structlog.stdlib.BoundLogger = structlog.get_logger(name).bind(logger_name=name) binds[name] = logger @@ -671,8 +673,8 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger: def configure_logging( level: str = "INFO", - console_level: str = None, - file_level: str = None, + console_level: Optional[str] = None, + file_level: Optional[str] = None, max_bytes: int = 5 * 1024 * 1024, backup_count: int = 30, log_dir: str = "logs", @@ -729,14 +731,11 @@ def reload_log_config(): global LOG_CONFIG LOG_CONFIG = load_log_config() - # 重新设置handler的日志级别 - file_handler = get_file_handler() - if file_handler: + if file_handler := get_file_handler(): file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO")) file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO)) - console_handler = get_console_handler() - if console_handler: + if console_handler := get_console_handler(): console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO")) console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO)) @@ -780,8 +779,7 @@ def set_console_log_level(level: str): global LOG_CONFIG LOG_CONFIG["console_log_level"] = level.upper() - console_handler = get_console_handler() - if console_handler: + if console_handler := get_console_handler(): console_handler.setLevel(getattr(logging, level.upper(), logging.INFO)) # 重新设置root logger级别 @@ -800,8 +798,7 @@ def set_file_log_level(level: str): global LOG_CONFIG LOG_CONFIG["file_log_level"] = level.upper() - file_handler = get_file_handler() - if file_handler: + if file_handler := get_file_handler(): file_handler.setLevel(getattr(logging, level.upper(), logging.INFO)) # 重新设置root logger级别 @@ -933,13 +930,12 @@ def format_json_for_logging(data, indent=2, ensure_ascii=False): Returns: str: 格式化后的JSON字符串 """ - if isinstance(data, str): - # 如果是JSON字符串,先解析再格式化 - parsed_data = json.loads(data) - return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii) - else: + if not isinstance(data, str): # 如果是对象,直接格式化 return json.dumps(data, indent=indent, ensure_ascii=ensure_ascii) + # 如果是JSON字符串,先解析再格式化 + parsed_data = json.loads(data) + return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii) def cleanup_old_logs(): diff --git a/src/common/message/api.py b/src/common/message/api.py index 59ba9d1e2..eed85c0a9 100644 --- a/src/common/message/api.py +++ b/src/common/message/api.py @@ -8,7 +8,7 @@ from src.config.config import global_config global_api = None -def get_global_api() -> MessageServer: +def get_global_api() -> MessageServer: # sourcery skip: extract-method """获取全局MessageServer实例""" global global_api if global_api is None: @@ -36,9 +36,8 @@ def get_global_api() -> MessageServer: kwargs["custom_logger"] = maim_message_logger # 添加token认证 - if maim_message_config.auth_token: - if len(maim_message_config.auth_token) > 0: - kwargs["enable_token"] = True + if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0: + kwargs["enable_token"] = True if maim_message_config.use_custom: # 添加WSS模式支持 diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 107ee1c5e..dc5d8b7df 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,9 +1,11 @@ -from src.common.database.database_model import Messages # 更改导入 -from src.common.logger import get_logger import traceback + from typing import List, Any, Optional from peewee import Model # 添加 Peewee Model 导入 +from src.common.database.database_model import Messages +from src.common.logger import get_logger + logger = get_logger(__name__) diff --git a/src/common/remote.py b/src/common/remote.py index 955e760b0..5380cd01e 100644 --- a/src/common/remote.py +++ b/src/common/remote.py @@ -23,7 +23,7 @@ class TelemetryHeartBeatTask(AsyncTask): self.server_url = TELEMETRY_SERVER_URL """遥测服务地址""" - self.client_uuid = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None + self.client_uuid: str | None = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None # type: ignore """客户端UUID""" self.info_dict = self._get_sys_info() @@ -72,7 +72,7 @@ class TelemetryHeartBeatTask(AsyncTask): timeout=aiohttp.ClientTimeout(total=5), # 设置超时时间为5秒 ) as response: logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client") - logger.debug(local_storage["deploy_time"]) + logger.debug(local_storage["deploy_time"]) # type: ignore logger.debug(f"Response status: {response.status}") if response.status == 200: @@ -93,7 +93,7 @@ class TelemetryHeartBeatTask(AsyncTask): except Exception as e: import traceback - error_msg = str(e) if str(e) else "未知错误" + error_msg = str(e) or "未知错误" logger.warning( f"请求UUID出错,不过你还是可以正常使用麦麦: {type(e).__name__}: {error_msg}" ) # 可能是网络问题 @@ -114,11 +114,11 @@ class TelemetryHeartBeatTask(AsyncTask): """向服务器发送心跳""" headers = { "Client-UUID": self.client_uuid, - "User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}", + "User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}", # type: ignore } logger.debug(f"正在发送心跳到服务器: {self.server_url}") - logger.debug(headers) + logger.debug(str(headers)) try: async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session: @@ -151,7 +151,7 @@ class TelemetryHeartBeatTask(AsyncTask): except Exception as e: import traceback - error_msg = str(e) if str(e) else "未知错误" + error_msg = str(e) or "未知错误" logger.warning(f"(此消息不会影响正常使用)状态未发生: {type(e).__name__}: {error_msg}") logger.debug(f"完整错误信息: {traceback.format_exc()}") diff --git a/src/config/auto_update.py b/src/config/auto_update.py index 2088e3628..139003a84 100644 --- a/src/config/auto_update.py +++ b/src/config/auto_update.py @@ -1,5 +1,6 @@ import shutil import tomlkit +from tomlkit.items import Table from pathlib import Path from datetime import datetime @@ -45,8 +46,8 @@ def update_config(): # 检查version是否相同 if old_config and "inner" in old_config and "inner" in new_config: - old_version = old_config["inner"].get("version") - new_version = new_config["inner"].get("version") + old_version = old_config["inner"].get("version") # type: ignore + new_version = new_config["inner"].get("version") # type: ignore if old_version and new_version and old_version == new_version: print(f"检测到版本号相同 (v{old_version}),跳过更新") # 如果version相同,恢复旧配置文件并返回 @@ -62,7 +63,7 @@ def update_config(): if key == "version": continue if key in target: - if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)): + if isinstance(value, dict) and isinstance(target[key], (dict, Table)): update_dict(target[key], value) else: try: @@ -85,10 +86,7 @@ def update_config(): if value and isinstance(value[0], dict) and "regex" in value[0]: contains_regex = True - if contains_regex: - target[key] = value - else: - target[key] = tomlkit.array(value) + target[key] = value if contains_regex else tomlkit.array(str(value)) else: # 其他类型使用item方法创建新值 target[key] = tomlkit.item(value) diff --git a/src/config/config.py b/src/config/config.py index de173a520..b61111ec3 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -1,16 +1,14 @@ import os -from dataclasses import field, dataclass - import tomlkit import shutil -from datetime import datetime +from datetime import datetime from tomlkit import TOMLDocument from tomlkit.items import Table - -from src.common.logger import get_logger +from dataclasses import field, dataclass from rich.traceback import install +from src.common.logger import get_logger from src.config.config_base import ConfigBase from src.config.official_configs import ( BotConfig, @@ -80,8 +78,8 @@ def update_config(): # 检查version是否相同 if old_config and "inner" in old_config and "inner" in new_config: - old_version = old_config["inner"].get("version") - new_version = new_config["inner"].get("version") + old_version = old_config["inner"].get("version") # type: ignore + new_version = new_config["inner"].get("version") # type: ignore if old_version and new_version and old_version == new_version: logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") return @@ -103,7 +101,7 @@ def update_config(): shutil.copy2(template_path, new_config_path) logger.info(f"已创建新配置文件: {new_config_path}") - def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict): + def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): """ 将source字典的值更新到target字典中(如果target中存在相同的键) """ @@ -112,8 +110,9 @@ def update_config(): if key == "version": continue if key in target: - if isinstance(value, dict) and isinstance(target[key], (dict, Table)): - update_dict(target[key], value) + target_value = target[key] + if isinstance(value, dict) and isinstance(target_value, (dict, Table)): + update_dict(target_value, value) else: try: # 对数组类型进行特殊处理 diff --git a/src/config/config_base.py b/src/config/config_base.py index 129f5a1c0..5fb398190 100644 --- a/src/config/config_base.py +++ b/src/config/config_base.py @@ -43,7 +43,7 @@ class ConfigBase: field_type = f.type try: - init_args[field_name] = cls._convert_field(value, field_type) + init_args[field_name] = cls._convert_field(value, field_type) # type: ignore except TypeError as e: raise TypeError(f"Field '{field_name}' has a type error: {e}") from e except Exception as e: diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 7e2efbeba..6838df1d1 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1,7 +1,8 @@ -from dataclasses import dataclass, field -from typing import Any, Literal import re +from dataclasses import dataclass, field +from typing import Any, Literal, Optional + from src.config.config_base import ConfigBase """ @@ -113,7 +114,7 @@ class ChatConfig(ConfigBase): exit_focus_threshold: float = 1.0 """自动退出专注聊天的阈值,越低越容易退出专注聊天""" - def get_current_talk_frequency(self, chat_stream_id: str = None) -> float: + def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float: """ 根据当前时间和聊天流获取对应的 talk_frequency @@ -138,7 +139,7 @@ class ChatConfig(ConfigBase): # 如果都没有匹配,返回默认值 return self.talk_frequency - def _get_time_based_frequency(self, time_freq_list: list[str]) -> float: + def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]: """ 根据时间配置列表获取当前时段的频率 @@ -186,7 +187,7 @@ class ChatConfig(ConfigBase): return current_frequency - def _get_stream_specific_frequency(self, chat_stream_id: str) -> float: + def _get_stream_specific_frequency(self, chat_stream_id: str): """ 获取特定聊天流在当前时间的频率 @@ -217,7 +218,7 @@ class ChatConfig(ConfigBase): return None - def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> str: + def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]: """ 解析流配置字符串并生成对应的 chat_id diff --git a/src/individuality/identity.py b/src/individuality/identity.py index bb3125985..730615e3d 100644 --- a/src/individuality/identity.py +++ b/src/individuality/identity.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List +from typing import List, Optional @dataclass @@ -8,7 +8,7 @@ class Identity: identity_detail: List[str] # 身份细节描述 - def __init__(self, identity_detail: List[str] = None): + def __init__(self, identity_detail: Optional[List[str]] = None): """初始化身份特征 Args: diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index 8365c0888..532b203fd 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -1,17 +1,18 @@ -from typing import Optional import ast - -from src.llm_models.utils_model import LLMRequest -from .personality import Personality -from .identity import Identity import random import json import os import hashlib + +from typing import Optional from rich.traceback import install + from src.common.logger import get_logger -from src.person_info.person_info import get_person_info_manager from src.config.config import global_config +from src.llm_models.utils_model import LLMRequest +from src.person_info.person_info import get_person_info_manager +from .personality import Personality +from .identity import Identity install(extra_lines=3) @@ -23,7 +24,7 @@ class Individuality: def __init__(self): # 正常初始化实例属性 - self.personality: Optional[Personality] = None + self.personality: Personality = None # type: ignore self.identity: Optional[Identity] = None self.name = "" @@ -109,7 +110,7 @@ class Individuality: existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression") if existing_short_impression: try: - existing_data = ast.literal_eval(existing_short_impression) + existing_data = ast.literal_eval(existing_short_impression) # type: ignore if isinstance(existing_data, list) and len(existing_data) >= 1: personality_result = existing_data[0] except (json.JSONDecodeError, TypeError, IndexError): @@ -128,7 +129,7 @@ class Individuality: existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression") if existing_short_impression: try: - existing_data = ast.literal_eval(existing_short_impression) + existing_data = ast.literal_eval(existing_short_impression) # type: ignore if isinstance(existing_data, list) and len(existing_data) >= 2: identity_result = existing_data[1] except (json.JSONDecodeError, TypeError, IndexError): @@ -204,6 +205,7 @@ class Individuality: return prompt_personality def get_identity_prompt(self, level: int, x_person: int = 2) -> str: + # sourcery skip: assign-if-exp, merge-else-if-into-elif """ 获取身份特征的prompt @@ -240,13 +242,13 @@ class Individuality: if identity_parts: details_str = ",".join(identity_parts) - if x_person in [1, 2]: + if x_person in {1, 2}: return f"{i_pronoun},{details_str}。" else: # x_person == 0 # 无人称时,直接返回细节,不加代词和开头的逗号 return f"{details_str}。" else: - if x_person in [1, 2]: + if x_person in {1, 2}: return f"{i_pronoun}的身份信息不完整。" else: # x_person == 0 return "身份信息不完整。" @@ -441,14 +443,15 @@ class Individuality: if info_list_json: try: info_list = json.loads(info_list_json) if isinstance(info_list_json, str) else info_list_json - for item in info_list: - if isinstance(item, dict) and "info_type" in item: - keywords.append(item["info_type"]) + keywords.extend( + item["info_type"] for item in info_list if isinstance(item, dict) and "info_type" in item + ) except (json.JSONDecodeError, TypeError): logger.error(f"解析info_list失败: {info_list_json}") return keywords async def _create_personality(self, personality_core: str, personality_sides: list) -> str: + # sourcery skip: merge-list-append, move-assign """使用LLM创建压缩版本的impression Args: diff --git a/src/individuality/personality.py b/src/individuality/personality.py index 0ee46a3d0..ace719331 100644 --- a/src/individuality/personality.py +++ b/src/individuality/personality.py @@ -1,6 +1,7 @@ -from dataclasses import dataclass -from typing import Dict, List import json + +from dataclasses import dataclass +from typing import Dict, List, Optional from pathlib import Path @@ -24,7 +25,7 @@ class Personality: cls._instance = super().__new__(cls) return cls._instance - def __init__(self, personality_core: str = "", personality_sides: List[str] = None): + def __init__(self, personality_core: str = "", personality_sides: Optional[List[str]] = None): if personality_sides is None: personality_sides = [] self.personality_core = personality_core @@ -41,7 +42,7 @@ class Personality: cls._instance = cls() return cls._instance - def _init_big_five_personality(self): + def _init_big_five_personality(self): # sourcery skip: extract-method """初始化大五人格特质""" # 构建文件路径 personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per" @@ -63,7 +64,6 @@ class Personality: else: self.extraversion = 0.3 self.neuroticism = 0.5 - if "认真" in self.personality_core or "负责" in self.personality_sides: self.conscientiousness = 0.9 else: diff --git a/src/manager/async_task_manager.py b/src/manager/async_task_manager.py index 1e1e9132f..0a2c0d215 100644 --- a/src/manager/async_task_manager.py +++ b/src/manager/async_task_manager.py @@ -120,12 +120,7 @@ class AsyncTaskManager: """ 获取所有任务的状态 """ - tasks_status = {} - for task_name, task in self.tasks.items(): - tasks_status[task_name] = { - "status": "running" if not task.done() else "done", - } - return tasks_status + return {task_name: {"status": "done" if task.done() else "running"} for task_name, task in self.tasks.items()} async def stop_and_wait_all_tasks(self): """ diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index ffdf8ff36..e3a66370b 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -2,12 +2,12 @@ import math import random import time -from src.chat.message_receive.message import MessageRecv -from src.llm_models.utils_model import LLMRequest -from ..common.logger import get_logger -from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive +from src.common.logger import get_logger from src.config.config import global_config +from src.chat.message_receive.message import MessageRecv from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive +from src.llm_models.utils_model import LLMRequest from src.manager.async_task_manager import AsyncTask, async_task_manager logger = get_logger("mood") @@ -55,12 +55,12 @@ class ChatMood: request_type="mood", ) - self.last_change_time = 0 + self.last_change_time: float = 0 async def update_mood_by_message(self, message: MessageRecv, interested_rate: float): self.regression_count = 0 - during_last_time = message.message_info.time - self.last_change_time + during_last_time = message.message_info.time - self.last_change_time # type: ignore base_probability = 0.05 time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time)) @@ -78,7 +78,7 @@ class ChatMood: if random.random() > update_probability: return - message_time = message.message_info.time + message_time: float = message.message_info.time # type: ignore message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, @@ -119,7 +119,7 @@ class ChatMood: self.mood_state = response - self.last_change_time = message_time + self.last_change_time = message_time # type: ignore async def regress_mood(self): message_time = time.time() diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index f44a88225..5e5f033f9 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -1,17 +1,18 @@ -from src.common.logger import get_logger -from src.common.database.database import db -from src.common.database.database_model import PersonInfo # 新增导入 import copy import hashlib -from typing import Any, Callable, Dict, Union import datetime import asyncio +import json + +from json_repair import repair_json +from typing import Any, Callable, Dict, Union, Optional + +from src.common.logger import get_logger +from src.common.database.database import db +from src.common.database.database_model import PersonInfo from src.llm_models.utils_model import LLMRequest from src.config.config import global_config -import json # 新增导入 -from json_repair import repair_json - """ PersonInfoManager 类方法功能摘要: @@ -42,7 +43,7 @@ person_info_default = { "last_know": None, # "user_cardname": None, # This field is not in Peewee model PersonInfo # "user_avatar": None, # This field is not in Peewee model PersonInfo - "impression": None, # Corrected from persion_impression + "impression": None, # Corrected from person_impression "short_impression": None, "info_list": None, "points": None, @@ -106,27 +107,24 @@ class PersonInfoManager: logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}") return False - def get_person_id_by_person_name(self, person_name: str): + def get_person_id_by_person_name(self, person_name: str) -> str: """根据用户名获取用户ID""" try: record = PersonInfo.get_or_none(PersonInfo.person_name == person_name) - if record: - return record.person_id - else: - return "" + return record.person_id if record else "" except Exception as e: logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}") return "" @staticmethod - async def create_person_info(person_id: str, data: dict = None): + async def create_person_info(person_id: str, data: Optional[dict] = None): """创建一个项""" if not person_id: - logger.debug("创建失败,personid不存在") + logger.debug("创建失败,person_id不存在") return _person_info_default = copy.deepcopy(person_info_default) - model_fields = PersonInfo._meta.fields.keys() + model_fields = PersonInfo._meta.fields.keys() # type: ignore final_data = {"person_id": person_id} @@ -163,9 +161,9 @@ class PersonInfoManager: await asyncio.to_thread(_db_create_sync, final_data) - async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None): + async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None): """更新某一个字段,会补全""" - if field_name not in PersonInfo._meta.fields: + if field_name not in PersonInfo._meta.fields: # type: ignore logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。") return @@ -228,15 +226,13 @@ class PersonInfoManager: @staticmethod async def has_one_field(person_id: str, field_name: str): """判断是否存在某一个字段""" - if field_name not in PersonInfo._meta.fields: + if field_name not in PersonInfo._meta.fields: # type: ignore logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。") return False def _db_has_field_sync(p_id: str, f_name: str): record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) - if record: - return True - return False + return bool(record) try: return await asyncio.to_thread(_db_has_field_sync, person_id, field_name) @@ -435,9 +431,7 @@ class PersonInfoManager: except Exception as e: logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}") # Fallback to default in case of any error during DB access - if field_name in person_info_default: - return default_value_for_field - return None + return default_value_for_field if field_name in person_info_default else None @staticmethod def get_value_sync(person_id: str, field_name: str): @@ -446,8 +440,7 @@ class PersonInfoManager: if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None: default_value_for_field = [] - record = PersonInfo.get_or_none(PersonInfo.person_id == person_id) - if record: + if record := PersonInfo.get_or_none(PersonInfo.person_id == person_id): val = getattr(record, field_name, None) if field_name in JSON_SERIALIZED_FIELDS: if isinstance(val, str): @@ -481,7 +474,7 @@ class PersonInfoManager: record = await asyncio.to_thread(_db_get_record_sync, person_id) for field_name in field_names: - if field_name not in PersonInfo._meta.fields: + if field_name not in PersonInfo._meta.fields: # type: ignore if field_name in person_info_default: result[field_name] = copy.deepcopy(person_info_default[field_name]) logger.debug(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。") @@ -509,7 +502,7 @@ class PersonInfoManager: """ 获取满足条件的字段值字典 """ - if field_name not in PersonInfo._meta.fields: + if field_name not in PersonInfo._meta.fields: # type: ignore logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义") return {} @@ -531,7 +524,7 @@ class PersonInfoManager: return {} async def get_or_create_person( - self, platform: str, user_id: int, nickname: str = None, user_cardname: str = None, user_avatar: str = None + self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None ) -> str: """ 根据 platform 和 user_id 获取 person_id。 @@ -561,7 +554,7 @@ class PersonInfoManager: "points": [], "forgotten_points": [], } - model_fields = PersonInfo._meta.fields.keys() + model_fields = PersonInfo._meta.fields.keys() # type: ignore filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} await self.create_person_info(person_id, data=filtered_initial_data) @@ -610,7 +603,9 @@ class PersonInfoManager: "name_reason", ] valid_fields_to_get = [ - f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default + f + for f in required_fields + if f in PersonInfo._meta.fields or f in person_info_default # type: ignore ] person_data = await self.get_values(found_person_id, valid_fields_to_get) diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index 0b443850f..7b69b47bb 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -3,12 +3,12 @@ import traceback import os import pickle import random -from typing import List, Dict +from typing import List, Dict, Any from src.config.config import global_config from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import get_chat_manager from src.person_info.relationship_manager import get_relationship_manager from src.person_info.person_info import get_person_info_manager, PersonInfoManager +from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import ( get_raw_msg_by_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat_inclusive, @@ -45,7 +45,7 @@ class RelationshipBuilder: self.chat_id = chat_id # 新的消息段缓存结构: # {person_id: [{"start_time": float, "end_time": float, "last_msg_time": float, "message_count": int}, ...]} - self.person_engaged_cache: Dict[str, List[Dict[str, any]]] = {} + self.person_engaged_cache: Dict[str, List[Dict[str, Any]]] = {} # 持久化存储文件路径 self.cache_file_path = os.path.join("data", "relationship", f"relationship_cache_{self.chat_id}.pkl") @@ -210,11 +210,7 @@ class RelationshipBuilder: if person_id not in self.person_engaged_cache: return 0 - total_count = 0 - for segment in self.person_engaged_cache[person_id]: - total_count += segment["message_count"] - - return total_count + return sum(segment["message_count"] for segment in self.person_engaged_cache[person_id]) def _cleanup_old_segments(self) -> bool: """清理老旧的消息段""" @@ -289,7 +285,7 @@ class RelationshipBuilder: self.last_cleanup_time = current_time # 保存缓存 - if cleanup_stats["segments_removed"] > 0 or len(users_to_remove) > 0: + if cleanup_stats["segments_removed"] > 0 or users_to_remove: self._save_cache() logger.info( f"{self.log_prefix} 清理完成 - 影响用户: {cleanup_stats['users_cleaned']}, 移除消息段: {cleanup_stats['segments_removed']}, 移除用户: {len(users_to_remove)}" @@ -313,6 +309,7 @@ class RelationshipBuilder: return False def get_cache_status(self) -> str: + # sourcery skip: merge-list-append, merge-list-appends-into-extend """获取缓存状态信息,用于调试和监控""" if not self.person_engaged_cache: return f"{self.log_prefix} 关系缓存为空" @@ -357,13 +354,12 @@ class RelationshipBuilder: self._cleanup_old_segments() current_time = time.time() - latest_messages = get_raw_msg_by_timestamp_with_chat( + if latest_messages := get_raw_msg_by_timestamp_with_chat( self.chat_id, self.last_processed_message_time, current_time, limit=50, # 获取自上次处理后的消息 - ) - if latest_messages: + ): # 处理所有新的非bot消息 for latest_msg in latest_messages: user_id = latest_msg.get("user_id") @@ -414,7 +410,7 @@ class RelationshipBuilder: # 负责触发关系构建、整合消息段、更新用户印象 # ================================ - async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, any]]): + async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]): """基于消息段更新用户印象""" original_segment_count = len(segments) logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象") diff --git a/src/person_info/relationship_builder_manager.py b/src/person_info/relationship_builder_manager.py index 926d67fca..f3bca25d2 100644 --- a/src/person_info/relationship_builder_manager.py +++ b/src/person_info/relationship_builder_manager.py @@ -1,4 +1,5 @@ -from typing import Dict, Optional, List +from typing import Dict, Optional, List, Any + from src.common.logger import get_logger from .relationship_builder import RelationshipBuilder @@ -63,7 +64,7 @@ class RelationshipBuilderManager: """ return list(self.builders.keys()) - def get_status(self) -> Dict[str, any]: + def get_status(self) -> Dict[str, Any]: """获取管理器状态 Returns: @@ -94,9 +95,7 @@ class RelationshipBuilderManager: bool: 是否成功清理 """ builder = self.get_builder(chat_id) - if builder: - return builder.force_cleanup_user_segments(person_id) - return False + return builder.force_cleanup_user_segments(person_id) if builder else False # 全局管理器实例 diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 65be0b3af..5e369e752 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -1,16 +1,19 @@ -from src.config.config import global_config -from src.llm_models.utils_model import LLMRequest import time import traceback -from src.common.logger import get_logger -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.person_info.person_info import get_person_info_manager -from typing import List, Dict -from json_repair import repair_json -from src.chat.message_receive.chat_stream import get_chat_manager import json import random +from typing import List, Dict, Any +from json_repair import repair_json + +from src.common.logger import get_logger +from src.config.config import global_config +from src.llm_models.utils_model import LLMRequest +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.message_receive.chat_stream import get_chat_manager +from src.person_info.person_info import get_person_info_manager + + logger = get_logger("relationship_fetcher") @@ -62,11 +65,11 @@ class RelationshipFetcher: self.chat_id = chat_id # 信息获取缓存:记录正在获取的信息请求 - self.info_fetching_cache: List[Dict[str, any]] = [] + self.info_fetching_cache: List[Dict[str, Any]] = [] # 信息结果缓存:存储已获取的信息结果,带TTL - self.info_fetched_cache: Dict[str, Dict[str, any]] = {} - # 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknow": bool}}} + self.info_fetched_cache: Dict[str, Dict[str, Any]] = {} + # 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}} # LLM模型配置 self.llm_model = LLMRequest( @@ -184,7 +187,7 @@ class RelationshipFetcher: nickname_str = ",".join(global_config.bot.alias_names) name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。" person_info_manager = get_person_info_manager() - person_name = await person_info_manager.get_value(person_id, "person_name") + person_name: str = await person_info_manager.get_value(person_id, "person_name") # type: ignore info_cache_block = self._build_info_cache_block() @@ -208,8 +211,7 @@ class RelationshipFetcher: logger.debug(f"{self.log_prefix} LLM判断当前不需要查询任何信息:{content_json.get('none', '')}") return None - info_type = content_json.get("info_type") - if info_type: + if info_type := content_json.get("info_type"): # 记录信息获取请求 self.info_fetching_cache.append( { @@ -287,7 +289,7 @@ class RelationshipFetcher: "ttl": 2, "start_time": start_time, "person_name": person_name, - "unknow": cached_info == "none", + "unknown": cached_info == "none", } logger.info(f"{self.log_prefix} 记得 {person_name} 的 {info_type}: {cached_info}") return @@ -321,7 +323,7 @@ class RelationshipFetcher: "ttl": 2, "start_time": start_time, "person_name": person_name, - "unknow": True, + "unknown": True, } logger.info(f"{self.log_prefix} 完全不认识 {person_name}") await self._save_info_to_cache(person_id, info_type, "none") @@ -353,15 +355,15 @@ class RelationshipFetcher: if person_id not in self.info_fetched_cache: self.info_fetched_cache[person_id] = {} self.info_fetched_cache[person_id][info_type] = { - "info": "unknow" if is_unknown else info_content, + "info": "unknown" if is_unknown else info_content, "ttl": 3, "start_time": start_time, "person_name": person_name, - "unknow": is_unknown, + "unknown": is_unknown, } # 保存到持久化缓存 (info_list) - await self._save_info_to_cache(person_id, info_type, info_content if not is_unknown else "none") + await self._save_info_to_cache(person_id, info_type, "none" if is_unknown else info_content) if not is_unknown: logger.info(f"{self.log_prefix} 思考得到,{person_name} 的 {info_type}: {info_content}") @@ -393,7 +395,7 @@ class RelationshipFetcher: for info_type in self.info_fetched_cache[person_id]: person_name = self.info_fetched_cache[person_id][info_type]["person_name"] - if not self.info_fetched_cache[person_id][info_type]["unknow"]: + if not self.info_fetched_cache[person_id][info_type]["unknown"]: info_content = self.info_fetched_cache[person_id][info_type]["info"] person_known_infos.append(f"[{info_type}]:{info_content}") else: @@ -430,6 +432,7 @@ class RelationshipFetcher: return persons_infos_str async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str): + # sourcery skip: use-next """将提取到的信息保存到 person_info 的 info_list 字段中 Args: diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index 039197250..2c544fe46 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -1,5 +1,5 @@ from src.common.logger import get_logger -from src.person_info.person_info import PersonInfoManager, get_person_info_manager +from .person_info import PersonInfoManager, get_person_info_manager import time import random from src.llm_models.utils_model import LLMRequest @@ -12,7 +12,7 @@ from difflib import SequenceMatcher import jieba from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity - +from typing import List, Dict, Any logger = get_logger("relation") @@ -28,8 +28,7 @@ class RelationshipManager: async def is_known_some_one(platform, user_id): """判断是否认识某人""" person_info_manager = get_person_info_manager() - is_known = await person_info_manager.is_person_known(platform, user_id) - return is_known + return await person_info_manager.is_person_known(platform, user_id) @staticmethod async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str): @@ -110,7 +109,7 @@ class RelationshipManager: return relation_prompt - async def update_person_impression(self, person_id, timestamp, bot_engaged_messages=None): + async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]): """更新用户印象 Args: @@ -123,7 +122,7 @@ class RelationshipManager: person_info_manager = get_person_info_manager() person_name = await person_info_manager.get_value(person_id, "person_name") nickname = await person_info_manager.get_value(person_id, "nickname") - know_times = await person_info_manager.get_value(person_id, "know_times") or 0 + know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore alias_str = ", ".join(global_config.bot.alias_names) # personality_block =get_individuality().get_personality_prompt(x_person=2, level=2) @@ -142,13 +141,13 @@ class RelationshipManager: # 遍历消息,构建映射 for msg in user_messages: await person_info_manager.get_or_create_person( - platform=msg.get("chat_info_platform"), - user_id=msg.get("user_id"), - nickname=msg.get("user_nickname"), - user_cardname=msg.get("user_cardname"), + platform=msg.get("chat_info_platform"), # type: ignore + user_id=msg.get("user_id"), # type: ignore + nickname=msg.get("user_nickname"), # type: ignore + user_cardname=msg.get("user_cardname"), # type: ignore ) - replace_user_id = msg.get("user_id") - replace_platform = msg.get("chat_info_platform") + replace_user_id: str = msg.get("user_id") # type: ignore + replace_platform: str = msg.get("chat_info_platform") # type: ignore replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id) replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name") @@ -354,8 +353,8 @@ class RelationshipManager: person_name = await person_info_manager.get_value(person_id, "person_name") nickname = await person_info_manager.get_value(person_id, "nickname") - know_times = await person_info_manager.get_value(person_id, "know_times") or 0 - attitude = await person_info_manager.get_value(person_id, "attitude") or 50 + know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore + attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore # 根据熟悉度,调整印象和简短印象的最大长度 if know_times > 300: @@ -414,16 +413,14 @@ class RelationshipManager: if len(remaining_points) < 10: # 如果还没达到30条,直接保留 remaining_points.append(point) + elif random.random() < keep_probability: + # 保留这个点,随机移除一个已保留的点 + idx_to_remove = random.randrange(len(remaining_points)) + points_to_move.append(remaining_points[idx_to_remove]) + remaining_points[idx_to_remove] = point else: - # 随机决定是否保留 - if random.random() < keep_probability: - # 保留这个点,随机移除一个已保留的点 - idx_to_remove = random.randrange(len(remaining_points)) - points_to_move.append(remaining_points[idx_to_remove]) - remaining_points[idx_to_remove] = point - else: - # 不保留这个点 - points_to_move.append(point) + # 不保留这个点 + points_to_move.append(point) # 更新points和forgotten_points current_points = remaining_points @@ -520,7 +517,7 @@ class RelationshipManager: new_attitude = int(relation_value_json.get("attitude", 50)) # 获取当前的关系值 - old_attitude = await person_info_manager.get_value(person_id, "attitude") or 50 + old_attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore # 更新熟悉度 if new_attitude > 25: diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index c341e5214..6c8cc01da 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -65,9 +65,9 @@ def get_replyer( async def generate_reply( - chat_stream=None, - chat_id: str = None, - action_data: Dict[str, Any] = None, + chat_stream: Optional[ChatStream] = None, + chat_id: Optional[str] = None, + action_data: Optional[Dict[str, Any]] = None, reply_to: str = "", extra_info: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, @@ -78,25 +78,25 @@ async def generate_reply( model_configs: Optional[List[Dict[str, Any]]] = None, request_type: str = "", enable_timeout: bool = False, -) -> Tuple[bool, List[Tuple[str, Any]]]: +) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: """生成回复 Args: chat_stream: 聊天流对象(优先) - action_data: 动作数据 chat_id: 聊天ID(备用) + action_data: 动作数据 enable_splitter: 是否启用消息分割器 enable_chinese_typo: 是否启用错字生成器 return_prompt: 是否返回提示词 Returns: - Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合) + Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词) """ try: # 获取回复器 replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type) if not replyer: logger.error("[GeneratorAPI] 无法获取回复器") - return False, [] + return False, [], None logger.debug("[GeneratorAPI] 开始生成回复") @@ -109,8 +109,9 @@ async def generate_reply( enable_timeout=enable_timeout, enable_tool=enable_tool, ) - - reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo) + reply_set = [] + if content: + reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo) if success: logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项") @@ -118,19 +119,19 @@ async def generate_reply( logger.warning("[GeneratorAPI] 回复生成失败") if return_prompt: - return success, reply_set or [], prompt + return success, reply_set, prompt else: - return success, reply_set or [] + return success, reply_set, None except Exception as e: logger.error(f"[GeneratorAPI] 生成回复时出错: {e}") - return False, [] + return False, [], None async def rewrite_reply( - chat_stream=None, - reply_data: Dict[str, Any] = None, - chat_id: str = None, + chat_stream: Optional[ChatStream] = None, + reply_data: Optional[Dict[str, Any]] = None, + chat_id: Optional[str] = None, enable_splitter: bool = True, enable_chinese_typo: bool = True, model_configs: Optional[List[Dict[str, Any]]] = None, @@ -158,15 +159,16 @@ async def rewrite_reply( # 调用回复器重写回复 success, content = await replyer.rewrite_reply_with_context(reply_data=reply_data or {}) - - reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo) + reply_set = [] + if content: + reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo) if success: logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项") else: logger.warning("[GeneratorAPI] 重写回复失败") - return success, reply_set or [] + return success, reply_set except Exception as e: logger.error(f"[GeneratorAPI] 重写回复时出错: {e}") diff --git a/src/tools/tool_executor.py b/src/tools/tool_executor.py index 29ee8be1b..403ed554f 100644 --- a/src/tools/tool_executor.py +++ b/src/tools/tool_executor.py @@ -34,7 +34,7 @@ class ToolExecutor: 可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。 """ - def __init__(self, chat_id: str = None, enable_cache: bool = True, cache_ttl: int = 3): + def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3): """初始化工具执行器 Args: @@ -62,8 +62,8 @@ class ToolExecutor: logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}") async def execute_from_chat_message( - self, target_message: str, chat_history: list[str], sender: str, return_details: bool = False - ) -> List[Dict] | Tuple[List[Dict], List[str], str]: + self, target_message: str, chat_history: str, sender: str, return_details: bool = False + ) -> Tuple[List[Dict], List[str], str]: """从聊天消息执行工具 Args: @@ -79,16 +79,14 @@ class ToolExecutor: # 首先检查缓存 cache_key = self._generate_cache_key(target_message, chat_history, sender) - cached_result = self._get_from_cache(cache_key) - - if cached_result: + if cached_result := self._get_from_cache(cache_key): logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行") - if return_details: - # 从缓存结果中提取工具名称 - used_tools = [result.get("tool_name", "unknown") for result in cached_result] - return cached_result, used_tools, "使用缓存结果" - else: - return cached_result + if not return_details: + return cached_result, [], "使用缓存结果" + + # 从缓存结果中提取工具名称 + used_tools = [result.get("tool_name", "unknown") for result in cached_result] + return cached_result, used_tools, "使用缓存结果" # 缓存未命中,执行工具调用 # 获取可用工具 @@ -134,7 +132,7 @@ class ToolExecutor: if return_details: return tool_results, used_tools, prompt else: - return tool_results + return tool_results, [], "" async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]: """执行工具调用 @@ -207,7 +205,7 @@ class ToolExecutor: return tool_results, used_tools - def _generate_cache_key(self, target_message: str, chat_history: list[str], sender: str) -> str: + def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str: """生成缓存键 Args: @@ -267,10 +265,7 @@ class ToolExecutor: return expired_keys = [] - for cache_key, cache_item in self.tool_cache.items(): - if cache_item["ttl"] <= 0: - expired_keys.append(cache_key) - + expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0) for key in expired_keys: del self.tool_cache[key] @@ -355,7 +350,7 @@ class ToolExecutor: "ttl_distribution": ttl_distribution, } - def set_cache_config(self, enable_cache: bool = None, cache_ttl: int = None): + def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1): """动态修改缓存配置 Args: @@ -366,7 +361,7 @@ class ToolExecutor: self.enable_cache = enable_cache logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}") - if cache_ttl is not None and cache_ttl > 0: + if cache_ttl > 0: self.cache_ttl = cache_ttl logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}") @@ -380,7 +375,7 @@ init_tool_executor_prompt() # 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3) executor = ToolExecutor(executor_id="my_executor") -results = await executor.execute_from_chat_message( +results, _, _ = await executor.execute_from_chat_message( talking_message_str="今天天气怎么样?现在几点了?", is_group_chat=False ) From 2d39cefce0a0e96358daeb12f88a90b864f6b9ac Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 12 Jul 2025 16:21:28 +0000 Subject: [PATCH 13/13] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/focus_chat/heartFC_chat.py | 168 +++++++----------- src/chat/focus_chat/hfc_utils.py | 2 - src/chat/focus_chat/priority_manager.py | 2 +- .../heart_flow/heartflow_message_processor.py | 4 +- src/chat/heart_flow/sub_heartflow.py | 11 +- src/chat/message_receive/message.py | 9 +- .../message_receive/uni_message_sender.py | 1 - src/chat/normal_chat/normal_chat.py | 84 +++++---- src/chat/planner_actions/action_modifier.py | 12 +- src/chat/planner_actions/planner.py | 2 +- src/chat/replyer/default_generator.py | 1 - src/chat/utils/chat_message_builder.py | 27 ++- src/common/message_repository.py | 6 +- src/config/official_configs.py | 1 - src/main.py | 1 - src/mood/mood_manager.py | 2 +- src/plugin_system/apis/message_api.py | 24 ++- 17 files changed, 168 insertions(+), 189 deletions(-) diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index a9654449e..9d22a593e 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -17,13 +17,12 @@ from src.chat.focus_chat.hfc_utils import CycleDetail import random from src.chat.focus_chat.hfc_utils import get_recent_message_stats from src.person_info.person_info import get_person_info_manager -from src.plugin_system.apis import generator_api,send_api,message_api +from src.plugin_system.apis import generator_api, send_api, message_api from src.chat.willing.willing_manager import get_willing_manager from .priority_manager import PriorityManager from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat - ERROR_LOOP_INFO = { "loop_plan_info": { "action_result": { @@ -85,7 +84,7 @@ class HeartFChatting: self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id) self.loop_mode = "normal" - + # 新增:消息计数器和疲惫阈值 self._message_count = 0 # 发送的消息计数 # 基于exit_focus_threshold动态计算疲惫阈值 @@ -93,7 +92,6 @@ class HeartFChatting: self._message_threshold = max(10, int(30 * global_config.chat.exit_focus_threshold)) self._fatigue_triggered = False # 是否已触发疲惫退出 - self.action_manager = ActionManager() self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager) self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id) @@ -109,14 +107,12 @@ class HeartFChatting: self.reply_timeout_count = 0 self.plan_timeout_count = 0 - - self.last_read_time = time.time()-1 - - + + self.last_read_time = time.time() - 1 + self.willing_amplifier = 1 self.willing_manager = get_willing_manager() - - + self.reply_mode = self.chat_stream.context.get_priority_mode() if self.reply_mode == "priority": self.priority_manager = PriorityManager( @@ -125,13 +121,11 @@ class HeartFChatting: self.loop_mode = "priority" else: self.priority_manager = None - logger.info( f"{self.log_prefix} HeartFChatting 初始化完成,消息疲惫阈值: {self._message_threshold}条(基于exit_focus_threshold={global_config.chat.exit_focus_threshold}计算,仅在auto模式下生效)" ) - - + self.energy_value = 100 async def start(self): @@ -168,68 +162,69 @@ class HeartFChatting: logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)") except asyncio.CancelledError: logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天") - + def start_cycle(self): self._cycle_counter += 1 self._current_cycle_detail = CycleDetail(self._cycle_counter) self._current_cycle_detail.thinking_id = "tid" + str(round(time.time(), 2)) cycle_timers = {} return cycle_timers, self._current_cycle_detail.thinking_id - - def end_cycle(self,loop_info,cycle_timers): + + def end_cycle(self, loop_info, cycle_timers): self._current_cycle_detail.set_loop_info(loop_info) self.history_loop.append(self._current_cycle_detail) self._current_cycle_detail.timers = cycle_timers self._current_cycle_detail.end_time = time.time() - - def print_cycle_info(self,cycle_timers): - # 记录循环信息和计时器结果 - timer_strings = [] - for name, elapsed in cycle_timers.items(): - formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒" - timer_strings.append(f"{name}: {formatted_time}") - logger.info( - f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考," - f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " - f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}" - + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "") - ) - + def print_cycle_info(self, cycle_timers): + # 记录循环信息和计时器结果 + timer_strings = [] + for name, elapsed in cycle_timers.items(): + formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒" + timer_strings.append(f"{name}: {formatted_time}") + + logger.info( + f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考," + f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " + f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}" + + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "") + ) - async def _loopbody(self): if self.loop_mode == "focus": - - self.energy_value -= 5 * (1/global_config.chat.exit_focus_threshold) + self.energy_value -= 5 * (1 / global_config.chat.exit_focus_threshold) if self.energy_value <= 0: self.loop_mode = "normal" return True - - + return await self._observe() elif self.loop_mode == "normal": new_messages_data = get_raw_msg_by_timestamp_with_chat( - chat_id=self.stream_id, timestamp_start=self.last_read_time, timestamp_end=time.time(),limit=10,limit_mode="earliest",fliter_bot=True + chat_id=self.stream_id, + timestamp_start=self.last_read_time, + timestamp_end=time.time(), + limit=10, + limit_mode="earliest", + fliter_bot=True, ) - + if len(new_messages_data) > 4 * global_config.chat.auto_focus_threshold: self.loop_mode = "focus" self.energy_value = 100 return True - + if new_messages_data: earliest_messages_data = new_messages_data[0] self.last_read_time = earliest_messages_data.get("time") - + await self.normal_response(earliest_messages_data) return True await asyncio.sleep(1) - + return True - - async def build_reply_to_str(self,message_data:dict): + + async def build_reply_to_str(self, message_data: dict): person_info_manager = get_person_info_manager() person_id = person_info_manager.get_person_id( message_data.get("chat_info_platform"), message_data.get("user_id") @@ -238,22 +233,17 @@ class HeartFChatting: reply_to_str = f"{person_name}:{message_data.get('processed_plain_text')}" return reply_to_str - - async def _observe(self,message_data:dict = None): + async def _observe(self, message_data: dict = None): # 创建新的循环信息 cycle_timers, thinking_id = self.start_cycle() - + logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考[模式:{self.loop_mode}]") - - - async with global_prompt_manager.async_message_scope( - self.chat_stream.context.get_template_name() - ): + async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): loop_start_time = time.time() # await self.loop_info.observe() await self.relationship_builder.build_relation() - + # 第一步:动作修改 with Timer("动作修改", cycle_timers): try: @@ -261,18 +251,15 @@ class HeartFChatting: available_actions = self.action_manager.get_using_actions() except Exception as e: logger.error(f"{self.log_prefix} 动作修改失败: {e}") - - #如果normal,开始一个回复生成进程,先准备好回复(其实是和planer同时进行的) + + # 如果normal,开始一个回复生成进程,先准备好回复(其实是和planer同时进行的) if self.loop_mode == "normal": reply_to_str = await self.build_reply_to_str(message_data) - gen_task = asyncio.create_task(self._generate_response(message_data, available_actions,reply_to_str)) - + gen_task = asyncio.create_task(self._generate_response(message_data, available_actions, reply_to_str)) with Timer("规划器", cycle_timers): plan_result = await self.action_planner.plan(mode=self.loop_mode) - - action_result = plan_result.get("action_result", {}) action_type, action_data, reasoning, is_parallel = ( action_result.get("action_type", "error"), @@ -282,7 +269,7 @@ class HeartFChatting: ) action_data["loop_start_time"] = loop_start_time - + if self.loop_mode == "normal": if action_type == "no_action": logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定进行回复") @@ -293,8 +280,6 @@ class HeartFChatting: else: logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定执行{action_type}动作") - - if action_type == "no_action": # 等待回复生成完毕 gather_timeout = global_config.chat.thinking_timeout @@ -307,9 +292,7 @@ class HeartFChatting: content = " ".join([item[1] for item in response_set if item[0] == "text"]) # 模型炸了,没有回复内容生成 - if not response_set or ( - action_type not in ["no_action"] and not is_parallel - ): + if not response_set or (action_type not in ["no_action"] and not is_parallel): if not response_set: logger.warning(f"[{self.log_prefix}] 模型未生成回复内容") elif action_type not in ["no_action"] and not is_parallel: @@ -320,14 +303,11 @@ class HeartFChatting: logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定的回复内容: {content}") - # 发送回复 (不再需要传入 chat) await self._send_response(response_set, reply_to_str, loop_start_time) return True - - - + else: # 动作执行计时 with Timer("动作执行", cycle_timers): @@ -350,18 +330,16 @@ class HeartFChatting: if loop_info["loop_action_info"]["command"] == "stop_focus_chat": logger.info(f"{self.log_prefix} 麦麦决定停止专注聊天") return False - #停止该聊天模式的循环 + # 停止该聊天模式的循环 - self.end_cycle(loop_info,cycle_timers) + self.end_cycle(loop_info, cycle_timers) self.print_cycle_info(cycle_timers) if self.loop_mode == "normal": await self.willing_manager.after_generate_reply_handle(message_data.get("message_id")) return True - - - + async def _main_chat_loop(self): """主循环,持续进行计划并可能回复消息,直到被外部取消。""" try: @@ -370,7 +348,7 @@ class HeartFChatting: await asyncio.sleep(0.1) if not success: break - + logger.info(f"{self.log_prefix} 麦麦已强制离开聊天") except asyncio.CancelledError: # 设置了关闭标志位后被取消是正常流程 @@ -430,7 +408,7 @@ class HeartFChatting: else: success, reply_text = result command = "" - + if reply_text == "timeout": self.reply_timeout_count += 1 if self.reply_timeout_count > 5: @@ -446,8 +424,6 @@ class HeartFChatting: logger.error(f"{self.log_prefix} 处理{action}时出错: {e}") traceback.print_exc() return False, "", "" - - async def shutdown(self): """优雅关闭HeartFChatting实例,取消活动循环任务""" @@ -483,7 +459,6 @@ class HeartFChatting: logger.info(f"{self.log_prefix} HeartFChatting关闭完成") - def adjust_reply_frequency(self): """ 根据预设规则动态调整回复意愿(willing_amplifier)。 @@ -553,18 +528,16 @@ class HeartFChatting: f"[{self.log_prefix}] 调整回复意愿。10分钟内回复: {bot_reply_count_10_min} (目标: {target_replies_in_window:.0f}) -> " f"意愿放大器更新为: {self.willing_amplifier:.2f}" ) - - - + async def normal_response(self, message_data: dict) -> None: """ 处理接收到的消息。 在"兴趣"模式下,判断是否回复并生成内容。 """ - + is_mentioned = message_data.get("is_mentioned", False) interested_rate = message_data.get("interest_rate", 0.0) * self.willing_amplifier - + reply_probability = ( 1.0 if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply else 0.0 ) # 如果被提及,且开启了提及必回复,则基础概率为1,否则需要意愿判断 @@ -587,7 +560,6 @@ class HeartFChatting: if message_data.get("is_emoji") or message_data.get("is_picid"): reply_probability = 0 - # 打印消息信息 mes_name = self.chat_stream.group_info.group_name if self.chat_stream.group_info else "私聊" if reply_probability > 0.1: @@ -599,16 +571,15 @@ class HeartFChatting: if random.random() < reply_probability: await self.willing_manager.before_generate_reply_handle(message_data.get("message_id")) - await self._observe(message_data = message_data) + await self._observe(message_data=message_data) # 意愿管理器:注销当前message信息 (无论是否回复,只要处理过就删除) self.willing_manager.delete(message_data.get("message_id")) - + return True - - + async def _generate_response( - self, message_data: dict, available_actions: Optional[list],reply_to:str + self, message_data: dict, available_actions: Optional[list], reply_to: str ) -> Optional[list]: """生成普通回复""" try: @@ -629,29 +600,28 @@ class HeartFChatting: except Exception as e: logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}") return None - - - async def _send_response( - self, reply_set, reply_to, thinking_start_time - ): + + async def _send_response(self, reply_set, reply_to, thinking_start_time): current_time = time.time() new_message_count = message_api.count_new_messages( chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time ) - + need_reply = new_message_count >= random.randint(2, 4) - + logger.info( f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}引用回复" ) - + reply_text = "" first_replyed = False for reply_seg in reply_set: data = reply_seg[1] if not first_replyed: if need_reply: - await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, typing=False) + await send_api.text_to_stream( + text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, typing=False + ) first_replyed = True else: await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, typing=False) @@ -659,7 +629,5 @@ class HeartFChatting: else: await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, typing=True) reply_text += data - + return reply_text - - \ No newline at end of file diff --git a/src/chat/focus_chat/hfc_utils.py b/src/chat/focus_chat/hfc_utils.py index 5d8df651d..db5bfea12 100644 --- a/src/chat/focus_chat/hfc_utils.py +++ b/src/chat/focus_chat/hfc_utils.py @@ -6,7 +6,6 @@ from src.config.config import global_config from src.common.message_repository import count_messages - logger = get_logger(__name__) @@ -82,7 +81,6 @@ class CycleDetail: self.loop_action_info = loop_info["loop_action_info"] - def get_recent_message_stats(minutes: int = 30, chat_id: str = None) -> dict: """ Args: diff --git a/src/chat/focus_chat/priority_manager.py b/src/chat/focus_chat/priority_manager.py index a3f379651..9db67bc63 100644 --- a/src/chat/focus_chat/priority_manager.py +++ b/src/chat/focus_chat/priority_manager.py @@ -49,7 +49,7 @@ class PriorityManager: 添加新消息到合适的队列中。 """ user_id = message_data.get("user_id") - + priority_info_raw = message_data.get("priority_info") priority_info = {} if isinstance(priority_info_raw, str): diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index d5d63483d..ec28fb813 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -109,12 +109,12 @@ class HeartFCMessageReceiver: interested_rate, is_mentioned = await _calculate_interest(message) message.interest_value = interested_rate message.is_mentioned = is_mentioned - + await self.storage.store_message(message, chat) subheartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) message.update_chat_stream(chat) - + # subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py index 631b0aaec..8c2e6de20 100644 --- a/src/chat/heart_flow/sub_heartflow.py +++ b/src/chat/heart_flow/sub_heartflow.py @@ -28,26 +28,22 @@ class SubHeartflow: self.subheartflow_id = subheartflow_id self.chat_id = subheartflow_id - self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id) self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id - + # focus模式退出冷却时间管理 self.last_focus_exit_time: float = 0 # 上次退出focus模式的时间 # 随便水群 normal_chat 和 认真水群 focus_chat 实例 # CHAT模式激活 随便水群 FOCUS模式激活 认真水群 self.heart_fc_instance: Optional[HeartFChatting] = HeartFChatting( - chat_id=self.subheartflow_id, - ) # 该sub_heartflow的HeartFChatting实例 + chat_id=self.subheartflow_id, + ) # 该sub_heartflow的HeartFChatting实例 async def initialize(self): """异步初始化方法,创建兴趣流并确定聊天类型""" await self.heart_fc_instance.start() - - - async def _stop_heart_fc_chat(self): """停止并清理 HeartFChatting 实例""" if self.heart_fc_instance.running: @@ -85,7 +81,6 @@ class SubHeartflow: logger.error(f"{self.log_prefix} _start_heart_fc_chat 执行时出错: {e}") logger.error(traceback.format_exc()) return False - def is_in_focus_cooldown(self) -> bool: """检查是否在focus模式的冷却期内 diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 8cc06573c..bc55311c1 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -444,11 +444,8 @@ class MessageSet: def message_recv_from_dict(message_dict: dict) -> MessageRecv: - return MessageRecv( - - message_dict - - ) + return MessageRecv(message_dict) + def message_from_db_dict(db_dict: dict) -> MessageRecv: """从数据库字典创建MessageRecv实例""" @@ -492,4 +489,4 @@ def message_from_db_dict(db_dict: dict) -> MessageRecv: msg.is_emoji = db_dict.get("is_emoji", False) msg.is_picid = db_dict.get("is_picid", False) - return msg \ No newline at end of file + return msg diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 07eaaad97..6bc14b026 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -84,4 +84,3 @@ class HeartFCSender: except Exception as e: logger.error(f"[{chat_id}] 处理或存储消息 {message_id} 时出错: {e}") raise e - diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index 5a9293dd8..e7b0434a9 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -23,6 +23,7 @@ logger = get_logger("normal_chat") LOOP_INTERVAL = 0.3 + class NormalChat: """ 普通聊天处理类,负责处理非核心对话的聊天逻辑。 @@ -43,7 +44,7 @@ class NormalChat: """ self.chat_stream = chat_stream self.stream_id = chat_stream.stream_id - self.last_read_time = time.time()-1 + self.last_read_time = time.time() - 1 self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id @@ -56,7 +57,7 @@ class NormalChat: # self.mood_manager = mood_manager self.start_time = time.time() - + self.running = False self._initialized = False # Track initialization status @@ -86,7 +87,7 @@ class NormalChat: # 任务管理 self._chat_task: Optional[asyncio.Task] = None - self._priority_chat_task: Optional[asyncio.Task] = None # for priority mode consumer + self._priority_chat_task: Optional[asyncio.Task] = None # for priority mode consumer self._disabled = False # 停用标志 # 新增:回复模式和优先级管理器 @@ -106,11 +107,11 @@ class NormalChat: if self.reply_mode == "priority" and self._priority_chat_task and not self._priority_chat_task.done(): self._priority_chat_task.cancel() logger.info(f"[{self.stream_name}] NormalChat 已停用。") - + # async def _interest_mode_loopbody(self): # try: # await asyncio.sleep(LOOP_INTERVAL) - + # if self._disabled: # return False @@ -118,10 +119,10 @@ class NormalChat: # new_messages_data = get_raw_msg_by_timestamp_with_chat_inclusive( # chat_id=self.stream_id, timestamp_start=self.last_read_time, timestamp_end=now, limit_mode="earliest" # ) - + # if new_messages_data: # self.last_read_time = now - + # for msg_data in new_messages_data: # try: # self.adjust_reply_frequency() @@ -134,44 +135,42 @@ class NormalChat: # except Exception as e: # logger.error(f"[{self.stream_name}] 处理消息时出错: {e} {traceback.format_exc()}") - # except asyncio.CancelledError: # logger.info(f"[{self.stream_name}] 兴趣模式轮询任务被取消") # return False # except Exception: # logger.error(f"[{self.stream_name}] 兴趣模式轮询循环出现错误: {traceback.format_exc()}", exc_info=True) # await asyncio.sleep(10) - + async def _priority_mode_loopbody(self): - try: - await asyncio.sleep(LOOP_INTERVAL) + try: + await asyncio.sleep(LOOP_INTERVAL) - if self._disabled: - return False - - now = time.time() - new_messages_data = get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id=self.stream_id, timestamp_start=self.last_read_time, timestamp_end=now, limit_mode="earliest" - ) - - if new_messages_data: - self.last_read_time = now - - for msg_data in new_messages_data: - try: - if self.priority_manager: - self.priority_manager.add_message(msg_data, msg_data.get("interest_rate", 0.0)) - return True - except Exception as e: - logger.error(f"[{self.stream_name}] 添加消息到优先级队列时出错: {e} {traceback.format_exc()}") - - - except asyncio.CancelledError: - logger.info(f"[{self.stream_name}] 优先级消息生产者任务被取消") + if self._disabled: return False - except Exception: - logger.error(f"[{self.stream_name}] 优先级消息生产者循环出现错误: {traceback.format_exc()}", exc_info=True) - await asyncio.sleep(10) + + now = time.time() + new_messages_data = get_raw_msg_by_timestamp_with_chat_inclusive( + chat_id=self.stream_id, timestamp_start=self.last_read_time, timestamp_end=now, limit_mode="earliest" + ) + + if new_messages_data: + self.last_read_time = now + + for msg_data in new_messages_data: + try: + if self.priority_manager: + self.priority_manager.add_message(msg_data, msg_data.get("interest_rate", 0.0)) + return True + except Exception as e: + logger.error(f"[{self.stream_name}] 添加消息到优先级队列时出错: {e} {traceback.format_exc()}") + + except asyncio.CancelledError: + logger.info(f"[{self.stream_name}] 优先级消息生产者任务被取消") + return False + except Exception: + logger.error(f"[{self.stream_name}] 优先级消息生产者循环出现错误: {traceback.format_exc()}", exc_info=True) + await asyncio.sleep(10) # async def _interest_message_polling_loop(self): # """ @@ -181,16 +180,13 @@ class NormalChat: # try: # while not self._disabled: # success = await self._interest_mode_loopbody() - + # if not success: # break # except asyncio.CancelledError: # logger.info(f"[{self.stream_name}] 兴趣模式消息轮询任务被优雅地取消了") - - - async def _priority_chat_loop(self): """ 使用优先级队列的消息处理循环。 @@ -272,9 +268,8 @@ class NormalChat: # user_nickname=message_data.get("user_nickname"), # platform=message_data.get("chat_info_platform"), # ) - + # reply = message_from_db_dict(message_data) - # mark_head = False # first_bot_msg = None @@ -652,7 +647,9 @@ class NormalChat: # Start consumer loop consumer_task = asyncio.create_task(self._priority_chat_loop()) self._priority_chat_task = consumer_task - self._priority_chat_task.add_done_callback(lambda t: self._handle_task_completion(t, "priority_consumer")) + self._priority_chat_task.add_done_callback( + lambda t: self._handle_task_completion(t, "priority_consumer") + ) else: # Interest mode polling_task = asyncio.create_task(self._interest_message_polling_loop()) self._chat_task = polling_task @@ -712,7 +709,6 @@ class NormalChat: self._chat_task = None self._priority_chat_task = None - # def adjust_reply_frequency(self): # """ # 根据预设规则动态调整回复意愿(willing_amplifier)。 diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index fe0941fd9..8f17f16f2 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -82,9 +82,9 @@ class ActionModifier: # === 第一阶段:传统观察处理 === # if history_loop: - # removals_from_loop = await self.analyze_loop_actions(history_loop) - # if removals_from_loop: - # removals_s1.extend(removals_from_loop) + # removals_from_loop = await self.analyze_loop_actions(history_loop) + # if removals_from_loop: + # removals_s1.extend(removals_from_loop) # 检查动作的关联类型 chat_context = self.chat_stream.context @@ -188,7 +188,7 @@ class ActionModifier: reason = "激活类型为never" deactivated_actions.append((action_name, reason)) logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: 激活类型为never") - + else: logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理") @@ -500,13 +500,13 @@ class ActionModifier: return removals - def get_available_actions_count(self,mode:str = "focus") -> int: + def get_available_actions_count(self, mode: str = "focus") -> int: """获取当前可用动作数量(排除默认的no_action)""" current_actions = self.action_manager.get_using_actions_for_mode(mode) # 排除no_action(如果存在) filtered_actions = {k: v for k, v in current_actions.items() if k != "no_action"} return len(filtered_actions) - + def should_skip_planning_for_no_reply(self) -> bool: """判断是否应该跳过规划过程""" current_actions = self.action_manager.get_using_actions_for_mode("focus") diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 8863c60fa..7d08688a6 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -76,7 +76,7 @@ class ActionPlanner: self.last_obs_time_mark = 0.0 - async def plan(self,mode:str = "focus") -> Dict[str, Any]: + async def plan(self, mode: str = "focus") -> Dict[str, Any]: """ 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 """ diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 1d83d2c29..b2df0dff1 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -506,7 +506,6 @@ class DefaultReplyer: show_actions=True, ) - message_list_before_short = get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index a858abd4d..cdc9ffe86 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -28,7 +28,12 @@ def get_raw_msg_by_timestamp( def get_raw_msg_by_timestamp_with_chat( - chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest", fliter_bot = False + chat_id: str, + timestamp_start: float, + timestamp_end: float, + limit: int = 0, + limit_mode: str = "latest", + fliter_bot=False, ) -> List[Dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 @@ -38,11 +43,18 @@ def get_raw_msg_by_timestamp_with_chat( # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None # 直接将 limit_mode 传递给 find_messages - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, fliter_bot=fliter_bot) + return find_messages( + message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, fliter_bot=fliter_bot + ) def get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest", fliter_bot = False + chat_id: str, + timestamp_start: float, + timestamp_end: float, + limit: int = 0, + limit_mode: str = "latest", + fliter_bot=False, ) -> List[Dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 @@ -52,8 +64,10 @@ def get_raw_msg_by_timestamp_with_chat_inclusive( # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None # 直接将 limit_mode 传递给 find_messages - - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, fliter_bot=fliter_bot) + + return find_messages( + message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, fliter_bot=fliter_bot + ) def get_raw_msg_by_timestamp_with_chat_users( @@ -583,8 +597,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str: action_name = action.get("action_name", "未知动作") if action_name == "no_action" or action_name == "no_reply": continue - - + action_prompt_display = action.get("action_prompt_display", "无具体内容") time_diff_seconds = current_time - action_time diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 716452784..4eb9287a2 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -20,7 +20,7 @@ def find_messages( sort: Optional[List[tuple[str, int]]] = None, limit: int = 0, limit_mode: str = "latest", - fliter_bot = False + fliter_bot=False, ) -> List[dict[str, Any]]: """ 根据提供的过滤器、排序和限制条件查找消息。 @@ -69,10 +69,10 @@ def find_messages( logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") if conditions: query = query.where(*conditions) - + if fliter_bot: query = query.where(Messages.user_id != global_config.bot.qq_account) - + if limit > 0: if limit_mode == "earliest": # 获取时间最早的 limit 条记录,已经是正序 diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 613e447e8..000d4e95f 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -278,7 +278,6 @@ class NormalChatConfig(ConfigBase): """@bot 必然回复""" - @dataclass class FocusChatConfig(ConfigBase): """专注聊天配置类""" diff --git a/src/main.py b/src/main.py index a457f42e4..0e85f6945 100644 --- a/src/main.py +++ b/src/main.py @@ -125,7 +125,6 @@ class MainSystem: logger.info("个体特征初始化成功") try: - init_time = int(1000 * (time.time() - init_start_time)) logger.info(f"初始化完成,神经元放电{init_time}次") except Exception as e: diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index a7b8d7f49..a8f343f3f 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -77,7 +77,7 @@ class ChatMood: if random.random() > update_probability: return - + logger.info(f"更新情绪状态,感兴趣度: {interested_rate}, 更新概率: {update_probability}") message_time = message.message_info.time diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index d3e319595..e3847c55f 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -56,7 +56,12 @@ def get_messages_by_time( def get_messages_by_time_in_chat( - chat_id: str, start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False + chat_id: str, + start_time: float, + end_time: float, + limit: int = 0, + limit_mode: str = "latest", + filter_mai: bool = False, ) -> List[Dict[str, Any]]: """ 获取指定聊天中指定时间范围内的消息 @@ -78,7 +83,12 @@ def get_messages_by_time_in_chat( def get_messages_by_time_in_chat_inclusive( - chat_id: str, start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False + chat_id: str, + start_time: float, + end_time: float, + limit: int = 0, + limit_mode: str = "latest", + filter_mai: bool = False, ) -> List[Dict[str, Any]]: """ 获取指定聊天中指定时间范围内的消息(包含边界) @@ -95,7 +105,9 @@ def get_messages_by_time_in_chat_inclusive( 消息列表 """ if filter_mai: - return filter_mai_messages(get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode)) + return filter_mai_messages( + get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode) + ) return get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode) @@ -181,7 +193,9 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool return get_raw_msg_before_timestamp(timestamp, limit) -def get_messages_before_time_in_chat(chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]: +def get_messages_before_time_in_chat( + chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False +) -> List[Dict[str, Any]]: """ 获取指定聊天中指定时间戳之前的消息 @@ -342,10 +356,12 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s """ return await get_person_id_list(messages) + # ============================================================================= # 消息过滤函数 # ============================================================================= + def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ 从消息列表中移除麦麦的消息