api typing check

This commit is contained in:
UnCLASPrommer
2025-07-15 19:09:04 +08:00
parent 2fab069dca
commit 80a1c0bf93
8 changed files with 144 additions and 26 deletions

View File

@@ -182,6 +182,7 @@ class DefaultReplyer:
回复器 (Replier): 核心逻辑,负责生成回复文本。 回复器 (Replier): 核心逻辑,负责生成回复文本。
(已整合原 HeartFCGenerator 的功能) (已整合原 HeartFCGenerator 的功能)
""" """
prompt = None
if available_actions is None: if available_actions is None:
available_actions = {} available_actions = {}
if reply_data is None: if reply_data is None:
@@ -707,6 +708,7 @@ class DefaultReplyer:
) )
target_user_id = "" target_user_id = ""
person_id = ""
if sender: if sender:
# 根据sender通过person_info_manager反向查找person_id再获取user_id # 根据sender通过person_info_manager反向查找person_id再获取user_id
person_id = person_info_manager.get_person_id_by_person_name(sender) person_id = person_info_manager.get_person_id_by_person_name(sender)

View File

@@ -76,7 +76,7 @@ def get_raw_msg_by_timestamp_with_chat_users(
chat_id: str, chat_id: str,
timestamp_start: float, timestamp_start: float,
timestamp_end: float, timestamp_end: float,
person_ids: list, person_ids: List[str],
limit: int = 0, limit: int = 0,
limit_mode: str = "latest", limit_mode: str = "latest",
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:

View File

@@ -131,6 +131,9 @@ async def generate_reply(
else: else:
return success, reply_set, None return success, reply_set, None
except ValueError as ve:
raise ve
except Exception as e: except Exception as e:
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}") logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
return False, [], None return False, [], None
@@ -178,6 +181,9 @@ async def rewrite_reply(
return success, reply_set return success, reply_set
except ValueError as ve:
raise ve
except Exception as e: except Exception as e:
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}") logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
return False, [] return False, []
@@ -191,12 +197,14 @@ async def process_human_text(content: str, enable_splitter: bool, enable_chinese
enable_splitter: 是否启用消息分割器 enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器 enable_chinese_typo: 是否启用错字生成器
""" """
if not isinstance(content, str):
raise ValueError("content 必须是字符串类型")
try: try:
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo) processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
reply_set = [] reply_set = []
for str in processed_response: for text in processed_response:
reply_seg = ("text", str) reply_seg = ("text", text)
reply_set.append(reply_seg) reply_set.append(reply_seg)
return reply_set return reply_set

View File

@@ -48,8 +48,15 @@ def get_messages_by_time(
filter_mai: 是否过滤麦麦自身的消息默认为False filter_mai: 是否过滤麦麦自身的消息默认为False
Returns: Returns:
消息列表 List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
""" """
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai: if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)) return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode) return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
@@ -75,8 +82,19 @@ def get_messages_by_time_in_chat(
filter_mai: 是否过滤麦麦自身的消息默认为False filter_mai: 是否过滤麦麦自身的消息默认为False
Returns: Returns:
消息列表 List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
""" """
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
if filter_mai: if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode)) return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode) return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode)
@@ -102,8 +120,19 @@ def get_messages_by_time_in_chat_inclusive(
filter_mai: 是否过滤麦麦自身的消息默认为False filter_mai: 是否过滤麦麦自身的消息默认为False
Returns: Returns:
消息列表 List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
""" """
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
if filter_mai: if filter_mai:
return filter_mai_messages( return filter_mai_messages(
get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode) get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode)
@@ -115,7 +144,7 @@ def get_messages_by_time_in_chat_for_users(
chat_id: str, chat_id: str,
start_time: float, start_time: float,
end_time: float, end_time: float,
person_ids: list, person_ids: List[str],
limit: int = 0, limit: int = 0,
limit_mode: str = "latest", limit_mode: str = "latest",
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
@@ -131,8 +160,19 @@ def get_messages_by_time_in_chat_for_users(
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录 limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
Returns: Returns:
消息列表 List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
""" """
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode) return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode)
@@ -150,8 +190,15 @@ def get_random_chat_messages(
filter_mai: 是否过滤麦麦自身的消息默认为False filter_mai: 是否过滤麦麦自身的消息默认为False
Returns: Returns:
消息列表 List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
""" """
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai: if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)) return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode) return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)
@@ -171,8 +218,15 @@ def get_messages_by_time_for_users(
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录 limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
Returns: Returns:
消息列表 List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
""" """
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode) return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
@@ -186,8 +240,15 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool
filter_mai: 是否过滤麦麦自身的消息默认为False filter_mai: 是否过滤麦麦自身的消息默认为False
Returns: Returns:
消息列表 List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
""" """
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai: if filter_mai:
return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit)) return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit))
return get_raw_msg_before_timestamp(timestamp, limit) return get_raw_msg_before_timestamp(timestamp, limit)
@@ -206,8 +267,19 @@ def get_messages_before_time_in_chat(
filter_mai: 是否过滤麦麦自身的消息默认为False filter_mai: 是否过滤麦麦自身的消息默认为False
Returns: Returns:
消息列表 List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
""" """
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
if filter_mai: if filter_mai:
return filter_mai_messages(get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)) return filter_mai_messages(get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit))
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit) return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
@@ -223,8 +295,15 @@ def get_messages_before_time_for_users(timestamp: float, person_ids: list, limit
limit: 限制返回的消息数量0为不限制 limit: 限制返回的消息数量0为不限制
Returns: Returns:
消息列表 List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
""" """
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit) return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit)
@@ -242,8 +321,19 @@ def get_recent_messages(
filter_mai: 是否过滤麦麦自身的消息默认为False filter_mai: 是否过滤麦麦自身的消息默认为False
Returns: Returns:
消息列表 List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法s
""" """
if not isinstance(hours, (int, float)) or hours < 0:
raise ValueError("hours 不能是负数")
if not isinstance(limit, int) or limit < 0:
raise ValueError("limit 必须是非负整数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
now = time.time() now = time.time()
start_time = now - hours * 3600 start_time = now - hours * 3600
if filter_mai: if filter_mai:
@@ -266,8 +356,17 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
end_time: 结束时间戳如果为None则使用当前时间 end_time: 结束时间戳如果为None则使用当前时间
Returns: Returns:
新消息数量 int: 新消息数量
Raises:
ValueError: 如果参数不合法
""" """
if not isinstance(start_time, (int, float)):
raise ValueError("start_time 必须是数字类型")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since(chat_id, start_time, end_time) return num_new_messages_since(chat_id, start_time, end_time)
@@ -282,8 +381,17 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
person_ids: 用户ID列表 person_ids: 用户ID列表
Returns: Returns:
新消息数量 int: 新消息数量
Raises:
ValueError: 如果参数不合法
""" """
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids) return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids)

