diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 3ad3fe4cf..a9214a9af 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -4,6 +4,7 @@ import asyncio import random import ast import re + from typing import List, Optional, Dict, Any, Tuple from datetime import datetime @@ -161,13 +162,13 @@ class DefaultReplyer: async def generate_reply_with_context( self, - reply_data: Dict[str, Any] = None, + reply_data: Optional[Dict[str, Any]] = None, reply_to: str = "", extra_info: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, enable_tool: bool = True, enable_timeout: bool = False, - ) -> Tuple[bool, Optional[str]]: + ) -> Tuple[bool, Optional[str], Optional[str]]: """ 回复器 (Replier): 核心逻辑,负责生成回复文本。 (已整合原 HeartFCGenerator 的功能) @@ -225,14 +226,14 @@ class DefaultReplyer: except Exception as llm_e: # 精简报错信息 logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}") - return False, None # LLM 调用失败则无法生成回复 + return False, None, prompt # LLM 调用失败则无法生成回复 return True, content, prompt except Exception as e: logger.error(f"{self.log_prefix}回复生成意外失败: {e}") traceback.print_exc() - return False, None + return False, None, prompt async def rewrite_reply_with_context( self, @@ -368,7 +369,7 @@ class DefaultReplyer: memory_str += f"- {running_memory['content']}\n" return memory_str - async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True): + async def build_tool_info(self, chat_history, reply_data: Optional[Dict], enable_tool: bool = True): """构建工具信息块 Args: @@ -393,7 +394,7 @@ class DefaultReplyer: try: # 使用工具执行器获取信息 - tool_results = await self.tool_executor.execute_from_chat_message( + tool_results, _, _ = await self.tool_executor.execute_from_chat_message( sender=sender, target_message=text, chat_history=chat_history, return_details=False ) @@ -468,7 +469,7 @@ class DefaultReplyer: async def build_prompt_reply_context( self, - reply_data=None, + reply_data: Dict[str, Any], available_actions: Optional[Dict[str, ActionInfo]] = None, enable_timeout: bool = False, enable_tool: bool = True, @@ -549,7 +550,7 @@ class DefaultReplyer: ), self._time_and_run_task(self.build_memory_block(chat_talking_prompt_half, target), "build_memory_block"), self._time_and_run_task( - self.build_tool_info(reply_data, chat_talking_prompt_half, enable_tool=enable_tool), "build_tool_info" + self.build_tool_info(chat_talking_prompt_half, reply_data, enable_tool=enable_tool), "build_tool_info" ), ) @@ -806,7 +807,7 @@ class DefaultReplyer: response_set: List[Tuple[str, str]], thinking_id: str = "", display_message: str = "", - ) -> Optional[MessageSending]: + ) -> Optional[List[Tuple[str, bool]]]: # sourcery skip: assign-if-exp, boolean-if-exp-identity, remove-unnecessary-cast """发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender""" chat = self.chat_stream @@ -869,7 +870,7 @@ class DefaultReplyer: try: if ( bot_message.is_private_message() - or bot_message.reply.processed_plain_text != "[System Trigger Context]" + or bot_message.reply.processed_plain_text != "[System Trigger Context]" # type: ignore or mark_head ): set_reply = False @@ -910,7 +911,7 @@ class DefaultReplyer: is_emoji: bool, thinking_start_time: float, display_message: str, - anchor_message: MessageRecv = None, + anchor_message: Optional[MessageRecv] = None, ) -> MessageSending: """构建单个发送消息""" diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index a2a2aaaa0..3f1c731b4 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -1,8 +1,8 @@ from typing import Dict, Any, Optional, List +from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.replyer.default_generator import DefaultReplyer -from src.common.logger import get_logger logger = get_logger("ReplyerManager") diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 8c579e6d3..6bdf7f58d 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -1,6 +1,7 @@ import time # 导入 time 模块以获取当前时间 import random import re + from typing import List, Dict, Any, Tuple, Optional from rich.traceback import install @@ -88,8 +89,8 @@ def get_actions_by_timestamp_with_chat( """获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表""" query = ActionRecords.select().where( (ActionRecords.chat_id == chat_id) - & (ActionRecords.time > timestamp_start) - & (ActionRecords.time < timestamp_end) + & (ActionRecords.time > timestamp_start) # type: ignore + & (ActionRecords.time < timestamp_end) # type: ignore ) if limit > 0: @@ -113,8 +114,8 @@ def get_actions_by_timestamp_with_chat_inclusive( """获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表""" query = ActionRecords.select().where( (ActionRecords.chat_id == chat_id) - & (ActionRecords.time >= timestamp_start) - & (ActionRecords.time <= timestamp_end) + & (ActionRecords.time >= timestamp_start) # type: ignore + & (ActionRecords.time <= timestamp_end) # type: ignore ) if limit > 0: @@ -331,7 +332,7 @@ def _build_readable_messages_internal( if replace_bot_name and user_id == global_config.bot.qq_account: person_name = f"{global_config.bot.nickname}(你)" else: - person_name = person_info_manager.get_value_sync(person_id, "person_name") + person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore # 如果 person_name 未设置,则使用消息中的 nickname 或默认名称 if not person_name: @@ -911,8 +912,8 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: person_ids_set = set() # 使用集合来自动去重 for msg in messages: - platform = msg.get("user_platform") - user_id = msg.get("user_id") + platform: str = msg.get("user_platform") # type: ignore + user_id: str = msg.get("user_id") # type: ignore # 检查必要信息是否存在 且 不是机器人自己 if not all([platform, user_id]) or user_id == global_config.bot.qq_account: diff --git a/src/chat/utils/json_utils.py b/src/chat/utils/json_utils.py index 6226e6e96..892deac4f 100644 --- a/src/chat/utils/json_utils.py +++ b/src/chat/utils/json_utils.py @@ -1,7 +1,8 @@ +import ast import json import logging -from typing import Any, Dict, TypeVar, List, Union, Tuple -import ast + +from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional # 定义类型变量用于泛型类型提示 T = TypeVar("T") @@ -30,18 +31,14 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]: # 尝试标准的 JSON 解析 return json.loads(json_str) except json.JSONDecodeError: - # 如果标准解析失败,尝试将单引号替换为双引号再解析 - # (注意:这种替换可能不安全,如果字符串内容本身包含引号) - # 更安全的方式是用 ast.literal_eval + # 如果标准解析失败,尝试用 ast.literal_eval 解析 try: # logger.debug(f"标准JSON解析失败,尝试用 ast.literal_eval 解析: {json_str[:100]}...") result = ast.literal_eval(json_str) - # 确保结果是字典(因为我们通常期望参数是字典) if isinstance(result, dict): return result - else: - logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}") - return default_value + logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}") + return default_value except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e: logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...") return default_value @@ -53,7 +50,9 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]: return default_value -def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]: +def extract_tool_call_arguments( + tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: """ 从LLM工具调用对象中提取参数 @@ -77,14 +76,12 @@ def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[s logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}") return default_result - # 提取arguments - arguments_str = function_data.get("arguments", "{}") - if not arguments_str: + if arguments_str := function_data.get("arguments", "{}"): + # 解析JSON + return safe_json_loads(arguments_str, default_result) + else: return default_result - # 解析JSON - return safe_json_loads(arguments_str, default_result) - except Exception as e: logger.error(f"提取工具调用参数时出错: {e}") return default_result diff --git a/src/chat/utils/prompt_builder.py b/src/chat/utils/prompt_builder.py index 26f8ffbad..1b107904c 100644 --- a/src/chat/utils/prompt_builder.py +++ b/src/chat/utils/prompt_builder.py @@ -1,12 +1,12 @@ -from typing import Dict, Any, Optional, List, Union import re -from contextlib import asynccontextmanager import asyncio import contextvars -from src.common.logger import get_logger -# import traceback from rich.traceback import install +from contextlib import asynccontextmanager +from typing import Dict, Any, Optional, List, Union + +from src.common.logger import get_logger install(extra_lines=3) @@ -32,6 +32,7 @@ class PromptContext: @asynccontextmanager async def async_scope(self, context_id: Optional[str] = None): + # sourcery skip: hoist-statement-from-if, use-contextlib-suppress """创建一个异步的临时提示模板作用域""" # 保存当前上下文并设置新上下文 if context_id is not None: @@ -88,8 +89,7 @@ class PromptContext: async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None: """异步注册提示模板到指定作用域""" async with self._context_lock: - target_context = context_id or self._current_context - if target_context: + if target_context := context_id or self._current_context: self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt @@ -151,7 +151,7 @@ class Prompt(str): @staticmethod def _process_escaped_braces(template) -> str: - """处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记""" + """处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记""" # type: ignore # 如果传入的是列表,将其转换为字符串 if isinstance(template, list): template = "\n".join(str(item) for item in template) @@ -195,14 +195,8 @@ class Prompt(str): obj._kwargs = kwargs # 修改自动注册逻辑 - if should_register: - if global_prompt_manager._context._current_context: - # 如果存在当前上下文,则注册到上下文中 - # asyncio.create_task(global_prompt_manager._context.register_async(obj)) - pass - else: - # 否则注册到全局管理器 - global_prompt_manager.register(obj) + if should_register and not global_prompt_manager._context._current_context: + global_prompt_manager.register(obj) return obj @classmethod @@ -276,15 +270,13 @@ class Prompt(str): self.name, args=list(args) if args else self._args, _should_register=False, - **kwargs if kwargs else self._kwargs, + **kwargs or self._kwargs, ) # print(f"prompt build result: {ret} name: {ret.name} ") return str(ret) def __str__(self) -> str: - if self._kwargs or self._args: - return super().__str__() - return self.template + return super().__str__() if self._kwargs or self._args else self.template def __repr__(self) -> str: return f"Prompt(template='{self.template}', name='{self.name}')" diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 25d231c01..4e0edd31f 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -1,18 +1,17 @@ -from collections import defaultdict -from datetime import datetime, timedelta -from typing import Any, Dict, Tuple, List import asyncio import concurrent.futures import json import os import glob +from collections import defaultdict +from datetime import datetime, timedelta +from typing import Any, Dict, Tuple, List from src.common.logger import get_logger +from src.common.database.database import db +from src.common.database.database_model import OnlineTime, LLMUsage, Messages from src.manager.async_task_manager import AsyncTask - -from ...common.database.database import db # This db is the Peewee database instance -from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model from src.manager.local_store_manager import local_storage logger = get_logger("maibot_statistic") @@ -76,14 +75,14 @@ class OnlineTimeRecordTask(AsyncTask): with db.atomic(): # Use atomic operations for schema changes OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model - async def run(self): + async def run(self): # sourcery skip: use-named-expression try: current_time = datetime.now() extended_end_time = current_time + timedelta(minutes=1) if self.record_id: # 如果有记录,则更新结束时间 - query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) + query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore updated_rows = query.execute() if updated_rows == 0: # Record might have been deleted or ID is stale, try to find/create @@ -94,7 +93,7 @@ class OnlineTimeRecordTask(AsyncTask): # Look for a record whose end_timestamp is recent enough to be considered ongoing recent_record = ( OnlineTime.select() - .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) + .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore .order_by(OnlineTime.end_timestamp.desc()) .first() ) @@ -123,15 +122,15 @@ def _format_online_time(online_seconds: int) -> str: :param online_seconds: 在线时间(秒) :return: 格式化后的在线时间字符串 """ - total_oneline_time = timedelta(seconds=online_seconds) + total_online_time = timedelta(seconds=online_seconds) - days = total_oneline_time.days - hours = total_oneline_time.seconds // 3600 - minutes = (total_oneline_time.seconds // 60) % 60 - seconds = total_oneline_time.seconds % 60 + days = total_online_time.days + hours = total_online_time.seconds // 3600 + minutes = (total_online_time.seconds // 60) % 60 + seconds = total_online_time.seconds % 60 if days > 0: # 如果在线时间超过1天,则格式化为"X天X小时X分钟" - return f"{total_oneline_time.days}天{hours}小时{minutes}分钟{seconds}秒" + return f"{total_online_time.days}天{hours}小时{minutes}分钟{seconds}秒" elif hours > 0: # 如果在线时间超过1小时,则格式化为"X小时X分钟X秒" return f"{hours}小时{minutes}分钟{seconds}秒" @@ -163,7 +162,7 @@ class StatisticOutputTask(AsyncTask): now = datetime.now() if "deploy_time" in local_storage: # 如果存在部署时间,则使用该时间作为全量统计的起始时间 - deploy_time = datetime.fromtimestamp(local_storage["deploy_time"]) + deploy_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore else: # 否则,使用最大时间范围,并记录部署时间为当前时间 deploy_time = datetime(2000, 1, 1) @@ -252,7 +251,7 @@ class StatisticOutputTask(AsyncTask): # 创建后台任务,不等待完成 collect_task = asyncio.create_task( - loop.run_in_executor(executor, self._collect_all_statistics, now) + loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore ) stats = await collect_task @@ -260,8 +259,8 @@ class StatisticOutputTask(AsyncTask): # 创建并发的输出任务 output_tasks = [ - asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), - asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), + asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore + asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore ] # 等待所有输出任务完成 @@ -320,7 +319,7 @@ class StatisticOutputTask(AsyncTask): # 以最早的时间戳为起始时间获取记录 # Assuming LLMUsage.timestamp is a DateTimeField query_start_time = collect_period[-1][1] - for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): + for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore record_timestamp = record.timestamp # This is already a datetime object for idx, (_, period_start) in enumerate(collect_period): if record_timestamp >= period_start: @@ -388,7 +387,7 @@ class StatisticOutputTask(AsyncTask): query_start_time = collect_period[-1][1] # Assuming OnlineTime.end_timestamp is a DateTimeField - for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): + for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore # record.end_timestamp and record.start_timestamp are datetime objects record_end_timestamp = record.end_timestamp record_start_timestamp = record.start_timestamp @@ -428,7 +427,7 @@ class StatisticOutputTask(AsyncTask): } query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) - for message in Messages.select().where(Messages.time >= query_start_timestamp): + for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore message_time_ts = message.time # This is a float timestamp chat_id = None @@ -661,7 +660,7 @@ class StatisticOutputTask(AsyncTask): if "last_full_statistics" in local_storage: # 如果存在上次完整统计数据,则使用该数据进行增量统计 - last_stat = local_storage["last_full_statistics"] # 上次完整统计数据 + last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射 last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据 @@ -727,6 +726,7 @@ class StatisticOutputTask(AsyncTask): return stat def _convert_defaultdict_to_dict(self, data): + # sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks """递归转换defaultdict为普通dict""" if isinstance(data, defaultdict): # 转换defaultdict为普通dict @@ -812,8 +812,7 @@ class StatisticOutputTask(AsyncTask): # 全局阶段平均时间 if stats[FOCUS_AVG_TIMES_BY_STAGE]: output.append("全局阶段平均时间:") - for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items(): - output.append(f" {stage}: {avg_time:.3f}秒") + output.extend(f" {stage}: {avg_time:.3f}秒" for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items()) output.append("") # Action类型比例 @@ -1050,7 +1049,7 @@ class StatisticOutputTask(AsyncTask): ] tab_content_list.append( - _format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) + _format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) # type: ignore ) # 添加Focus统计内容 @@ -1212,6 +1211,7 @@ class StatisticOutputTask(AsyncTask): f.write(html_template) def _generate_focus_tab(self, stat: dict[str, Any]) -> str: + # sourcery skip: for-append-to-extend, list-comprehension, use-any """生成Focus统计独立分页的HTML内容""" # 为每个时间段准备Focus数据 @@ -1313,12 +1313,11 @@ class StatisticOutputTask(AsyncTask): # 聊天流Action选择比例对比表(横向表格) focus_chat_action_ratios_rows = "" if stat_data.get("focus_action_ratios_by_chat"): - # 获取所有action类型(按全局频率排序) - all_action_types_for_ratio = sorted( - stat_data[FOCUS_ACTION_RATIOS].keys(), key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x], reverse=True - ) - - if all_action_types_for_ratio: + if all_action_types_for_ratio := sorted( + stat_data[FOCUS_ACTION_RATIOS].keys(), + key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x], + reverse=True, + ): # 为每个聊天流生成数据行(按循环数排序) chat_ratio_rows = [] for chat_id in sorted( @@ -1379,16 +1378,11 @@ class StatisticOutputTask(AsyncTask): if period_name == "all_time": from src.manager.local_store_manager import local_storage - start_time = datetime.fromtimestamp(local_storage["deploy_time"]) - time_range = ( - f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - ) + start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore else: start_time = datetime.now() - period_delta - time_range = ( - f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - ) + time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" # 生成该时间段的Focus统计HTML section_html = f"""
@@ -1681,16 +1675,10 @@ class StatisticOutputTask(AsyncTask): if period_name == "all_time": from src.manager.local_store_manager import local_storage - start_time = datetime.fromtimestamp(local_storage["deploy_time"]) - time_range = ( - f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - ) + start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore else: start_time = datetime.now() - period_delta - time_range = ( - f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - ) - + time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" # 生成该时间段的版本对比HTML section_html = f"""
@@ -1865,7 +1853,7 @@ class StatisticOutputTask(AsyncTask): # 查询LLM使用记录 query_start_time = start_time - for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): + for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore record_time = record.timestamp # 找到对应的时间间隔索引 @@ -1875,7 +1863,7 @@ class StatisticOutputTask(AsyncTask): if 0 <= interval_index < len(time_points): # 累加总花费数据 cost = record.cost or 0.0 - total_cost_data[interval_index] += cost + total_cost_data[interval_index] += cost # type: ignore # 累加按模型分类的花费 model_name = record.model_name or "unknown" @@ -1892,7 +1880,7 @@ class StatisticOutputTask(AsyncTask): # 查询消息记录 query_start_timestamp = start_time.timestamp() - for message in Messages.select().where(Messages.time >= query_start_timestamp): + for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore message_time_ts = message.time # 找到对应的时间间隔索引 @@ -1982,6 +1970,7 @@ class StatisticOutputTask(AsyncTask): } def _generate_chart_tab(self, chart_data: dict) -> str: + # sourcery skip: extract-duplicate-method, move-assign-in-block """生成图表选项卡HTML内容""" # 生成不同颜色的调色板 @@ -2293,7 +2282,7 @@ class AsyncStatisticOutputTask(AsyncTask): # 数据收集任务 collect_task = asyncio.create_task( - loop.run_in_executor(executor, self._collect_all_statistics, now) + loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore ) stats = await collect_task @@ -2301,8 +2290,8 @@ class AsyncStatisticOutputTask(AsyncTask): # 创建并发的输出任务 output_tasks = [ - asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), - asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), + asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore + asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore ] # 等待所有输出任务完成 diff --git a/src/chat/utils/timer_calculator.py b/src/chat/utils/timer_calculator.py index df2b9f778..d9479af16 100644 --- a/src/chat/utils/timer_calculator.py +++ b/src/chat/utils/timer_calculator.py @@ -1,7 +1,8 @@ +import asyncio + from time import perf_counter from functools import wraps from typing import Optional, Dict, Callable -import asyncio from rich.traceback import install install(extra_lines=3) @@ -88,10 +89,10 @@ class Timer: self.name = name self.storage = storage - self.elapsed = None + self.elapsed: float = None # type: ignore self.auto_unit = auto_unit - self.start = None + self.start: float = None # type: ignore @staticmethod def _validate_types(name, storage): @@ -120,7 +121,7 @@ class Timer: return None wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper - wrapper.__timer__ = self # 保留计时器引用 + wrapper.__timer__ = self # 保留计时器引用 # type: ignore return wrapper def __enter__(self): diff --git a/src/chat/utils/typo_generator.py b/src/chat/utils/typo_generator.py index 7c373f132..4de219464 100644 --- a/src/chat/utils/typo_generator.py +++ b/src/chat/utils/typo_generator.py @@ -7,10 +7,10 @@ import math import os import random import time +import jieba + from collections import defaultdict from pathlib import Path - -import jieba from pypinyin import Style, pinyin from src.common.logger import get_logger @@ -104,7 +104,7 @@ class ChineseTypoGenerator: try: return "\u4e00" <= char <= "\u9fff" except Exception as e: - logger.debug(e) + logger.debug(str(e)) return False def _get_pinyin(self, sentence): @@ -138,7 +138,7 @@ class ChineseTypoGenerator: # 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况 if not py[-1].isdigit(): # 为非数字结尾的拼音添加数字声调1 - return py + "1" + return f"{py}1" base = py[:-1] # 去掉声调 tone = int(py[-1]) # 获取声调 diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index f3226b2e1..2fbc69559 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -1,23 +1,21 @@ import random import re import time -from collections import Counter - import jieba import numpy as np + +from collections import Counter from maim_message import UserInfo +from typing import Optional, Tuple, Dict from src.common.logger import get_logger - -# from src.mood.mood_manager import mood_manager -from ..message_receive.message import MessageRecv -from src.llm_models.utils_model import LLMRequest -from .typo_generator import ChineseTypoGenerator -from ...config.config import global_config -from ...common.message_repository import find_messages, count_messages -from typing import Optional, Tuple, Dict +from src.common.message_repository import find_messages, count_messages +from src.config.config import global_config +from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.chat_stream import get_chat_manager +from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import PersonInfoManager, get_person_info_manager +from .typo_generator import ChineseTypoGenerator logger = get_logger("chat_utils") @@ -31,11 +29,7 @@ def db_message_to_str(message_dict: dict) -> str: logger.debug(f"message_dict: {message_dict}") time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"])) try: - name = "[(%s)%s]%s" % ( - message_dict["user_id"], - message_dict.get("user_nickname", ""), - message_dict.get("user_cardname", ""), - ) + name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}" except Exception: name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}" content = message_dict.get("processed_plain_text", "") @@ -58,11 +52,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: and message.message_info.additional_config.get("is_mentioned") is not None ): try: - reply_probability = float(message.message_info.additional_config.get("is_mentioned")) + reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore is_mentioned = True return is_mentioned, reply_probability except Exception as e: - logger.warning(e) + logger.warning(str(e)) logger.warning( f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}" ) @@ -135,20 +129,17 @@ def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, c if not recent_messages: return [] - message_detailed_plain_text = "" - message_detailed_plain_text_list = [] - # 反转消息列表,使最新的消息在最后 recent_messages.reverse() if combine: - for msg_db_data in recent_messages: - message_detailed_plain_text += str(msg_db_data["detailed_plain_text"]) - return message_detailed_plain_text - else: - for msg_db_data in recent_messages: - message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"]) - return message_detailed_plain_text_list + return "".join(str(msg_db_data["detailed_plain_text"]) for msg_db_data in recent_messages) + + message_detailed_plain_text_list = [] + + for msg_db_data in recent_messages: + message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"]) + return message_detailed_plain_text_list def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list: @@ -204,10 +195,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]: len_text = len(text) if len_text < 3: - if random.random() < 0.01: - return list(text) # 如果文本很短且触发随机条件,直接按字符分割 - else: - return [text] + return list(text) if random.random() < 0.01 else [text] # 定义分隔符 separators = {",", ",", " ", "。", ";"} @@ -352,10 +340,9 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese max_length = global_config.response_splitter.max_length * 2 max_sentence_num = global_config.response_splitter.max_sentence_num # 如果基本上是中文,则进行长度过滤 - if get_western_ratio(cleaned_text) < 0.1: - if len(cleaned_text) > max_length: - logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复") - return ["懒得说"] + if get_western_ratio(cleaned_text) < 0.1 and len(cleaned_text) > max_length: + logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复") + return ["懒得说"] typo_generator = ChineseTypoGenerator( error_rate=global_config.chinese_typo.error_rate, @@ -420,7 +407,7 @@ def calculate_typing_time( # chinese_time *= 1 / typing_speed_multiplier # english_time *= 1 / typing_speed_multiplier # 计算中文字符数 - chinese_chars = sum(1 for char in input_string if "\u4e00" <= char <= "\u9fff") + chinese_chars = sum("\u4e00" <= char <= "\u9fff" for char in input_string) # 如果只有一个中文字符,使用3倍时间 if chinese_chars == 1 and len(input_string.strip()) == 1: @@ -429,11 +416,7 @@ def calculate_typing_time( # 正常计算所有字符的输入时间 total_time = 0.0 for char in input_string: - if "\u4e00" <= char <= "\u9fff": # 判断是否为中文字符 - total_time += chinese_time - else: # 其他字符(如英文) - total_time += english_time - + total_time += chinese_time if "\u4e00" <= char <= "\u9fff" else english_time if is_emoji: total_time = 1 @@ -453,18 +436,14 @@ def cosine_similarity(v1, v2): dot_product = np.dot(v1, v2) norm1 = np.linalg.norm(v1) norm2 = np.linalg.norm(v2) - if norm1 == 0 or norm2 == 0: - return 0 - return dot_product / (norm1 * norm2) + return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2) def text_to_vector(text): """将文本转换为词频向量""" # 分词 words = jieba.lcut(text) - # 统计词频 - word_freq = Counter(words) - return word_freq + return Counter(words) def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list: @@ -491,9 +470,7 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list: def truncate_message(message: str, max_length=20) -> str: """截断消息,使其不超过指定长度""" - if len(message) > max_length: - return message[:max_length] + "..." - return message + return f"{message[:max_length]}..." if len(message) > max_length else message def protect_kaomoji(sentence): @@ -522,7 +499,7 @@ def protect_kaomoji(sentence): placeholder_to_kaomoji = {} for idx, match in enumerate(kaomoji_matches): - kaomoji = match[0] if match[0] else match[1] + kaomoji = match[0] or match[1] placeholder = f"__KAOMOJI_{idx}__" sentence = sentence.replace(kaomoji, placeholder, 1) placeholder_to_kaomoji[placeholder] = kaomoji @@ -563,7 +540,7 @@ def get_western_ratio(paragraph): if not alnum_chars: return 0.0 - western_count = sum(1 for char in alnum_chars if is_english_letter(char)) + western_count = sum(bool(is_english_letter(char)) for char in alnum_chars) return western_count / len(alnum_chars) @@ -610,6 +587,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) - def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str: + # sourcery skip: merge-comparisons, merge-duplicate-blocks, switch """将时间戳转换为人类可读的时间格式 Args: @@ -621,7 +599,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal" """ if mode == "normal": return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) - if mode == "normal_no_YMD": + elif mode == "normal_no_YMD": return time.strftime("%H:%M:%S", time.localtime(timestamp)) elif mode == "relative": now = time.time() @@ -640,7 +618,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal" else: return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":" else: # mode = "lite" or unknown - # 只返回时分秒格式,喵~ + # 只返回时分秒格式 return time.strftime("%H:%M:%S", time.localtime(timestamp)) @@ -670,8 +648,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: elif chat_stream.user_info: # It's a private chat is_group_chat = False user_info = chat_stream.user_info - platform = chat_stream.platform - user_id = user_info.user_id + platform: str = chat_stream.platform # type: ignore + user_id: str = user_info.user_id # type: ignore # Initialize target_info with basic info target_info = { diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 5579ccf84..d5fa301bb 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -3,21 +3,20 @@ import os import time import hashlib import uuid +import io +import asyncio +import numpy as np + from typing import Optional, Tuple from PIL import Image -import io -import numpy as np -import asyncio - +from rich.traceback import install +from src.common.logger import get_logger from src.common.database.database import db from src.common.database.database_model import Images, ImageDescriptions from src.config.config import global_config from src.llm_models.utils_model import LLMRequest -from src.common.logger import get_logger -from rich.traceback import install - install(extra_lines=3) logger = get_logger("chat_image") @@ -111,7 +110,7 @@ class ImageManager: return f"[表情包,含义看起来是:{cached_description}]" # 调用AI获取描述 - if image_format == "gif" or image_format == "GIF": + if image_format in ["gif", "GIF"]: image_base64_processed = self.transform_gif(image_base64) if image_base64_processed is None: logger.warning("GIF转换失败,无法获取描述") @@ -258,6 +257,7 @@ class ImageManager: @staticmethod def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]: + # sourcery skip: use-contextlib-suppress """将GIF转换为水平拼接的静态图像, 跳过相似的帧 Args: @@ -351,7 +351,7 @@ class ImageManager: # 创建拼接图像 total_width = target_width * len(resized_frames) # 防止总宽度为0 - if total_width == 0 and len(resized_frames) > 0: + if total_width == 0 and resized_frames: logger.warning("计算出的总宽度为0,但有选中帧,可能目标宽度太小") # 至少给点宽度吧 total_width = len(resized_frames) @@ -368,10 +368,7 @@ class ImageManager: # 转换为base64 buffer = io.BytesIO() combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG - result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") - - return result_base64 - + return base64.b64encode(buffer.getvalue()).decode("utf-8") except MemoryError: logger.error("GIF转换失败: 内存不足,可能是GIF太大或帧数太多") return None # 内存不够啦 @@ -380,6 +377,7 @@ class ImageManager: return None # 其他错误也返回None async def process_image(self, image_base64: str) -> Tuple[str, str]: + # sourcery skip: hoist-if-from-if """处理图片并返回图片ID和描述 Args: @@ -418,17 +416,9 @@ class ImageManager: if existing_image.vlm_processed is None: existing_image.vlm_processed = False - existing_image.count += 1 - existing_image.save() - return existing_image.image_id, f"[picid:{existing_image.image_id}]" - else: - # print(f"图片已存在: {existing_image.image_id}") - # print(f"图片描述: {existing_image.description}") - # print(f"图片计数: {existing_image.count}") - # 更新计数 - existing_image.count += 1 - existing_image.save() - return existing_image.image_id, f"[picid:{existing_image.image_id}]" + existing_image.count += 1 + existing_image.save() + return existing_image.image_id, f"[picid:{existing_image.image_id}]" else: # print(f"图片不存在: {image_hash}") image_id = str(uuid.uuid4()) diff --git a/src/common/database/database.py b/src/common/database/database.py index 249664155..ca3614816 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -54,11 +54,11 @@ class DBWrapper: return getattr(get_db(), name) def __getitem__(self, key): - return get_db()[key] + return get_db()[key] # type: ignore # 全局数据库访问点 -memory_db: Database = DBWrapper() +memory_db: Database = DBWrapper() # type: ignore # 定义数据库文件路径 ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 3485fedeb..b411e1b3a 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -406,9 +406,7 @@ def initialize_database(): existing_columns = {row[1] for row in cursor.fetchall()} model_fields = set(model._meta.fields.keys()) - # 检查并添加缺失字段(原有逻辑) - missing_fields = model_fields - existing_columns - if missing_fields: + if missing_fields := model_fields - existing_columns: logger.warning(f"表 '{table_name}' 缺失字段: {missing_fields}") for field_name, field_obj in model._meta.fields.items(): @@ -424,10 +422,7 @@ def initialize_database(): "DateTimeField": "DATETIME", }.get(field_type, "TEXT") alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}" - if field_obj.null: - alter_sql += " NULL" - else: - alter_sql += " NOT NULL" + alter_sql += " NULL" if field_obj.null else " NOT NULL" if hasattr(field_obj, "default") and field_obj.default is not None: # 正确处理不同类型的默认值 default_value = field_obj.default diff --git a/src/common/logger.py b/src/common/logger.py index 40fd15070..a235cf341 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -1,16 +1,16 @@ -import logging - # 使用基于时间戳的文件处理器,简单的轮转份数限制 -from pathlib import Path -from typing import Callable, Optional + +import logging import json import threading import time -from datetime import datetime, timedelta - import structlog import toml +from pathlib import Path +from typing import Callable, Optional +from datetime import datetime, timedelta + # 创建logs目录 LOG_DIR = Path("logs") LOG_DIR.mkdir(exist_ok=True) @@ -160,7 +160,7 @@ def close_handlers(): _console_handler = None -def remove_duplicate_handlers(): +def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-comprehension """移除重复的handler,特别是文件handler""" root_logger = logging.getLogger() @@ -184,7 +184,7 @@ def remove_duplicate_handlers(): # 读取日志配置 -def load_log_config(): +def load_log_config(): # sourcery skip: use-contextlib-suppress """从配置文件加载日志设置""" config_path = Path("config/bot_config.toml") default_config = { @@ -365,7 +365,7 @@ MODULE_COLORS = { "component_registry": "\033[38;5;214m", # 橙黄色 "stream_api": "\033[38;5;220m", # 黄色 "config_api": "\033[38;5;226m", # 亮黄色 - "hearflow_api": "\033[38;5;154m", # 黄绿色 + "heartflow_api": "\033[38;5;154m", # 黄绿色 "action_apis": "\033[38;5;118m", # 绿色 "independent_apis": "\033[38;5;82m", # 绿色 "llm_api": "\033[38;5;46m", # 亮绿色 @@ -412,6 +412,7 @@ class ModuleColoredConsoleRenderer: """自定义控制台渲染器,为不同模块提供不同颜色""" def __init__(self, colors=True): + # sourcery skip: merge-duplicate-blocks, remove-redundant-if self._colors = colors self._config = LOG_CONFIG @@ -443,6 +444,7 @@ class ModuleColoredConsoleRenderer: self._enable_full_content_colors = False def __call__(self, logger, method_name, event_dict): + # sourcery skip: merge-duplicate-blocks """渲染日志消息""" # 获取基本信息 timestamp = event_dict.get("timestamp", "") @@ -662,7 +664,7 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger: """获取logger实例,支持按名称绑定""" if name is None: return raw_logger - logger = binds.get(name) + logger = binds.get(name) # type: ignore if logger is None: logger: structlog.stdlib.BoundLogger = structlog.get_logger(name).bind(logger_name=name) binds[name] = logger @@ -671,8 +673,8 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger: def configure_logging( level: str = "INFO", - console_level: str = None, - file_level: str = None, + console_level: Optional[str] = None, + file_level: Optional[str] = None, max_bytes: int = 5 * 1024 * 1024, backup_count: int = 30, log_dir: str = "logs", @@ -729,14 +731,11 @@ def reload_log_config(): global LOG_CONFIG LOG_CONFIG = load_log_config() - # 重新设置handler的日志级别 - file_handler = get_file_handler() - if file_handler: + if file_handler := get_file_handler(): file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO")) file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO)) - console_handler = get_console_handler() - if console_handler: + if console_handler := get_console_handler(): console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO")) console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO)) @@ -780,8 +779,7 @@ def set_console_log_level(level: str): global LOG_CONFIG LOG_CONFIG["console_log_level"] = level.upper() - console_handler = get_console_handler() - if console_handler: + if console_handler := get_console_handler(): console_handler.setLevel(getattr(logging, level.upper(), logging.INFO)) # 重新设置root logger级别 @@ -800,8 +798,7 @@ def set_file_log_level(level: str): global LOG_CONFIG LOG_CONFIG["file_log_level"] = level.upper() - file_handler = get_file_handler() - if file_handler: + if file_handler := get_file_handler(): file_handler.setLevel(getattr(logging, level.upper(), logging.INFO)) # 重新设置root logger级别 @@ -933,13 +930,12 @@ def format_json_for_logging(data, indent=2, ensure_ascii=False): Returns: str: 格式化后的JSON字符串 """ - if isinstance(data, str): - # 如果是JSON字符串,先解析再格式化 - parsed_data = json.loads(data) - return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii) - else: + if not isinstance(data, str): # 如果是对象,直接格式化 return json.dumps(data, indent=indent, ensure_ascii=ensure_ascii) + # 如果是JSON字符串,先解析再格式化 + parsed_data = json.loads(data) + return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii) def cleanup_old_logs(): diff --git a/src/common/message/api.py b/src/common/message/api.py index 59ba9d1e2..eed85c0a9 100644 --- a/src/common/message/api.py +++ b/src/common/message/api.py @@ -8,7 +8,7 @@ from src.config.config import global_config global_api = None -def get_global_api() -> MessageServer: +def get_global_api() -> MessageServer: # sourcery skip: extract-method """获取全局MessageServer实例""" global global_api if global_api is None: @@ -36,9 +36,8 @@ def get_global_api() -> MessageServer: kwargs["custom_logger"] = maim_message_logger # 添加token认证 - if maim_message_config.auth_token: - if len(maim_message_config.auth_token) > 0: - kwargs["enable_token"] = True + if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0: + kwargs["enable_token"] = True if maim_message_config.use_custom: # 添加WSS模式支持 diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 107ee1c5e..dc5d8b7df 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,9 +1,11 @@ -from src.common.database.database_model import Messages # 更改导入 -from src.common.logger import get_logger import traceback + from typing import List, Any, Optional from peewee import Model # 添加 Peewee Model 导入 +from src.common.database.database_model import Messages +from src.common.logger import get_logger + logger = get_logger(__name__) diff --git a/src/common/remote.py b/src/common/remote.py index 955e760b0..5380cd01e 100644 --- a/src/common/remote.py +++ b/src/common/remote.py @@ -23,7 +23,7 @@ class TelemetryHeartBeatTask(AsyncTask): self.server_url = TELEMETRY_SERVER_URL """遥测服务地址""" - self.client_uuid = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None + self.client_uuid: str | None = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None # type: ignore """客户端UUID""" self.info_dict = self._get_sys_info() @@ -72,7 +72,7 @@ class TelemetryHeartBeatTask(AsyncTask): timeout=aiohttp.ClientTimeout(total=5), # 设置超时时间为5秒 ) as response: logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client") - logger.debug(local_storage["deploy_time"]) + logger.debug(local_storage["deploy_time"]) # type: ignore logger.debug(f"Response status: {response.status}") if response.status == 200: @@ -93,7 +93,7 @@ class TelemetryHeartBeatTask(AsyncTask): except Exception as e: import traceback - error_msg = str(e) if str(e) else "未知错误" + error_msg = str(e) or "未知错误" logger.warning( f"请求UUID出错,不过你还是可以正常使用麦麦: {type(e).__name__}: {error_msg}" ) # 可能是网络问题 @@ -114,11 +114,11 @@ class TelemetryHeartBeatTask(AsyncTask): """向服务器发送心跳""" headers = { "Client-UUID": self.client_uuid, - "User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}", + "User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}", # type: ignore } logger.debug(f"正在发送心跳到服务器: {self.server_url}") - logger.debug(headers) + logger.debug(str(headers)) try: async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session: @@ -151,7 +151,7 @@ class TelemetryHeartBeatTask(AsyncTask): except Exception as e: import traceback - error_msg = str(e) if str(e) else "未知错误" + error_msg = str(e) or "未知错误" logger.warning(f"(此消息不会影响正常使用)状态未发生: {type(e).__name__}: {error_msg}") logger.debug(f"完整错误信息: {traceback.format_exc()}") diff --git a/src/config/auto_update.py b/src/config/auto_update.py index 2088e3628..139003a84 100644 --- a/src/config/auto_update.py +++ b/src/config/auto_update.py @@ -1,5 +1,6 @@ import shutil import tomlkit +from tomlkit.items import Table from pathlib import Path from datetime import datetime @@ -45,8 +46,8 @@ def update_config(): # 检查version是否相同 if old_config and "inner" in old_config and "inner" in new_config: - old_version = old_config["inner"].get("version") - new_version = new_config["inner"].get("version") + old_version = old_config["inner"].get("version") # type: ignore + new_version = new_config["inner"].get("version") # type: ignore if old_version and new_version and old_version == new_version: print(f"检测到版本号相同 (v{old_version}),跳过更新") # 如果version相同,恢复旧配置文件并返回 @@ -62,7 +63,7 @@ def update_config(): if key == "version": continue if key in target: - if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)): + if isinstance(value, dict) and isinstance(target[key], (dict, Table)): update_dict(target[key], value) else: try: @@ -85,10 +86,7 @@ def update_config(): if value and isinstance(value[0], dict) and "regex" in value[0]: contains_regex = True - if contains_regex: - target[key] = value - else: - target[key] = tomlkit.array(value) + target[key] = value if contains_regex else tomlkit.array(str(value)) else: # 其他类型使用item方法创建新值 target[key] = tomlkit.item(value) diff --git a/src/config/config.py b/src/config/config.py index de173a520..b61111ec3 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -1,16 +1,14 @@ import os -from dataclasses import field, dataclass - import tomlkit import shutil -from datetime import datetime +from datetime import datetime from tomlkit import TOMLDocument from tomlkit.items import Table - -from src.common.logger import get_logger +from dataclasses import field, dataclass from rich.traceback import install +from src.common.logger import get_logger from src.config.config_base import ConfigBase from src.config.official_configs import ( BotConfig, @@ -80,8 +78,8 @@ def update_config(): # 检查version是否相同 if old_config and "inner" in old_config and "inner" in new_config: - old_version = old_config["inner"].get("version") - new_version = new_config["inner"].get("version") + old_version = old_config["inner"].get("version") # type: ignore + new_version = new_config["inner"].get("version") # type: ignore if old_version and new_version and old_version == new_version: logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") return @@ -103,7 +101,7 @@ def update_config(): shutil.copy2(template_path, new_config_path) logger.info(f"已创建新配置文件: {new_config_path}") - def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict): + def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): """ 将source字典的值更新到target字典中(如果target中存在相同的键) """ @@ -112,8 +110,9 @@ def update_config(): if key == "version": continue if key in target: - if isinstance(value, dict) and isinstance(target[key], (dict, Table)): - update_dict(target[key], value) + target_value = target[key] + if isinstance(value, dict) and isinstance(target_value, (dict, Table)): + update_dict(target_value, value) else: try: # 对数组类型进行特殊处理 diff --git a/src/config/config_base.py b/src/config/config_base.py index 129f5a1c0..5fb398190 100644 --- a/src/config/config_base.py +++ b/src/config/config_base.py @@ -43,7 +43,7 @@ class ConfigBase: field_type = f.type try: - init_args[field_name] = cls._convert_field(value, field_type) + init_args[field_name] = cls._convert_field(value, field_type) # type: ignore except TypeError as e: raise TypeError(f"Field '{field_name}' has a type error: {e}") from e except Exception as e: diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 7e2efbeba..6838df1d1 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1,7 +1,8 @@ -from dataclasses import dataclass, field -from typing import Any, Literal import re +from dataclasses import dataclass, field +from typing import Any, Literal, Optional + from src.config.config_base import ConfigBase """ @@ -113,7 +114,7 @@ class ChatConfig(ConfigBase): exit_focus_threshold: float = 1.0 """自动退出专注聊天的阈值,越低越容易退出专注聊天""" - def get_current_talk_frequency(self, chat_stream_id: str = None) -> float: + def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float: """ 根据当前时间和聊天流获取对应的 talk_frequency @@ -138,7 +139,7 @@ class ChatConfig(ConfigBase): # 如果都没有匹配,返回默认值 return self.talk_frequency - def _get_time_based_frequency(self, time_freq_list: list[str]) -> float: + def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]: """ 根据时间配置列表获取当前时段的频率 @@ -186,7 +187,7 @@ class ChatConfig(ConfigBase): return current_frequency - def _get_stream_specific_frequency(self, chat_stream_id: str) -> float: + def _get_stream_specific_frequency(self, chat_stream_id: str): """ 获取特定聊天流在当前时间的频率 @@ -217,7 +218,7 @@ class ChatConfig(ConfigBase): return None - def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> str: + def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]: """ 解析流配置字符串并生成对应的 chat_id diff --git a/src/individuality/identity.py b/src/individuality/identity.py index bb3125985..730615e3d 100644 --- a/src/individuality/identity.py +++ b/src/individuality/identity.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List +from typing import List, Optional @dataclass @@ -8,7 +8,7 @@ class Identity: identity_detail: List[str] # 身份细节描述 - def __init__(self, identity_detail: List[str] = None): + def __init__(self, identity_detail: Optional[List[str]] = None): """初始化身份特征 Args: diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index 8365c0888..532b203fd 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -1,17 +1,18 @@ -from typing import Optional import ast - -from src.llm_models.utils_model import LLMRequest -from .personality import Personality -from .identity import Identity import random import json import os import hashlib + +from typing import Optional from rich.traceback import install + from src.common.logger import get_logger -from src.person_info.person_info import get_person_info_manager from src.config.config import global_config +from src.llm_models.utils_model import LLMRequest +from src.person_info.person_info import get_person_info_manager +from .personality import Personality +from .identity import Identity install(extra_lines=3) @@ -23,7 +24,7 @@ class Individuality: def __init__(self): # 正常初始化实例属性 - self.personality: Optional[Personality] = None + self.personality: Personality = None # type: ignore self.identity: Optional[Identity] = None self.name = "" @@ -109,7 +110,7 @@ class Individuality: existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression") if existing_short_impression: try: - existing_data = ast.literal_eval(existing_short_impression) + existing_data = ast.literal_eval(existing_short_impression) # type: ignore if isinstance(existing_data, list) and len(existing_data) >= 1: personality_result = existing_data[0] except (json.JSONDecodeError, TypeError, IndexError): @@ -128,7 +129,7 @@ class Individuality: existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression") if existing_short_impression: try: - existing_data = ast.literal_eval(existing_short_impression) + existing_data = ast.literal_eval(existing_short_impression) # type: ignore if isinstance(existing_data, list) and len(existing_data) >= 2: identity_result = existing_data[1] except (json.JSONDecodeError, TypeError, IndexError): @@ -204,6 +205,7 @@ class Individuality: return prompt_personality def get_identity_prompt(self, level: int, x_person: int = 2) -> str: + # sourcery skip: assign-if-exp, merge-else-if-into-elif """ 获取身份特征的prompt @@ -240,13 +242,13 @@ class Individuality: if identity_parts: details_str = ",".join(identity_parts) - if x_person in [1, 2]: + if x_person in {1, 2}: return f"{i_pronoun},{details_str}。" else: # x_person == 0 # 无人称时,直接返回细节,不加代词和开头的逗号 return f"{details_str}。" else: - if x_person in [1, 2]: + if x_person in {1, 2}: return f"{i_pronoun}的身份信息不完整。" else: # x_person == 0 return "身份信息不完整。" @@ -441,14 +443,15 @@ class Individuality: if info_list_json: try: info_list = json.loads(info_list_json) if isinstance(info_list_json, str) else info_list_json - for item in info_list: - if isinstance(item, dict) and "info_type" in item: - keywords.append(item["info_type"]) + keywords.extend( + item["info_type"] for item in info_list if isinstance(item, dict) and "info_type" in item + ) except (json.JSONDecodeError, TypeError): logger.error(f"解析info_list失败: {info_list_json}") return keywords async def _create_personality(self, personality_core: str, personality_sides: list) -> str: + # sourcery skip: merge-list-append, move-assign """使用LLM创建压缩版本的impression Args: diff --git a/src/individuality/personality.py b/src/individuality/personality.py index 0ee46a3d0..ace719331 100644 --- a/src/individuality/personality.py +++ b/src/individuality/personality.py @@ -1,6 +1,7 @@ -from dataclasses import dataclass -from typing import Dict, List import json + +from dataclasses import dataclass +from typing import Dict, List, Optional from pathlib import Path @@ -24,7 +25,7 @@ class Personality: cls._instance = super().__new__(cls) return cls._instance - def __init__(self, personality_core: str = "", personality_sides: List[str] = None): + def __init__(self, personality_core: str = "", personality_sides: Optional[List[str]] = None): if personality_sides is None: personality_sides = [] self.personality_core = personality_core @@ -41,7 +42,7 @@ class Personality: cls._instance = cls() return cls._instance - def _init_big_five_personality(self): + def _init_big_five_personality(self): # sourcery skip: extract-method """初始化大五人格特质""" # 构建文件路径 personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per" @@ -63,7 +64,6 @@ class Personality: else: self.extraversion = 0.3 self.neuroticism = 0.5 - if "认真" in self.personality_core or "负责" in self.personality_sides: self.conscientiousness = 0.9 else: diff --git a/src/manager/async_task_manager.py b/src/manager/async_task_manager.py index 1e1e9132f..0a2c0d215 100644 --- a/src/manager/async_task_manager.py +++ b/src/manager/async_task_manager.py @@ -120,12 +120,7 @@ class AsyncTaskManager: """ 获取所有任务的状态 """ - tasks_status = {} - for task_name, task in self.tasks.items(): - tasks_status[task_name] = { - "status": "running" if not task.done() else "done", - } - return tasks_status + return {task_name: {"status": "done" if task.done() else "running"} for task_name, task in self.tasks.items()} async def stop_and_wait_all_tasks(self): """ diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index ffdf8ff36..e3a66370b 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -2,12 +2,12 @@ import math import random import time -from src.chat.message_receive.message import MessageRecv -from src.llm_models.utils_model import LLMRequest -from ..common.logger import get_logger -from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive +from src.common.logger import get_logger from src.config.config import global_config +from src.chat.message_receive.message import MessageRecv from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive +from src.llm_models.utils_model import LLMRequest from src.manager.async_task_manager import AsyncTask, async_task_manager logger = get_logger("mood") @@ -55,12 +55,12 @@ class ChatMood: request_type="mood", ) - self.last_change_time = 0 + self.last_change_time: float = 0 async def update_mood_by_message(self, message: MessageRecv, interested_rate: float): self.regression_count = 0 - during_last_time = message.message_info.time - self.last_change_time + during_last_time = message.message_info.time - self.last_change_time # type: ignore base_probability = 0.05 time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time)) @@ -78,7 +78,7 @@ class ChatMood: if random.random() > update_probability: return - message_time = message.message_info.time + message_time: float = message.message_info.time # type: ignore message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, @@ -119,7 +119,7 @@ class ChatMood: self.mood_state = response - self.last_change_time = message_time + self.last_change_time = message_time # type: ignore async def regress_mood(self): message_time = time.time() diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index f44a88225..5e5f033f9 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -1,17 +1,18 @@ -from src.common.logger import get_logger -from src.common.database.database import db -from src.common.database.database_model import PersonInfo # 新增导入 import copy import hashlib -from typing import Any, Callable, Dict, Union import datetime import asyncio +import json + +from json_repair import repair_json +from typing import Any, Callable, Dict, Union, Optional + +from src.common.logger import get_logger +from src.common.database.database import db +from src.common.database.database_model import PersonInfo from src.llm_models.utils_model import LLMRequest from src.config.config import global_config -import json # 新增导入 -from json_repair import repair_json - """ PersonInfoManager 类方法功能摘要: @@ -42,7 +43,7 @@ person_info_default = { "last_know": None, # "user_cardname": None, # This field is not in Peewee model PersonInfo # "user_avatar": None, # This field is not in Peewee model PersonInfo - "impression": None, # Corrected from persion_impression + "impression": None, # Corrected from person_impression "short_impression": None, "info_list": None, "points": None, @@ -106,27 +107,24 @@ class PersonInfoManager: logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}") return False - def get_person_id_by_person_name(self, person_name: str): + def get_person_id_by_person_name(self, person_name: str) -> str: """根据用户名获取用户ID""" try: record = PersonInfo.get_or_none(PersonInfo.person_name == person_name) - if record: - return record.person_id - else: - return "" + return record.person_id if record else "" except Exception as e: logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}") return "" @staticmethod - async def create_person_info(person_id: str, data: dict = None): + async def create_person_info(person_id: str, data: Optional[dict] = None): """创建一个项""" if not person_id: - logger.debug("创建失败,personid不存在") + logger.debug("创建失败,person_id不存在") return _person_info_default = copy.deepcopy(person_info_default) - model_fields = PersonInfo._meta.fields.keys() + model_fields = PersonInfo._meta.fields.keys() # type: ignore final_data = {"person_id": person_id} @@ -163,9 +161,9 @@ class PersonInfoManager: await asyncio.to_thread(_db_create_sync, final_data) - async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None): + async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None): """更新某一个字段,会补全""" - if field_name not in PersonInfo._meta.fields: + if field_name not in PersonInfo._meta.fields: # type: ignore logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。") return @@ -228,15 +226,13 @@ class PersonInfoManager: @staticmethod async def has_one_field(person_id: str, field_name: str): """判断是否存在某一个字段""" - if field_name not in PersonInfo._meta.fields: + if field_name not in PersonInfo._meta.fields: # type: ignore logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。") return False def _db_has_field_sync(p_id: str, f_name: str): record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) - if record: - return True - return False + return bool(record) try: return await asyncio.to_thread(_db_has_field_sync, person_id, field_name) @@ -435,9 +431,7 @@ class PersonInfoManager: except Exception as e: logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}") # Fallback to default in case of any error during DB access - if field_name in person_info_default: - return default_value_for_field - return None + return default_value_for_field if field_name in person_info_default else None @staticmethod def get_value_sync(person_id: str, field_name: str): @@ -446,8 +440,7 @@ class PersonInfoManager: if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None: default_value_for_field = [] - record = PersonInfo.get_or_none(PersonInfo.person_id == person_id) - if record: + if record := PersonInfo.get_or_none(PersonInfo.person_id == person_id): val = getattr(record, field_name, None) if field_name in JSON_SERIALIZED_FIELDS: if isinstance(val, str): @@ -481,7 +474,7 @@ class PersonInfoManager: record = await asyncio.to_thread(_db_get_record_sync, person_id) for field_name in field_names: - if field_name not in PersonInfo._meta.fields: + if field_name not in PersonInfo._meta.fields: # type: ignore if field_name in person_info_default: result[field_name] = copy.deepcopy(person_info_default[field_name]) logger.debug(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。") @@ -509,7 +502,7 @@ class PersonInfoManager: """ 获取满足条件的字段值字典 """ - if field_name not in PersonInfo._meta.fields: + if field_name not in PersonInfo._meta.fields: # type: ignore logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义") return {} @@ -531,7 +524,7 @@ class PersonInfoManager: return {} async def get_or_create_person( - self, platform: str, user_id: int, nickname: str = None, user_cardname: str = None, user_avatar: str = None + self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None ) -> str: """ 根据 platform 和 user_id 获取 person_id。 @@ -561,7 +554,7 @@ class PersonInfoManager: "points": [], "forgotten_points": [], } - model_fields = PersonInfo._meta.fields.keys() + model_fields = PersonInfo._meta.fields.keys() # type: ignore filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} await self.create_person_info(person_id, data=filtered_initial_data) @@ -610,7 +603,9 @@ class PersonInfoManager: "name_reason", ] valid_fields_to_get = [ - f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default + f + for f in required_fields + if f in PersonInfo._meta.fields or f in person_info_default # type: ignore ] person_data = await self.get_values(found_person_id, valid_fields_to_get) diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index 0b443850f..7b69b47bb 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -3,12 +3,12 @@ import traceback import os import pickle import random -from typing import List, Dict +from typing import List, Dict, Any from src.config.config import global_config from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import get_chat_manager from src.person_info.relationship_manager import get_relationship_manager from src.person_info.person_info import get_person_info_manager, PersonInfoManager +from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import ( get_raw_msg_by_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat_inclusive, @@ -45,7 +45,7 @@ class RelationshipBuilder: self.chat_id = chat_id # 新的消息段缓存结构: # {person_id: [{"start_time": float, "end_time": float, "last_msg_time": float, "message_count": int}, ...]} - self.person_engaged_cache: Dict[str, List[Dict[str, any]]] = {} + self.person_engaged_cache: Dict[str, List[Dict[str, Any]]] = {} # 持久化存储文件路径 self.cache_file_path = os.path.join("data", "relationship", f"relationship_cache_{self.chat_id}.pkl") @@ -210,11 +210,7 @@ class RelationshipBuilder: if person_id not in self.person_engaged_cache: return 0 - total_count = 0 - for segment in self.person_engaged_cache[person_id]: - total_count += segment["message_count"] - - return total_count + return sum(segment["message_count"] for segment in self.person_engaged_cache[person_id]) def _cleanup_old_segments(self) -> bool: """清理老旧的消息段""" @@ -289,7 +285,7 @@ class RelationshipBuilder: self.last_cleanup_time = current_time # 保存缓存 - if cleanup_stats["segments_removed"] > 0 or len(users_to_remove) > 0: + if cleanup_stats["segments_removed"] > 0 or users_to_remove: self._save_cache() logger.info( f"{self.log_prefix} 清理完成 - 影响用户: {cleanup_stats['users_cleaned']}, 移除消息段: {cleanup_stats['segments_removed']}, 移除用户: {len(users_to_remove)}" @@ -313,6 +309,7 @@ class RelationshipBuilder: return False def get_cache_status(self) -> str: + # sourcery skip: merge-list-append, merge-list-appends-into-extend """获取缓存状态信息,用于调试和监控""" if not self.person_engaged_cache: return f"{self.log_prefix} 关系缓存为空" @@ -357,13 +354,12 @@ class RelationshipBuilder: self._cleanup_old_segments() current_time = time.time() - latest_messages = get_raw_msg_by_timestamp_with_chat( + if latest_messages := get_raw_msg_by_timestamp_with_chat( self.chat_id, self.last_processed_message_time, current_time, limit=50, # 获取自上次处理后的消息 - ) - if latest_messages: + ): # 处理所有新的非bot消息 for latest_msg in latest_messages: user_id = latest_msg.get("user_id") @@ -414,7 +410,7 @@ class RelationshipBuilder: # 负责触发关系构建、整合消息段、更新用户印象 # ================================ - async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, any]]): + async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]): """基于消息段更新用户印象""" original_segment_count = len(segments) logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象") diff --git a/src/person_info/relationship_builder_manager.py b/src/person_info/relationship_builder_manager.py index 926d67fca..f3bca25d2 100644 --- a/src/person_info/relationship_builder_manager.py +++ b/src/person_info/relationship_builder_manager.py @@ -1,4 +1,5 @@ -from typing import Dict, Optional, List +from typing import Dict, Optional, List, Any + from src.common.logger import get_logger from .relationship_builder import RelationshipBuilder @@ -63,7 +64,7 @@ class RelationshipBuilderManager: """ return list(self.builders.keys()) - def get_status(self) -> Dict[str, any]: + def get_status(self) -> Dict[str, Any]: """获取管理器状态 Returns: @@ -94,9 +95,7 @@ class RelationshipBuilderManager: bool: 是否成功清理 """ builder = self.get_builder(chat_id) - if builder: - return builder.force_cleanup_user_segments(person_id) - return False + return builder.force_cleanup_user_segments(person_id) if builder else False # 全局管理器实例 diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 65be0b3af..5e369e752 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -1,16 +1,19 @@ -from src.config.config import global_config -from src.llm_models.utils_model import LLMRequest import time import traceback -from src.common.logger import get_logger -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.person_info.person_info import get_person_info_manager -from typing import List, Dict -from json_repair import repair_json -from src.chat.message_receive.chat_stream import get_chat_manager import json import random +from typing import List, Dict, Any +from json_repair import repair_json + +from src.common.logger import get_logger +from src.config.config import global_config +from src.llm_models.utils_model import LLMRequest +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.message_receive.chat_stream import get_chat_manager +from src.person_info.person_info import get_person_info_manager + + logger = get_logger("relationship_fetcher") @@ -62,11 +65,11 @@ class RelationshipFetcher: self.chat_id = chat_id # 信息获取缓存:记录正在获取的信息请求 - self.info_fetching_cache: List[Dict[str, any]] = [] + self.info_fetching_cache: List[Dict[str, Any]] = [] # 信息结果缓存:存储已获取的信息结果,带TTL - self.info_fetched_cache: Dict[str, Dict[str, any]] = {} - # 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknow": bool}}} + self.info_fetched_cache: Dict[str, Dict[str, Any]] = {} + # 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}} # LLM模型配置 self.llm_model = LLMRequest( @@ -184,7 +187,7 @@ class RelationshipFetcher: nickname_str = ",".join(global_config.bot.alias_names) name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。" person_info_manager = get_person_info_manager() - person_name = await person_info_manager.get_value(person_id, "person_name") + person_name: str = await person_info_manager.get_value(person_id, "person_name") # type: ignore info_cache_block = self._build_info_cache_block() @@ -208,8 +211,7 @@ class RelationshipFetcher: logger.debug(f"{self.log_prefix} LLM判断当前不需要查询任何信息:{content_json.get('none', '')}") return None - info_type = content_json.get("info_type") - if info_type: + if info_type := content_json.get("info_type"): # 记录信息获取请求 self.info_fetching_cache.append( { @@ -287,7 +289,7 @@ class RelationshipFetcher: "ttl": 2, "start_time": start_time, "person_name": person_name, - "unknow": cached_info == "none", + "unknown": cached_info == "none", } logger.info(f"{self.log_prefix} 记得 {person_name} 的 {info_type}: {cached_info}") return @@ -321,7 +323,7 @@ class RelationshipFetcher: "ttl": 2, "start_time": start_time, "person_name": person_name, - "unknow": True, + "unknown": True, } logger.info(f"{self.log_prefix} 完全不认识 {person_name}") await self._save_info_to_cache(person_id, info_type, "none") @@ -353,15 +355,15 @@ class RelationshipFetcher: if person_id not in self.info_fetched_cache: self.info_fetched_cache[person_id] = {} self.info_fetched_cache[person_id][info_type] = { - "info": "unknow" if is_unknown else info_content, + "info": "unknown" if is_unknown else info_content, "ttl": 3, "start_time": start_time, "person_name": person_name, - "unknow": is_unknown, + "unknown": is_unknown, } # 保存到持久化缓存 (info_list) - await self._save_info_to_cache(person_id, info_type, info_content if not is_unknown else "none") + await self._save_info_to_cache(person_id, info_type, "none" if is_unknown else info_content) if not is_unknown: logger.info(f"{self.log_prefix} 思考得到,{person_name} 的 {info_type}: {info_content}") @@ -393,7 +395,7 @@ class RelationshipFetcher: for info_type in self.info_fetched_cache[person_id]: person_name = self.info_fetched_cache[person_id][info_type]["person_name"] - if not self.info_fetched_cache[person_id][info_type]["unknow"]: + if not self.info_fetched_cache[person_id][info_type]["unknown"]: info_content = self.info_fetched_cache[person_id][info_type]["info"] person_known_infos.append(f"[{info_type}]:{info_content}") else: @@ -430,6 +432,7 @@ class RelationshipFetcher: return persons_infos_str async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str): + # sourcery skip: use-next """将提取到的信息保存到 person_info 的 info_list 字段中 Args: diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index 039197250..2c544fe46 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -1,5 +1,5 @@ from src.common.logger import get_logger -from src.person_info.person_info import PersonInfoManager, get_person_info_manager +from .person_info import PersonInfoManager, get_person_info_manager import time import random from src.llm_models.utils_model import LLMRequest @@ -12,7 +12,7 @@ from difflib import SequenceMatcher import jieba from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity - +from typing import List, Dict, Any logger = get_logger("relation") @@ -28,8 +28,7 @@ class RelationshipManager: async def is_known_some_one(platform, user_id): """判断是否认识某人""" person_info_manager = get_person_info_manager() - is_known = await person_info_manager.is_person_known(platform, user_id) - return is_known + return await person_info_manager.is_person_known(platform, user_id) @staticmethod async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str): @@ -110,7 +109,7 @@ class RelationshipManager: return relation_prompt - async def update_person_impression(self, person_id, timestamp, bot_engaged_messages=None): + async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]): """更新用户印象 Args: @@ -123,7 +122,7 @@ class RelationshipManager: person_info_manager = get_person_info_manager() person_name = await person_info_manager.get_value(person_id, "person_name") nickname = await person_info_manager.get_value(person_id, "nickname") - know_times = await person_info_manager.get_value(person_id, "know_times") or 0 + know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore alias_str = ", ".join(global_config.bot.alias_names) # personality_block =get_individuality().get_personality_prompt(x_person=2, level=2) @@ -142,13 +141,13 @@ class RelationshipManager: # 遍历消息,构建映射 for msg in user_messages: await person_info_manager.get_or_create_person( - platform=msg.get("chat_info_platform"), - user_id=msg.get("user_id"), - nickname=msg.get("user_nickname"), - user_cardname=msg.get("user_cardname"), + platform=msg.get("chat_info_platform"), # type: ignore + user_id=msg.get("user_id"), # type: ignore + nickname=msg.get("user_nickname"), # type: ignore + user_cardname=msg.get("user_cardname"), # type: ignore ) - replace_user_id = msg.get("user_id") - replace_platform = msg.get("chat_info_platform") + replace_user_id: str = msg.get("user_id") # type: ignore + replace_platform: str = msg.get("chat_info_platform") # type: ignore replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id) replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name") @@ -354,8 +353,8 @@ class RelationshipManager: person_name = await person_info_manager.get_value(person_id, "person_name") nickname = await person_info_manager.get_value(person_id, "nickname") - know_times = await person_info_manager.get_value(person_id, "know_times") or 0 - attitude = await person_info_manager.get_value(person_id, "attitude") or 50 + know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore + attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore # 根据熟悉度,调整印象和简短印象的最大长度 if know_times > 300: @@ -414,16 +413,14 @@ class RelationshipManager: if len(remaining_points) < 10: # 如果还没达到30条,直接保留 remaining_points.append(point) + elif random.random() < keep_probability: + # 保留这个点,随机移除一个已保留的点 + idx_to_remove = random.randrange(len(remaining_points)) + points_to_move.append(remaining_points[idx_to_remove]) + remaining_points[idx_to_remove] = point else: - # 随机决定是否保留 - if random.random() < keep_probability: - # 保留这个点,随机移除一个已保留的点 - idx_to_remove = random.randrange(len(remaining_points)) - points_to_move.append(remaining_points[idx_to_remove]) - remaining_points[idx_to_remove] = point - else: - # 不保留这个点 - points_to_move.append(point) + # 不保留这个点 + points_to_move.append(point) # 更新points和forgotten_points current_points = remaining_points @@ -520,7 +517,7 @@ class RelationshipManager: new_attitude = int(relation_value_json.get("attitude", 50)) # 获取当前的关系值 - old_attitude = await person_info_manager.get_value(person_id, "attitude") or 50 + old_attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore # 更新熟悉度 if new_attitude > 25: diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index c341e5214..6c8cc01da 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -65,9 +65,9 @@ def get_replyer( async def generate_reply( - chat_stream=None, - chat_id: str = None, - action_data: Dict[str, Any] = None, + chat_stream: Optional[ChatStream] = None, + chat_id: Optional[str] = None, + action_data: Optional[Dict[str, Any]] = None, reply_to: str = "", extra_info: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, @@ -78,25 +78,25 @@ async def generate_reply( model_configs: Optional[List[Dict[str, Any]]] = None, request_type: str = "", enable_timeout: bool = False, -) -> Tuple[bool, List[Tuple[str, Any]]]: +) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: """生成回复 Args: chat_stream: 聊天流对象(优先) - action_data: 动作数据 chat_id: 聊天ID(备用) + action_data: 动作数据 enable_splitter: 是否启用消息分割器 enable_chinese_typo: 是否启用错字生成器 return_prompt: 是否返回提示词 Returns: - Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合) + Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词) """ try: # 获取回复器 replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type) if not replyer: logger.error("[GeneratorAPI] 无法获取回复器") - return False, [] + return False, [], None logger.debug("[GeneratorAPI] 开始生成回复") @@ -109,8 +109,9 @@ async def generate_reply( enable_timeout=enable_timeout, enable_tool=enable_tool, ) - - reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo) + reply_set = [] + if content: + reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo) if success: logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项") @@ -118,19 +119,19 @@ async def generate_reply( logger.warning("[GeneratorAPI] 回复生成失败") if return_prompt: - return success, reply_set or [], prompt + return success, reply_set, prompt else: - return success, reply_set or [] + return success, reply_set, None except Exception as e: logger.error(f"[GeneratorAPI] 生成回复时出错: {e}") - return False, [] + return False, [], None async def rewrite_reply( - chat_stream=None, - reply_data: Dict[str, Any] = None, - chat_id: str = None, + chat_stream: Optional[ChatStream] = None, + reply_data: Optional[Dict[str, Any]] = None, + chat_id: Optional[str] = None, enable_splitter: bool = True, enable_chinese_typo: bool = True, model_configs: Optional[List[Dict[str, Any]]] = None, @@ -158,15 +159,16 @@ async def rewrite_reply( # 调用回复器重写回复 success, content = await replyer.rewrite_reply_with_context(reply_data=reply_data or {}) - - reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo) + reply_set = [] + if content: + reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo) if success: logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项") else: logger.warning("[GeneratorAPI] 重写回复失败") - return success, reply_set or [] + return success, reply_set except Exception as e: logger.error(f"[GeneratorAPI] 重写回复时出错: {e}") diff --git a/src/tools/tool_executor.py b/src/tools/tool_executor.py index 29ee8be1b..403ed554f 100644 --- a/src/tools/tool_executor.py +++ b/src/tools/tool_executor.py @@ -34,7 +34,7 @@ class ToolExecutor: 可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。 """ - def __init__(self, chat_id: str = None, enable_cache: bool = True, cache_ttl: int = 3): + def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3): """初始化工具执行器 Args: @@ -62,8 +62,8 @@ class ToolExecutor: logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}") async def execute_from_chat_message( - self, target_message: str, chat_history: list[str], sender: str, return_details: bool = False - ) -> List[Dict] | Tuple[List[Dict], List[str], str]: + self, target_message: str, chat_history: str, sender: str, return_details: bool = False + ) -> Tuple[List[Dict], List[str], str]: """从聊天消息执行工具 Args: @@ -79,16 +79,14 @@ class ToolExecutor: # 首先检查缓存 cache_key = self._generate_cache_key(target_message, chat_history, sender) - cached_result = self._get_from_cache(cache_key) - - if cached_result: + if cached_result := self._get_from_cache(cache_key): logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行") - if return_details: - # 从缓存结果中提取工具名称 - used_tools = [result.get("tool_name", "unknown") for result in cached_result] - return cached_result, used_tools, "使用缓存结果" - else: - return cached_result + if not return_details: + return cached_result, [], "使用缓存结果" + + # 从缓存结果中提取工具名称 + used_tools = [result.get("tool_name", "unknown") for result in cached_result] + return cached_result, used_tools, "使用缓存结果" # 缓存未命中,执行工具调用 # 获取可用工具 @@ -134,7 +132,7 @@ class ToolExecutor: if return_details: return tool_results, used_tools, prompt else: - return tool_results + return tool_results, [], "" async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]: """执行工具调用 @@ -207,7 +205,7 @@ class ToolExecutor: return tool_results, used_tools - def _generate_cache_key(self, target_message: str, chat_history: list[str], sender: str) -> str: + def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str: """生成缓存键 Args: @@ -267,10 +265,7 @@ class ToolExecutor: return expired_keys = [] - for cache_key, cache_item in self.tool_cache.items(): - if cache_item["ttl"] <= 0: - expired_keys.append(cache_key) - + expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0) for key in expired_keys: del self.tool_cache[key] @@ -355,7 +350,7 @@ class ToolExecutor: "ttl_distribution": ttl_distribution, } - def set_cache_config(self, enable_cache: bool = None, cache_ttl: int = None): + def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1): """动态修改缓存配置 Args: @@ -366,7 +361,7 @@ class ToolExecutor: self.enable_cache = enable_cache logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}") - if cache_ttl is not None and cache_ttl > 0: + if cache_ttl > 0: self.cache_ttl = cache_ttl logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}") @@ -380,7 +375,7 @@ init_tool_executor_prompt() # 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3) executor = ToolExecutor(executor_id="my_executor") -results = await executor.execute_from_chat_message( +results, _, _ = await executor.execute_from_chat_message( talking_message_str="今天天气怎么样?现在几点了?", is_group_chat=False )