From 80a1c0bf933c516bf861dbf2b40de34d79a2a8e4 Mon Sep 17 00:00:00 2001 From: UnCLASPrommer Date: Tue, 15 Jul 2025 19:09:04 +0800 Subject: [PATCH] api typing check --- src/chat/replyer/default_generator.py | 2 + src/chat/utils/chat_message_builder.py | 2 +- src/plugin_system/apis/generator_api.py | 12 +- src/plugin_system/apis/message_api.py | 134 +++++++++++++++++++-- src/plugin_system/apis/person_api.py | 4 +- src/plugin_system/apis/send_api.py | 2 +- src/plugin_system/apis/utils_api.py | 10 +- src/plugins/built_in/core_actions/emoji.py | 4 +- 8 files changed, 144 insertions(+), 26 deletions(-) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index dddd8e1cc..7da6ebc01 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -182,6 +182,7 @@ class DefaultReplyer: 回复器 (Replier): 核心逻辑,负责生成回复文本。 (已整合原 HeartFCGenerator 的功能) """ + prompt = None if available_actions is None: available_actions = {} if reply_data is None: @@ -707,6 +708,7 @@ class DefaultReplyer: ) target_user_id = "" + person_id = "" if sender: # 根据sender通过person_info_manager反向查找person_id,再获取user_id person_id = person_info_manager.get_person_id_by_person_name(sender) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 2ff537f0c..aaa59c8ec 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -76,7 +76,7 @@ def get_raw_msg_by_timestamp_with_chat_users( chat_id: str, timestamp_start: float, timestamp_end: float, - person_ids: list, + person_ids: List[str], limit: int = 0, limit_mode: str = "latest", ) -> List[Dict[str, Any]]: diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 4763dbd1b..cbb1336ce 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -131,6 +131,9 @@ async def generate_reply( else: return success, reply_set, None + except ValueError as ve: + raise ve + except Exception as e: logger.error(f"[GeneratorAPI] 生成回复时出错: {e}") return False, [], None @@ -178,6 +181,9 @@ async def rewrite_reply( return success, reply_set + except ValueError as ve: + raise ve + except Exception as e: logger.error(f"[GeneratorAPI] 重写回复时出错: {e}") return False, [] @@ -191,12 +197,14 @@ async def process_human_text(content: str, enable_splitter: bool, enable_chinese enable_splitter: 是否启用消息分割器 enable_chinese_typo: 是否启用错字生成器 """ + if not isinstance(content, str): + raise ValueError("content 必须是字符串类型") try: processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo) reply_set = [] - for str in processed_response: - reply_seg = ("text", str) + for text in processed_response: + reply_seg = ("text", text) reply_set.append(reply_seg) return reply_set diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index e3847c55f..b720bb23c 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -48,8 +48,15 @@ def get_messages_by_time( filter_mai: 是否过滤麦麦自身的消息,默认为False Returns: - 消息列表 + List[Dict[str, Any]]: 消息列表 + + Raises: + ValueError: 如果参数不合法 """ + if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): + raise ValueError("start_time 和 end_time 必须是数字类型") + if limit < 0: + raise ValueError("limit 不能为负数") if filter_mai: return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)) return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode) @@ -75,8 +82,19 @@ def get_messages_by_time_in_chat( filter_mai: 是否过滤麦麦自身的消息,默认为False Returns: - 消息列表 + List[Dict[str, Any]]: 消息列表 + + Raises: + ValueError: 如果参数不合法 """ + if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): + raise ValueError("start_time 和 end_time 必须是数字类型") + if limit < 0: + raise ValueError("limit 不能为负数") + if not chat_id: + raise ValueError("chat_id 不能为空") + if not isinstance(chat_id, str): + raise ValueError("chat_id 必须是字符串类型") if filter_mai: return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode)) return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode) @@ -102,8 +120,19 @@ def get_messages_by_time_in_chat_inclusive( filter_mai: 是否过滤麦麦自身的消息,默认为False Returns: - 消息列表 + List[Dict[str, Any]]: 消息列表 + + Raises: + ValueError: 如果参数不合法 """ + if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): + raise ValueError("start_time 和 end_time 必须是数字类型") + if limit < 0: + raise ValueError("limit 不能为负数") + if not chat_id: + raise ValueError("chat_id 不能为空") + if not isinstance(chat_id, str): + raise ValueError("chat_id 必须是字符串类型") if filter_mai: return filter_mai_messages( get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode) @@ -115,7 +144,7 @@ def get_messages_by_time_in_chat_for_users( chat_id: str, start_time: float, end_time: float, - person_ids: list, + person_ids: List[str], limit: int = 0, limit_mode: str = "latest", ) -> List[Dict[str, Any]]: @@ -131,8 +160,19 @@ def get_messages_by_time_in_chat_for_users( limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录 Returns: - 消息列表 + List[Dict[str, Any]]: 消息列表 + + Raises: + ValueError: 如果参数不合法 """ + if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): + raise ValueError("start_time 和 end_time 必须是数字类型") + if limit < 0: + raise ValueError("limit 不能为负数") + if not chat_id: + raise ValueError("chat_id 不能为空") + if not isinstance(chat_id, str): + raise ValueError("chat_id 必须是字符串类型") return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode) @@ -150,8 +190,15 @@ def get_random_chat_messages( filter_mai: 是否过滤麦麦自身的消息,默认为False Returns: - 消息列表 + List[Dict[str, Any]]: 消息列表 + + Raises: + ValueError: 如果参数不合法 """ + if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): + raise ValueError("start_time 和 end_time 必须是数字类型") + if limit < 0: + raise ValueError("limit 不能为负数") if filter_mai: return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)) return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode) @@ -171,8 +218,15 @@ def get_messages_by_time_for_users( limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录 Returns: - 消息列表 + List[Dict[str, Any]]: 消息列表 + + Raises: + ValueError: 如果参数不合法 """ + if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): + raise ValueError("start_time 和 end_time 必须是数字类型") + if limit < 0: + raise ValueError("limit 不能为负数") return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode) @@ -186,8 +240,15 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool filter_mai: 是否过滤麦麦自身的消息,默认为False Returns: - 消息列表 + List[Dict[str, Any]]: 消息列表 + + Raises: + ValueError: 如果参数不合法 """ + if not isinstance(timestamp, (int, float)): + raise ValueError("timestamp 必须是数字类型") + if limit < 0: + raise ValueError("limit 不能为负数") if filter_mai: return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit)) return get_raw_msg_before_timestamp(timestamp, limit) @@ -206,8 +267,19 @@ def get_messages_before_time_in_chat( filter_mai: 是否过滤麦麦自身的消息,默认为False Returns: - 消息列表 + List[Dict[str, Any]]: 消息列表 + + Raises: + ValueError: 如果参数不合法 """ + if not isinstance(timestamp, (int, float)): + raise ValueError("timestamp 必须是数字类型") + if limit < 0: + raise ValueError("limit 不能为负数") + if not chat_id: + raise ValueError("chat_id 不能为空") + if not isinstance(chat_id, str): + raise ValueError("chat_id 必须是字符串类型") if filter_mai: return filter_mai_messages(get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)) return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit) @@ -223,8 +295,15 @@ def get_messages_before_time_for_users(timestamp: float, person_ids: list, limit limit: 限制返回的消息数量,0为不限制 Returns: - 消息列表 + List[Dict[str, Any]]: 消息列表 + + Raises: + ValueError: 如果参数不合法 """ + if not isinstance(timestamp, (int, float)): + raise ValueError("timestamp 必须是数字类型") + if limit < 0: + raise ValueError("limit 不能为负数") return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit) @@ -242,8 +321,19 @@ def get_recent_messages( filter_mai: 是否过滤麦麦自身的消息,默认为False Returns: - 消息列表 + List[Dict[str, Any]]: 消息列表 + + Raises: + ValueError: 如果参数不合法s """ + if not isinstance(hours, (int, float)) or hours < 0: + raise ValueError("hours 不能是负数") + if not isinstance(limit, int) or limit < 0: + raise ValueError("limit 必须是非负整数") + if not chat_id: + raise ValueError("chat_id 不能为空") + if not isinstance(chat_id, str): + raise ValueError("chat_id 必须是字符串类型") now = time.time() start_time = now - hours * 3600 if filter_mai: @@ -266,8 +356,17 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional end_time: 结束时间戳,如果为None则使用当前时间 Returns: - 新消息数量 + int: 新消息数量 + + Raises: + ValueError: 如果参数不合法 """ + if not isinstance(start_time, (int, float)): + raise ValueError("start_time 必须是数字类型") + if not chat_id: + raise ValueError("chat_id 不能为空") + if not isinstance(chat_id, str): + raise ValueError("chat_id 必须是字符串类型") return num_new_messages_since(chat_id, start_time, end_time) @@ -282,8 +381,17 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa person_ids: 用户ID列表 Returns: - 新消息数量 + int: 新消息数量 + + Raises: + ValueError: 如果参数不合法 """ + if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): + raise ValueError("start_time 和 end_time 必须是数字类型") + if not chat_id: + raise ValueError("chat_id 不能为空") + if not isinstance(chat_id, str): + raise ValueError("chat_id 必须是字符串类型") return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids) diff --git a/src/plugin_system/apis/person_api.py b/src/plugin_system/apis/person_api.py index ae108211c..a84c5d2bb 100644 --- a/src/plugin_system/apis/person_api.py +++ b/src/plugin_system/apis/person_api.py @@ -7,7 +7,7 @@ value = await person_api.get_person_value(person_id, "nickname") """ -from typing import Any +from typing import Any, Optional from src.common.logger import get_logger from src.person_info.person_info import get_person_info_manager, PersonInfoManager @@ -63,7 +63,7 @@ async def get_person_value(person_id: str, field_name: str, default: Any = None) return default -async def get_person_values(person_id: str, field_names: list, default_dict: dict = None) -> dict: +async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict: """批量获取用户信息字段值 Args: diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 5e0e3e4be..91e3266d5 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -91,7 +91,7 @@ async def _send_to_target( ) # 创建消息段 - message_segment = Seg(type=message_type, data=content) + message_segment = Seg(type=message_type, data=content) # type: ignore # 处理回复消息 anchor_message = None diff --git a/src/plugin_system/apis/utils_api.py b/src/plugin_system/apis/utils_api.py index 1e5858b3f..45996df5c 100644 --- a/src/plugin_system/apis/utils_api.py +++ b/src/plugin_system/apis/utils_api.py @@ -36,9 +36,9 @@ def get_plugin_path(caller_frame=None) -> str: """ try: if caller_frame is None: - caller_frame = inspect.currentframe().f_back + caller_frame = inspect.currentframe().f_back # type: ignore - plugin_module_path = inspect.getfile(caller_frame) + plugin_module_path = inspect.getfile(caller_frame) # type: ignore plugin_dir = os.path.dirname(plugin_module_path) return plugin_dir except Exception as e: @@ -59,7 +59,7 @@ def read_json_file(file_path: str, default: Any = None) -> Any: try: # 如果是相对路径,则相对于调用者的插件目录 if not os.path.isabs(file_path): - caller_frame = inspect.currentframe().f_back + caller_frame = inspect.currentframe().f_back # type: ignore plugin_dir = get_plugin_path(caller_frame) file_path = os.path.join(plugin_dir, file_path) @@ -88,7 +88,7 @@ def write_json_file(file_path: str, data: Any, indent: int = 2) -> bool: try: # 如果是相对路径,则相对于调用者的插件目录 if not os.path.isabs(file_path): - caller_frame = inspect.currentframe().f_back + caller_frame = inspect.currentframe().f_back # type: ignore plugin_dir = get_plugin_path(caller_frame) file_path = os.path.join(plugin_dir, file_path) @@ -117,7 +117,7 @@ def get_timestamp() -> int: return int(time.time()) -def format_time(timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str: +def format_time(timestamp: Optional[int | float] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str: """格式化时间 Args: diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index efd285f9d..95dddf0b1 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -108,8 +108,8 @@ class EmojiAction(BaseAction): models = llm_api.get_available_models() chat_model_config = getattr(models, "utils_small", None) # 默认使用chat模型 if not chat_model_config: - logger.error(f"{self.log_prefix} 未找到'chat'模型配置,无法调用LLM") - return False, "未找到'chat'模型配置" + logger.error(f"{self.log_prefix} 未找到'utils_small'模型配置,无法调用LLM") + return False, "未找到'utils_small'模型配置" success, chosen_emotion, _, _ = await llm_api.generate_with_model( prompt, model_config=chat_model_config, request_type="emoji"