View File

@@ -7,7 +7,7 @@
value = await person_api.get_person_value(person_id, "nickname") value = await person_api.get_person_value(person_id, "nickname")
""" """
from typing import Any from typing import Any, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.person_info.person_info import get_person_info_manager, PersonInfoManager from src.person_info.person_info import get_person_info_manager, PersonInfoManager
@@ -63,7 +63,7 @@ async def get_person_value(person_id: str, field_name: str, default: Any = None)
return default return default
async def get_person_values(person_id: str, field_names: list, default_dict: dict = None) -> dict: async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict:
"""批量获取用户信息字段值 """批量获取用户信息字段值
Args: Args:

View File

@@ -91,7 +91,7 @@ async def _send_to_target(
) )
# 创建消息段 # 创建消息段
message_segment = Seg(type=message_type, data=content) message_segment = Seg(type=message_type, data=content) # type: ignore
# 处理回复消息 # 处理回复消息
anchor_message = None anchor_message = None

View File

@@ -36,9 +36,9 @@ def get_plugin_path(caller_frame=None) -> str:
""" """
try: try:
if caller_frame is None: if caller_frame is None:
caller_frame = inspect.currentframe().f_back caller_frame = inspect.currentframe().f_back # type: ignore
plugin_module_path = inspect.getfile(caller_frame) plugin_module_path = inspect.getfile(caller_frame) # type: ignore
plugin_dir = os.path.dirname(plugin_module_path) plugin_dir = os.path.dirname(plugin_module_path)
return plugin_dir return plugin_dir
except Exception as e: except Exception as e:
@@ -59,7 +59,7 @@ def read_json_file(file_path: str, default: Any = None) -> Any:
try: try:
# 如果是相对路径,则相对于调用者的插件目录 # 如果是相对路径,则相对于调用者的插件目录
if not os.path.isabs(file_path): if not os.path.isabs(file_path):
caller_frame = inspect.currentframe().f_back caller_frame = inspect.currentframe().f_back # type: ignore
plugin_dir = get_plugin_path(caller_frame) plugin_dir = get_plugin_path(caller_frame)
file_path = os.path.join(plugin_dir, file_path) file_path = os.path.join(plugin_dir, file_path)
@@ -88,7 +88,7 @@ def write_json_file(file_path: str, data: Any, indent: int = 2) -> bool:
try: try:
# 如果是相对路径,则相对于调用者的插件目录 # 如果是相对路径,则相对于调用者的插件目录
if not os.path.isabs(file_path): if not os.path.isabs(file_path):
caller_frame = inspect.currentframe().f_back caller_frame = inspect.currentframe().f_back # type: ignore
plugin_dir = get_plugin_path(caller_frame) plugin_dir = get_plugin_path(caller_frame)
file_path = os.path.join(plugin_dir, file_path) file_path = os.path.join(plugin_dir, file_path)
@@ -117,7 +117,7 @@ def get_timestamp() -> int:
return int(time.time()) return int(time.time())
def format_time(timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str: def format_time(timestamp: Optional[int | float] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
"""格式化时间 """格式化时间
Args: Args:

View File

@@ -108,8 +108,8 @@ class EmojiAction(BaseAction):
models = llm_api.get_available_models() models = llm_api.get_available_models()
chat_model_config = getattr(models, "utils_small", None) # 默认使用chat模型 chat_model_config = getattr(models, "utils_small", None) # 默认使用chat模型
if not chat_model_config: if not chat_model_config:
logger.error(f"{self.log_prefix} 未找到'chat'模型配置无法调用LLM") logger.error(f"{self.log_prefix} 未找到'utils_small'模型配置无法调用LLM")
return False, "未找到'chat'模型配置" return False, "未找到'utils_small'模型配置"
success, chosen_emotion, _, _ = await llm_api.generate_with_model( success, chosen_emotion, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="emoji" prompt, model_config=chat_model_config, request_type="emoji"