From af02f2ab573b4cc2adf9734fb04f584180de797a Mon Sep 17 00:00:00 2001 From: UnCLASPrommer Date: Tue, 15 Jul 2025 00:57:43 +0800 Subject: [PATCH] fix typing, api change --- changes.md | 6 +- src/plugin_system/apis/chat_api.py | 77 ++++++++++++++---------- src/plugin_system/apis/config_api.py | 39 ++++++------ src/plugin_system/apis/database_api.py | 54 +++++++++-------- src/plugin_system/core/plugin_manager.py | 11 ++-- 5 files changed, 110 insertions(+), 77 deletions(-) diff --git a/changes.md b/changes.md index 85760965f..7ec499b43 100644 --- a/changes.md +++ b/changes.md @@ -15,4 +15,8 @@ # 插件系统修改 1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)** 2. 修复了一下显示插件信息不显示的问题。同时精简了一下显示内容 -3. 修复了插件系统混用了`plugin_name`和`display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。**(可能有遗漏)** \ No newline at end of file +3. 修复了插件系统混用了`plugin_name`和`display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。**(可能有遗漏)** +3. 部分API的参数类型和返回值进行了调整 + - `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。 + - `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。 + - `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。 diff --git a/src/plugin_system/apis/chat_api.py b/src/plugin_system/apis/chat_api.py index b56142a47..f436c4ab5 100644 --- a/src/plugin_system/apis/chat_api.py +++ b/src/plugin_system/apis/chat_api.py @@ -13,23 +13,29 @@ """ from typing import List, Dict, Any, Optional -from src.common.logger import get_logger +from enum import Enum -# 导入依赖 +from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager logger = get_logger("chat_api") +class SpecialTypes(Enum): + """特殊枚举类型""" + + ALL_PLATFORMS = "all_platforms" + + class ChatManager: """聊天管理器 - 专门负责聊天信息的查询和管理""" @staticmethod - def get_all_streams(platform: str = "qq") -> List[ChatStream]: + def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: """获取所有聊天流 Args: - platform: 平台筛选,默认为"qq" + platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 Returns: List[ChatStream]: 聊天流列表 @@ -37,7 +43,7 @@ class ChatManager: streams = [] try: for _, stream in get_chat_manager().streams.items(): - if stream.platform == platform: + if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform: streams.append(stream) logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的聊天流") except Exception as e: @@ -45,11 +51,11 @@ class ChatManager: return streams @staticmethod - def get_group_streams(platform: str = "qq") -> List[ChatStream]: + def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: """获取所有群聊聊天流 Args: - platform: 平台筛选,默认为"qq" + platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 Returns: List[ChatStream]: 群聊聊天流列表 @@ -57,7 +63,7 @@ class ChatManager: streams = [] try: for _, stream in get_chat_manager().streams.items(): - if stream.platform == platform and stream.group_info: + if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info: streams.append(stream) logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的群聊流") except Exception as e: @@ -65,11 +71,11 @@ class ChatManager: return streams @staticmethod - def get_private_streams(platform: str = "qq") -> List[ChatStream]: + def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: """获取所有私聊聊天流 Args: - platform: 平台筛选,默认为"qq" + platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 Returns: List[ChatStream]: 私聊聊天流列表 @@ -77,7 +83,7 @@ class ChatManager: streams = [] try: for _, stream in get_chat_manager().streams.items(): - if stream.platform == platform and not stream.group_info: + if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info: streams.append(stream) logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的私聊流") except Exception as e: @@ -85,12 +91,14 @@ class ChatManager: return streams @staticmethod - def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]: + def get_group_stream_by_group_id( + group_id: str, platform: Optional[str] | SpecialTypes = "qq" + ) -> Optional[ChatStream]: """根据群ID获取聊天流 Args: group_id: 群聊ID - platform: 平台,默认为"qq" + platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 Returns: Optional[ChatStream]: 聊天流对象,如果未找到返回None @@ -110,12 +118,14 @@ class ChatManager: return None @staticmethod - def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]: + def get_private_stream_by_user_id( + user_id: str, platform: Optional[str] | SpecialTypes = "qq" + ) -> Optional[ChatStream]: """根据用户ID获取私聊流 Args: user_id: 用户ID - platform: 平台,默认为"qq" + platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 Returns: Optional[ChatStream]: 聊天流对象,如果未找到返回None @@ -145,7 +155,7 @@ class ChatManager: str: 聊天类型 ("group", "private", "unknown") """ if not chat_stream: - return "unknown" + raise ValueError("chat_stream cannot be None") if hasattr(chat_stream, "group_info"): return "group" if chat_stream.group_info else "private" @@ -165,7 +175,7 @@ class ChatManager: return {} try: - info = { + info: Dict[str, Any] = { "stream_id": chat_stream.stream_id, "platform": chat_stream.platform, "type": ChatManager.get_stream_type(chat_stream), @@ -200,9 +210,9 @@ class ChatManager: Dict[str, int]: 包含各种统计信息的字典 """ try: - all_streams = ChatManager.get_all_streams() - group_streams = ChatManager.get_group_streams() - private_streams = ChatManager.get_private_streams() + all_streams = ChatManager.get_all_streams(SpecialTypes.ALL_PLATFORMS) + group_streams = ChatManager.get_group_streams(SpecialTypes.ALL_PLATFORMS) + private_streams = ChatManager.get_private_streams(SpecialTypes.ALL_PLATFORMS) summary = { "total_streams": len(all_streams), @@ -215,7 +225,12 @@ class ChatManager: return summary except Exception as e: logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}") - return {"total_streams": 0, "group_streams": 0, "private_streams": 0, "qq_streams": 0} + return { + "total_streams": 0, + "group_streams": 0, + "private_streams": 0, + "qq_streams": 0, + } # ============================================================================= @@ -223,41 +238,41 @@ class ChatManager: # ============================================================================= -def get_all_streams(platform: str = "qq") -> List[ChatStream]: +def get_all_streams(platform: Optional[str] | SpecialTypes = "qq"): """获取所有聊天流的便捷函数""" return ChatManager.get_all_streams(platform) -def get_group_streams(platform: str = "qq") -> List[ChatStream]: +def get_group_streams(platform: Optional[str] | SpecialTypes = "qq"): """获取群聊聊天流的便捷函数""" return ChatManager.get_group_streams(platform) -def get_private_streams(platform: str = "qq") -> List[ChatStream]: +def get_private_streams(platform: Optional[str] | SpecialTypes = "qq"): """获取私聊聊天流的便捷函数""" return ChatManager.get_private_streams(platform) -def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]: +def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq"): """根据群ID获取聊天流的便捷函数""" - return ChatManager.get_stream_by_group_id(group_id, platform) + return ChatManager.get_group_stream_by_group_id(group_id, platform) -def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]: +def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq"): """根据用户ID获取私聊流的便捷函数""" - return ChatManager.get_stream_by_user_id(user_id, platform) + return ChatManager.get_private_stream_by_user_id(user_id, platform) -def get_stream_type(chat_stream: ChatStream) -> str: +def get_stream_type(chat_stream: ChatStream): """获取聊天流类型的便捷函数""" return ChatManager.get_stream_type(chat_stream) -def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]: +def get_stream_info(chat_stream: ChatStream): """获取聊天流信息的便捷函数""" return ChatManager.get_stream_info(chat_stream) -def get_streams_summary() -> Dict[str, int]: +def get_streams_summary(): """获取聊天流统计摘要的便捷函数""" return ChatManager.get_streams_summary() diff --git a/src/plugin_system/apis/config_api.py b/src/plugin_system/apis/config_api.py index 80b9d2645..6ec492caf 100644 --- a/src/plugin_system/apis/config_api.py +++ b/src/plugin_system/apis/config_api.py @@ -26,7 +26,7 @@ def get_global_config(key: str, default: Any = None) -> Any: 插件应使用此方法读取全局配置,以保证只读和隔离性。 Args: - key: 配置键名,支持嵌套访问如 "section.subsection.key" + key: 命名空间式配置键名,支持嵌套访问,如 "section.subsection.key",大小写敏感 default: 如果配置不存在时返回的默认值 Returns: @@ -41,7 +41,7 @@ def get_global_config(key: str, default: Any = None) -> Any: if hasattr(current, k): current = getattr(current, k) else: - return default + raise KeyError(f"配置中不存在子空间或键 '{k}'") return current except Exception as e: logger.warning(f"[ConfigAPI] 获取全局配置 {key} 失败: {e}") @@ -54,26 +54,28 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any Args: plugin_config: 插件配置字典 - key: 配置键名,支持嵌套访问如 "section.subsection.key" + key: 配置键名,支持嵌套访问如 "section.subsection.key",大小写敏感 default: 如果配置不存在时返回的默认值 Returns: Any: 配置值或默认值 """ - if not plugin_config: - return default - # 支持嵌套键访问 keys = key.split(".") current = plugin_config - for k in keys: - if isinstance(current, dict) and k in current: - current = current[k] - else: - return default - - return current + try: + for k in keys: + if isinstance(current, dict) and k in current: + current = current[k] + elif hasattr(current, k): + current = getattr(current, k) + else: + raise KeyError(f"配置中不存在子空间或键 '{k}'") + return current + except Exception as e: + logger.warning(f"[ConfigAPI] 获取插件配置 {key} 失败: {e}") + return default # ============================================================================= @@ -82,7 +84,7 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any async def get_user_id_by_person_name(person_name: str) -> tuple[str, str]: - """根据用户名获取用户ID + """根据内部用户名获取用户ID Args: person_name: 用户名 @@ -93,8 +95,8 @@ async def get_user_id_by_person_name(person_name: str) -> tuple[str, str]: try: person_info_manager = get_person_info_manager() person_id = person_info_manager.get_person_id_by_person_name(person_name) - user_id = await person_info_manager.get_value(person_id, "user_id") - platform = await person_info_manager.get_value(person_id, "platform") + user_id: str = await person_info_manager.get_value(person_id, "user_id") # type: ignore + platform: str = await person_info_manager.get_value(person_id, "platform") # type: ignore return platform, user_id except Exception as e: logger.error(f"[ConfigAPI] 根据用户名获取用户ID失败: {e}") @@ -114,7 +116,10 @@ async def get_person_info(person_id: str, key: str, default: Any = None) -> Any: """ try: person_info_manager = get_person_info_manager() - return await person_info_manager.get_value(person_id, key, default) + response = await person_info_manager.get_value(person_id, key) + if not response: + raise ValueError(f"[ConfigAPI] 获取用户 {person_id} 的信息 '{key}' 失败,返回默认值") + return response except Exception as e: logger.error(f"[ConfigAPI] 获取用户信息失败: {e}") return default diff --git a/src/plugin_system/apis/database_api.py b/src/plugin_system/apis/database_api.py index 085df997f..d46bfba39 100644 --- a/src/plugin_system/apis/database_api.py +++ b/src/plugin_system/apis/database_api.py @@ -8,7 +8,7 @@ """ import traceback -from typing import Dict, List, Any, Union, Type +from typing import Dict, List, Any, Union, Type, Optional from src.common.logger import get_logger from peewee import Model, DoesNotExist @@ -21,12 +21,12 @@ logger = get_logger("database_api") async def db_query( model_class: Type[Model], - query_type: str = "get", - filters: Dict[str, Any] = None, - data: Dict[str, Any] = None, - limit: int = None, - order_by: List[str] = None, - single_result: bool = False, + data: Optional[Dict[str, Any]] = None, + query_type: Optional[str] = "get", + filters: Optional[Dict[str, Any]] = None, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + single_result: Optional[bool] = False, ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: """执行数据库查询操作 @@ -34,11 +34,11 @@ async def db_query( Args: model_class: Peewee 模型类,例如 ActionRecords, Messages 等 + data: 用于创建或更新的数据字典 query_type: 查询类型,可选值: "get", "create", "update", "delete", "count" filters: 过滤条件字典,键为字段名,值为要匹配的值 - data: 用于创建或更新的数据字典 limit: 限制结果数量 - order_by: 排序字段列表,使用字段名,前缀'-'表示降序 + order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序 single_result: 是否只返回单个结果 Returns: @@ -48,7 +48,8 @@ async def db_query( - "update": 返回受影响的行数 - "delete": 返回受影响的行数 - "count": 返回记录数量 - + """ + """ 示例: # 查询最近10条消息 messages = await database_api.db_query( @@ -62,16 +63,16 @@ async def db_query( # 创建一条记录 new_record = await database_api.db_query( ActionRecords, + data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}, query_type="create", - data={"action_id": "123", "time": time.time(), "action_name": "TestAction"} ) # 更新记录 updated_count = await database_api.db_query( ActionRecords, + data={"action_done": True}, query_type="update", filters={"action_id": "123"}, - data={"action_done": True} ) # 删除记录 @@ -129,7 +130,7 @@ async def db_query( # 创建记录 record = model_class.create(**data) # 返回创建的记录 - return model_class.select().where(model_class.id == record.id).dicts().get() + return model_class.select().where(model_class.id == record.id).dicts().get() # type: ignore elif query_type == "update": if not data: @@ -168,7 +169,7 @@ async def db_query( async def db_save( - model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None + model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None ) -> Union[Dict[str, Any], None]: """保存数据到数据库(创建或更新) @@ -213,14 +214,14 @@ async def db_save( existing_record.save() # 返回更新后的记录 - updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() + updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() # type: ignore return updated_record # 如果没有找到现有记录或未提供key_field和key_value,创建新记录 new_record = model_class.create(**data) # 返回创建的记录 - created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() + created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() # type: ignore return created_record except Exception as e: @@ -230,7 +231,11 @@ async def db_save( async def db_get( - model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None + model_class: Type[Model], + filters: Optional[Dict[str, Any]] = None, + limit: Optional[int] = None, + order_by: Optional[str] = None, + single_result: Optional[bool] = False, ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: """从数据库获取记录 @@ -239,11 +244,12 @@ async def db_get( Args: model_class: Peewee模型类 filters: 过滤条件,字段名和值的字典 - order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序 - limit: 结果数量限制,如果为1则返回单个记录而不是列表 + order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序 + limit: 结果数量限制 + single_result: 是否只返回单个结果,如果为True,则返回单个记录字典或None;否则返回记录字典列表或空列表 Returns: - 如果limit=1,返回单个记录字典或None; + 如果single_result为True,返回单个记录字典或None; 否则返回记录字典列表或空列表。 示例: @@ -258,8 +264,8 @@ async def db_get( records = await database_api.db_get( Messages, filters={"chat_id": chat_stream.stream_id}, + limit=10, order_by="-time", - limit=10 ) """ try: @@ -286,14 +292,14 @@ async def db_get( results = list(query.dicts()) # 返回结果 - if limit == 1: + if single_result: return results[0] if results else None return results except Exception as e: logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}") traceback.print_exc() - return None if limit == 1 else [] + return None if single_result else [] async def store_action_info( @@ -302,7 +308,7 @@ async def store_action_info( action_prompt_display: str = "", action_done: bool = True, thinking_id: str = "", - action_data: dict = None, + action_data: Optional[dict] = None, action_name: str = "", ) -> Union[Dict[str, Any], None]: """存储动作信息到数据库 diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index fd75d8c9d..b428912e6 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -102,7 +102,7 @@ class PluginManager: """ 加载已经注册的插件类 """ - plugin_class: Type[BasePlugin] = self.plugin_classes.get(plugin_name) + plugin_class = self.plugin_classes.get(plugin_name) if not plugin_class: logger.error(f"插件 {plugin_name} 的插件类未注册或不存在") return False, 1 @@ -115,7 +115,10 @@ class PluginManager: plugin_dir = self._find_plugin_directory(plugin_class) if plugin_dir: self.plugin_paths[plugin_name] = plugin_dir # 更新路径 - plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败) + plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败) + if not plugin_instance: + logger.error(f"插件 {plugin_name} 实例化失败") + return False, 1 # 检查插件是否启用 if not plugin_instance.enable_plugin: logger.info(f"插件 {plugin_name} 已禁用,跳过加载") @@ -248,7 +251,7 @@ class PluginManager: "failed_plugin_details": self.failed_plugins.copy(), } - def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, any]: + def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]: """检查所有插件的Python依赖包 Args: @@ -381,7 +384,7 @@ class PluginManager: return loaded_count, failed_count - def _find_plugin_directory(self, plugin_class: str) -> Optional[str]: + def _find_plugin_directory(self, plugin_class: Type[BasePlugin]) -> Optional[str]: """查找插件类对应的目录路径""" try: module = getmodule(plugin_class)