diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py
index 3ad3fe4cf..a9214a9af 100644
--- a/src/chat/replyer/default_generator.py
+++ b/src/chat/replyer/default_generator.py
@@ -4,6 +4,7 @@ import asyncio
import random
import ast
import re
+
from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
@@ -161,13 +162,13 @@ class DefaultReplyer:
async def generate_reply_with_context(
self,
- reply_data: Dict[str, Any] = None,
+ reply_data: Optional[Dict[str, Any]] = None,
reply_to: str = "",
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_tool: bool = True,
enable_timeout: bool = False,
- ) -> Tuple[bool, Optional[str]]:
+ ) -> Tuple[bool, Optional[str], Optional[str]]:
"""
回复器 (Replier): 核心逻辑,负责生成回复文本。
(已整合原 HeartFCGenerator 的功能)
@@ -225,14 +226,14 @@ class DefaultReplyer:
except Exception as llm_e:
# 精简报错信息
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
- return False, None # LLM 调用失败则无法生成回复
+ return False, None, prompt # LLM 调用失败则无法生成回复
return True, content, prompt
except Exception as e:
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
traceback.print_exc()
- return False, None
+ return False, None, prompt
async def rewrite_reply_with_context(
self,
@@ -368,7 +369,7 @@ class DefaultReplyer:
memory_str += f"- {running_memory['content']}\n"
return memory_str
- async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True):
+ async def build_tool_info(self, chat_history, reply_data: Optional[Dict], enable_tool: bool = True):
"""构建工具信息块
Args:
@@ -393,7 +394,7 @@ class DefaultReplyer:
try:
# 使用工具执行器获取信息
- tool_results = await self.tool_executor.execute_from_chat_message(
+ tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
sender=sender, target_message=text, chat_history=chat_history, return_details=False
)
@@ -468,7 +469,7 @@ class DefaultReplyer:
async def build_prompt_reply_context(
self,
- reply_data=None,
+ reply_data: Dict[str, Any],
available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_timeout: bool = False,
enable_tool: bool = True,
@@ -549,7 +550,7 @@ class DefaultReplyer:
),
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_half, target), "build_memory_block"),
self._time_and_run_task(
- self.build_tool_info(reply_data, chat_talking_prompt_half, enable_tool=enable_tool), "build_tool_info"
+ self.build_tool_info(chat_talking_prompt_half, reply_data, enable_tool=enable_tool), "build_tool_info"
),
)
@@ -806,7 +807,7 @@ class DefaultReplyer:
response_set: List[Tuple[str, str]],
thinking_id: str = "",
display_message: str = "",
- ) -> Optional[MessageSending]:
+ ) -> Optional[List[Tuple[str, bool]]]:
# sourcery skip: assign-if-exp, boolean-if-exp-identity, remove-unnecessary-cast
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
chat = self.chat_stream
@@ -869,7 +870,7 @@ class DefaultReplyer:
try:
if (
bot_message.is_private_message()
- or bot_message.reply.processed_plain_text != "[System Trigger Context]"
+ or bot_message.reply.processed_plain_text != "[System Trigger Context]" # type: ignore
or mark_head
):
set_reply = False
@@ -910,7 +911,7 @@ class DefaultReplyer:
is_emoji: bool,
thinking_start_time: float,
display_message: str,
- anchor_message: MessageRecv = None,
+ anchor_message: Optional[MessageRecv] = None,
) -> MessageSending:
"""构建单个发送消息"""
diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py
index a2a2aaaa0..3f1c731b4 100644
--- a/src/chat/replyer/replyer_manager.py
+++ b/src/chat/replyer/replyer_manager.py
@@ -1,8 +1,8 @@
from typing import Dict, Any, Optional, List
+from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.replyer.default_generator import DefaultReplyer
-from src.common.logger import get_logger
logger = get_logger("ReplyerManager")
diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py
index 8c579e6d3..6bdf7f58d 100644
--- a/src/chat/utils/chat_message_builder.py
+++ b/src/chat/utils/chat_message_builder.py
@@ -1,6 +1,7 @@
import time # 导入 time 模块以获取当前时间
import random
import re
+
from typing import List, Dict, Any, Tuple, Optional
from rich.traceback import install
@@ -88,8 +89,8 @@ def get_actions_by_timestamp_with_chat(
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
- & (ActionRecords.time > timestamp_start)
- & (ActionRecords.time < timestamp_end)
+ & (ActionRecords.time > timestamp_start) # type: ignore
+ & (ActionRecords.time < timestamp_end) # type: ignore
)
if limit > 0:
@@ -113,8 +114,8 @@ def get_actions_by_timestamp_with_chat_inclusive(
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
- & (ActionRecords.time >= timestamp_start)
- & (ActionRecords.time <= timestamp_end)
+ & (ActionRecords.time >= timestamp_start) # type: ignore
+ & (ActionRecords.time <= timestamp_end) # type: ignore
)
if limit > 0:
@@ -331,7 +332,7 @@ def _build_readable_messages_internal(
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
else:
- person_name = person_info_manager.get_value_sync(person_id, "person_name")
+ person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
if not person_name:
@@ -911,8 +912,8 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
person_ids_set = set() # 使用集合来自动去重
for msg in messages:
- platform = msg.get("user_platform")
- user_id = msg.get("user_id")
+ platform: str = msg.get("user_platform") # type: ignore
+ user_id: str = msg.get("user_id") # type: ignore
# 检查必要信息是否存在 且 不是机器人自己
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
diff --git a/src/chat/utils/json_utils.py b/src/chat/utils/json_utils.py
index 6226e6e96..892deac4f 100644
--- a/src/chat/utils/json_utils.py
+++ b/src/chat/utils/json_utils.py
@@ -1,7 +1,8 @@
+import ast
import json
import logging
-from typing import Any, Dict, TypeVar, List, Union, Tuple
-import ast
+
+from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional
# 定义类型变量用于泛型类型提示
T = TypeVar("T")
@@ -30,18 +31,14 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
# 尝试标准的 JSON 解析
return json.loads(json_str)
except json.JSONDecodeError:
- # 如果标准解析失败,尝试将单引号替换为双引号再解析
- # (注意:这种替换可能不安全,如果字符串内容本身包含引号)
- # 更安全的方式是用 ast.literal_eval
+ # 如果标准解析失败,尝试用 ast.literal_eval 解析
try:
# logger.debug(f"标准JSON解析失败,尝试用 ast.literal_eval 解析: {json_str[:100]}...")
result = ast.literal_eval(json_str)
- # 确保结果是字典(因为我们通常期望参数是字典)
if isinstance(result, dict):
return result
- else:
- logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
- return default_value
+ logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
+ return default_value
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...")
return default_value
@@ -53,7 +50,9 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
return default_value
-def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]:
+def extract_tool_call_arguments(
+ tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None
+) -> Dict[str, Any]:
"""
从LLM工具调用对象中提取参数
@@ -77,14 +76,12 @@ def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[s
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
return default_result
- # 提取arguments
- arguments_str = function_data.get("arguments", "{}")
- if not arguments_str:
+ if arguments_str := function_data.get("arguments", "{}"):
+ # 解析JSON
+ return safe_json_loads(arguments_str, default_result)
+ else:
return default_result
- # 解析JSON
- return safe_json_loads(arguments_str, default_result)
-
except Exception as e:
logger.error(f"提取工具调用参数时出错: {e}")
return default_result
diff --git a/src/chat/utils/prompt_builder.py b/src/chat/utils/prompt_builder.py
index 26f8ffbad..1b107904c 100644
--- a/src/chat/utils/prompt_builder.py
+++ b/src/chat/utils/prompt_builder.py
@@ -1,12 +1,12 @@
-from typing import Dict, Any, Optional, List, Union
import re
-from contextlib import asynccontextmanager
import asyncio
import contextvars
-from src.common.logger import get_logger
-# import traceback
from rich.traceback import install
+from contextlib import asynccontextmanager
+from typing import Dict, Any, Optional, List, Union
+
+from src.common.logger import get_logger
install(extra_lines=3)
@@ -32,6 +32,7 @@ class PromptContext:
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
+ # sourcery skip: hoist-statement-from-if, use-contextlib-suppress
"""创建一个异步的临时提示模板作用域"""
# 保存当前上下文并设置新上下文
if context_id is not None:
@@ -88,8 +89,7 @@ class PromptContext:
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
- target_context = context_id or self._current_context
- if target_context:
+ if target_context := context_id or self._current_context:
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
@@ -151,7 +151,7 @@ class Prompt(str):
@staticmethod
def _process_escaped_braces(template) -> str:
- """处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记"""
+ """处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记""" # type: ignore
# 如果传入的是列表,将其转换为字符串
if isinstance(template, list):
template = "\n".join(str(item) for item in template)
@@ -195,14 +195,8 @@ class Prompt(str):
obj._kwargs = kwargs
# 修改自动注册逻辑
- if should_register:
- if global_prompt_manager._context._current_context:
- # 如果存在当前上下文,则注册到上下文中
- # asyncio.create_task(global_prompt_manager._context.register_async(obj))
- pass
- else:
- # 否则注册到全局管理器
- global_prompt_manager.register(obj)
+ if should_register and not global_prompt_manager._context._current_context:
+ global_prompt_manager.register(obj)
return obj
@classmethod
@@ -276,15 +270,13 @@ class Prompt(str):
self.name,
args=list(args) if args else self._args,
_should_register=False,
- **kwargs if kwargs else self._kwargs,
+ **kwargs or self._kwargs,
)
# print(f"prompt build result: {ret} name: {ret.name} ")
return str(ret)
def __str__(self) -> str:
- if self._kwargs or self._args:
- return super().__str__()
- return self.template
+ return super().__str__() if self._kwargs or self._args else self.template
def __repr__(self) -> str:
return f"Prompt(template='{self.template}', name='{self.name}')"
diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py
index 25d231c01..4e0edd31f 100644
--- a/src/chat/utils/statistic.py
+++ b/src/chat/utils/statistic.py
@@ -1,18 +1,17 @@
-from collections import defaultdict
-from datetime import datetime, timedelta
-from typing import Any, Dict, Tuple, List
import asyncio
import concurrent.futures
import json
import os
import glob
+from collections import defaultdict
+from datetime import datetime, timedelta
+from typing import Any, Dict, Tuple, List
from src.common.logger import get_logger
+from src.common.database.database import db
+from src.common.database.database_model import OnlineTime, LLMUsage, Messages
from src.manager.async_task_manager import AsyncTask
-
-from ...common.database.database import db # This db is the Peewee database instance
-from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic")
@@ -76,14 +75,14 @@ class OnlineTimeRecordTask(AsyncTask):
with db.atomic(): # Use atomic operations for schema changes
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
- async def run(self):
+ async def run(self): # sourcery skip: use-named-expression
try:
current_time = datetime.now()
extended_end_time = current_time + timedelta(minutes=1)
if self.record_id:
# 如果有记录,则更新结束时间
- query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
+ query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore
updated_rows = query.execute()
if updated_rows == 0:
# Record might have been deleted or ID is stale, try to find/create
@@ -94,7 +93,7 @@ class OnlineTimeRecordTask(AsyncTask):
# Look for a record whose end_timestamp is recent enough to be considered ongoing
recent_record = (
OnlineTime.select()
- .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
+ .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore
.order_by(OnlineTime.end_timestamp.desc())
.first()
)
@@ -123,15 +122,15 @@ def _format_online_time(online_seconds: int) -> str:
:param online_seconds: 在线时间(秒)
:return: 格式化后的在线时间字符串
"""
- total_oneline_time = timedelta(seconds=online_seconds)
+ total_online_time = timedelta(seconds=online_seconds)
- days = total_oneline_time.days
- hours = total_oneline_time.seconds // 3600
- minutes = (total_oneline_time.seconds // 60) % 60
- seconds = total_oneline_time.seconds % 60
+ days = total_online_time.days
+ hours = total_online_time.seconds // 3600
+ minutes = (total_online_time.seconds // 60) % 60
+ seconds = total_online_time.seconds % 60
if days > 0:
# 如果在线时间超过1天,则格式化为"X天X小时X分钟"
- return f"{total_oneline_time.days}天{hours}小时{minutes}分钟{seconds}秒"
+ return f"{total_online_time.days}天{hours}小时{minutes}分钟{seconds}秒"
elif hours > 0:
# 如果在线时间超过1小时,则格式化为"X小时X分钟X秒"
return f"{hours}小时{minutes}分钟{seconds}秒"
@@ -163,7 +162,7 @@ class StatisticOutputTask(AsyncTask):
now = datetime.now()
if "deploy_time" in local_storage:
# 如果存在部署时间,则使用该时间作为全量统计的起始时间
- deploy_time = datetime.fromtimestamp(local_storage["deploy_time"])
+ deploy_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
else:
# 否则,使用最大时间范围,并记录部署时间为当前时间
deploy_time = datetime(2000, 1, 1)
@@ -252,7 +251,7 @@ class StatisticOutputTask(AsyncTask):
# 创建后台任务,不等待完成
collect_task = asyncio.create_task(
- loop.run_in_executor(executor, self._collect_all_statistics, now)
+ loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
)
stats = await collect_task
@@ -260,8 +259,8 @@ class StatisticOutputTask(AsyncTask):
# 创建并发的输出任务
output_tasks = [
- asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)),
- asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)),
+ asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
+ asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
]
# 等待所有输出任务完成
@@ -320,7 +319,7 @@ class StatisticOutputTask(AsyncTask):
# 以最早的时间戳为起始时间获取记录
# Assuming LLMUsage.timestamp is a DateTimeField
query_start_time = collect_period[-1][1]
- for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
+ for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_timestamp = record.timestamp # This is already a datetime object
for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
@@ -388,7 +387,7 @@ class StatisticOutputTask(AsyncTask):
query_start_time = collect_period[-1][1]
# Assuming OnlineTime.end_timestamp is a DateTimeField
- for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
+ for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore
# record.end_timestamp and record.start_timestamp are datetime objects
record_end_timestamp = record.end_timestamp
record_start_timestamp = record.start_timestamp
@@ -428,7 +427,7 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
- for message in Messages.select().where(Messages.time >= query_start_timestamp):
+ for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time # This is a float timestamp
chat_id = None
@@ -661,7 +660,7 @@ class StatisticOutputTask(AsyncTask):
if "last_full_statistics" in local_storage:
# 如果存在上次完整统计数据,则使用该数据进行增量统计
- last_stat = local_storage["last_full_statistics"] # 上次完整统计数据
+ last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
@@ -727,6 +726,7 @@ class StatisticOutputTask(AsyncTask):
return stat
def _convert_defaultdict_to_dict(self, data):
+ # sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks
"""递归转换defaultdict为普通dict"""
if isinstance(data, defaultdict):
# 转换defaultdict为普通dict
@@ -812,8 +812,7 @@ class StatisticOutputTask(AsyncTask):
# 全局阶段平均时间
if stats[FOCUS_AVG_TIMES_BY_STAGE]:
output.append("全局阶段平均时间:")
- for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items():
- output.append(f" {stage}: {avg_time:.3f}秒")
+ output.extend(f" {stage}: {avg_time:.3f}秒" for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items())
output.append("")
# Action类型比例
@@ -1050,7 +1049,7 @@ class StatisticOutputTask(AsyncTask):
]
tab_content_list.append(
- _format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"]))
+ _format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) # type: ignore
)
# 添加Focus统计内容
@@ -1212,6 +1211,7 @@ class StatisticOutputTask(AsyncTask):
f.write(html_template)
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
+ # sourcery skip: for-append-to-extend, list-comprehension, use-any
"""生成Focus统计独立分页的HTML内容"""
# 为每个时间段准备Focus数据
@@ -1313,12 +1313,11 @@ class StatisticOutputTask(AsyncTask):
# 聊天流Action选择比例对比表(横向表格)
focus_chat_action_ratios_rows = ""
if stat_data.get("focus_action_ratios_by_chat"):
- # 获取所有action类型(按全局频率排序)
- all_action_types_for_ratio = sorted(
- stat_data[FOCUS_ACTION_RATIOS].keys(), key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x], reverse=True
- )
-
- if all_action_types_for_ratio:
+ if all_action_types_for_ratio := sorted(
+ stat_data[FOCUS_ACTION_RATIOS].keys(),
+ key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x],
+ reverse=True,
+ ):
# 为每个聊天流生成数据行(按循环数排序)
chat_ratio_rows = []
for chat_id in sorted(
@@ -1379,16 +1378,11 @@ class StatisticOutputTask(AsyncTask):
if period_name == "all_time":
from src.manager.local_store_manager import local_storage
- start_time = datetime.fromtimestamp(local_storage["deploy_time"])
- time_range = (
- f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
- )
+ start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
else:
start_time = datetime.now() - period_delta
- time_range = (
- f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
- )
+ time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# 生成该时间段的Focus统计HTML
section_html = f"""
@@ -1681,16 +1675,10 @@ class StatisticOutputTask(AsyncTask):
if period_name == "all_time":
from src.manager.local_store_manager import local_storage
- start_time = datetime.fromtimestamp(local_storage["deploy_time"])
- time_range = (
- f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
- )
+ start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
else:
start_time = datetime.now() - period_delta
- time_range = (
- f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
- )
-
+ time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# 生成该时间段的版本对比HTML
section_html = f"""
@@ -1865,7 +1853,7 @@ class StatisticOutputTask(AsyncTask):
# 查询LLM使用记录
query_start_time = start_time
- for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
+ for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_time = record.timestamp
# 找到对应的时间间隔索引
@@ -1875,7 +1863,7 @@ class StatisticOutputTask(AsyncTask):
if 0 <= interval_index < len(time_points):
# 累加总花费数据
cost = record.cost or 0.0
- total_cost_data[interval_index] += cost
+ total_cost_data[interval_index] += cost # type: ignore
# 累加按模型分类的花费
model_name = record.model_name or "unknown"
@@ -1892,7 +1880,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
- for message in Messages.select().where(Messages.time >= query_start_timestamp):
+ for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time
# 找到对应的时间间隔索引
@@ -1982,6 +1970,7 @@ class StatisticOutputTask(AsyncTask):
}
def _generate_chart_tab(self, chart_data: dict) -> str:
+ # sourcery skip: extract-duplicate-method, move-assign-in-block
"""生成图表选项卡HTML内容"""
# 生成不同颜色的调色板
@@ -2293,7 +2282,7 @@ class AsyncStatisticOutputTask(AsyncTask):
# 数据收集任务
collect_task = asyncio.create_task(
- loop.run_in_executor(executor, self._collect_all_statistics, now)
+ loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
)
stats = await collect_task
@@ -2301,8 +2290,8 @@ class AsyncStatisticOutputTask(AsyncTask):
# 创建并发的输出任务
output_tasks = [
- asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)),
- asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)),
+ asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
+ asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
]
# 等待所有输出任务完成
diff --git a/src/chat/utils/timer_calculator.py b/src/chat/utils/timer_calculator.py
index df2b9f778..d9479af16 100644
--- a/src/chat/utils/timer_calculator.py
+++ b/src/chat/utils/timer_calculator.py
@@ -1,7 +1,8 @@
+import asyncio
+
from time import perf_counter
from functools import wraps
from typing import Optional, Dict, Callable
-import asyncio
from rich.traceback import install
install(extra_lines=3)
@@ -88,10 +89,10 @@ class Timer:
self.name = name
self.storage = storage
- self.elapsed = None
+ self.elapsed: float = None # type: ignore
self.auto_unit = auto_unit
- self.start = None
+ self.start: float = None # type: ignore
@staticmethod
def _validate_types(name, storage):
@@ -120,7 +121,7 @@ class Timer:
return None
wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
- wrapper.__timer__ = self # 保留计时器引用
+ wrapper.__timer__ = self # 保留计时器引用 # type: ignore
return wrapper
def __enter__(self):
diff --git a/src/chat/utils/typo_generator.py b/src/chat/utils/typo_generator.py
index 7c373f132..4de219464 100644
--- a/src/chat/utils/typo_generator.py
+++ b/src/chat/utils/typo_generator.py
@@ -7,10 +7,10 @@ import math
import os
import random
import time
+import jieba
+
from collections import defaultdict
from pathlib import Path
-
-import jieba
from pypinyin import Style, pinyin
from src.common.logger import get_logger
@@ -104,7 +104,7 @@ class ChineseTypoGenerator:
try:
return "\u4e00" <= char <= "\u9fff"
except Exception as e:
- logger.debug(e)
+ logger.debug(str(e))
return False
def _get_pinyin(self, sentence):
@@ -138,7 +138,7 @@ class ChineseTypoGenerator:
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1
- return py + "1"
+ return f"{py}1"
base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调
diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py
index f3226b2e1..2fbc69559 100644
--- a/src/chat/utils/utils.py
+++ b/src/chat/utils/utils.py
@@ -1,23 +1,21 @@
import random
import re
import time
-from collections import Counter
-
import jieba
import numpy as np
+
+from collections import Counter
from maim_message import UserInfo
+from typing import Optional, Tuple, Dict
from src.common.logger import get_logger
-
-# from src.mood.mood_manager import mood_manager
-from ..message_receive.message import MessageRecv
-from src.llm_models.utils_model import LLMRequest
-from .typo_generator import ChineseTypoGenerator
-from ...config.config import global_config
-from ...common.message_repository import find_messages, count_messages
-from typing import Optional, Tuple, Dict
+from src.common.message_repository import find_messages, count_messages
+from src.config.config import global_config
+from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
+from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
+from .typo_generator import ChineseTypoGenerator
logger = get_logger("chat_utils")
@@ -31,11 +29,7 @@ def db_message_to_str(message_dict: dict) -> str:
logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try:
- name = "[(%s)%s]%s" % (
- message_dict["user_id"],
- message_dict.get("user_nickname", ""),
- message_dict.get("user_cardname", ""),
- )
+ name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
except Exception:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "")
@@ -58,11 +52,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
and message.message_info.additional_config.get("is_mentioned") is not None
):
try:
- reply_probability = float(message.message_info.additional_config.get("is_mentioned"))
+ reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
is_mentioned = True
return is_mentioned, reply_probability
except Exception as e:
- logger.warning(e)
+ logger.warning(str(e))
logger.warning(
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
)
@@ -135,20 +129,17 @@ def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, c
if not recent_messages:
return []
- message_detailed_plain_text = ""
- message_detailed_plain_text_list = []
-
# 反转消息列表,使最新的消息在最后
recent_messages.reverse()
if combine:
- for msg_db_data in recent_messages:
- message_detailed_plain_text += str(msg_db_data["detailed_plain_text"])
- return message_detailed_plain_text
- else:
- for msg_db_data in recent_messages:
- message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
- return message_detailed_plain_text_list
+ return "".join(str(msg_db_data["detailed_plain_text"]) for msg_db_data in recent_messages)
+
+ message_detailed_plain_text_list = []
+
+ for msg_db_data in recent_messages:
+ message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
+ return message_detailed_plain_text_list
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
@@ -204,10 +195,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
len_text = len(text)
if len_text < 3:
- if random.random() < 0.01:
- return list(text) # 如果文本很短且触发随机条件,直接按字符分割
- else:
- return [text]
+ return list(text) if random.random() < 0.01 else [text]
# 定义分隔符
separators = {",", ",", " ", "。", ";"}
@@ -352,10 +340,9 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
max_length = global_config.response_splitter.max_length * 2
max_sentence_num = global_config.response_splitter.max_sentence_num
# 如果基本上是中文,则进行长度过滤
- if get_western_ratio(cleaned_text) < 0.1:
- if len(cleaned_text) > max_length:
- logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
- return ["懒得说"]
+ if get_western_ratio(cleaned_text) < 0.1 and len(cleaned_text) > max_length:
+ logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
+ return ["懒得说"]
typo_generator = ChineseTypoGenerator(
error_rate=global_config.chinese_typo.error_rate,
@@ -420,7 +407,7 @@ def calculate_typing_time(
# chinese_time *= 1 / typing_speed_multiplier
# english_time *= 1 / typing_speed_multiplier
# 计算中文字符数
- chinese_chars = sum(1 for char in input_string if "\u4e00" <= char <= "\u9fff")
+ chinese_chars = sum("\u4e00" <= char <= "\u9fff" for char in input_string)
# 如果只有一个中文字符,使用3倍时间
if chinese_chars == 1 and len(input_string.strip()) == 1:
@@ -429,11 +416,7 @@ def calculate_typing_time(
# 正常计算所有字符的输入时间
total_time = 0.0
for char in input_string:
- if "\u4e00" <= char <= "\u9fff": # 判断是否为中文字符
- total_time += chinese_time
- else: # 其他字符(如英文)
- total_time += english_time
-
+ total_time += chinese_time if "\u4e00" <= char <= "\u9fff" else english_time
if is_emoji:
total_time = 1
@@ -453,18 +436,14 @@ def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
- if norm1 == 0 or norm2 == 0:
- return 0
- return dot_product / (norm1 * norm2)
+ return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
def text_to_vector(text):
"""将文本转换为词频向量"""
# 分词
words = jieba.lcut(text)
- # 统计词频
- word_freq = Counter(words)
- return word_freq
+ return Counter(words)
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
@@ -491,9 +470,7 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度"""
- if len(message) > max_length:
- return message[:max_length] + "..."
- return message
+ return f"{message[:max_length]}..." if len(message) > max_length else message
def protect_kaomoji(sentence):
@@ -522,7 +499,7 @@ def protect_kaomoji(sentence):
placeholder_to_kaomoji = {}
for idx, match in enumerate(kaomoji_matches):
- kaomoji = match[0] if match[0] else match[1]
+ kaomoji = match[0] or match[1]
placeholder = f"__KAOMOJI_{idx}__"
sentence = sentence.replace(kaomoji, placeholder, 1)
placeholder_to_kaomoji[placeholder] = kaomoji
@@ -563,7 +540,7 @@ def get_western_ratio(paragraph):
if not alnum_chars:
return 0.0
- western_count = sum(1 for char in alnum_chars if is_english_letter(char))
+ western_count = sum(bool(is_english_letter(char)) for char in alnum_chars)
return western_count / len(alnum_chars)
@@ -610,6 +587,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
+ # sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
"""将时间戳转换为人类可读的时间格式
Args:
@@ -621,7 +599,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
"""
if mode == "normal":
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
- if mode == "normal_no_YMD":
+ elif mode == "normal_no_YMD":
return time.strftime("%H:%M:%S", time.localtime(timestamp))
elif mode == "relative":
now = time.time()
@@ -640,7 +618,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
else:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":"
else: # mode = "lite" or unknown
- # 只返回时分秒格式,喵~
+ # 只返回时分秒格式
return time.strftime("%H:%M:%S", time.localtime(timestamp))
@@ -670,8 +648,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
elif chat_stream.user_info: # It's a private chat
is_group_chat = False
user_info = chat_stream.user_info
- platform = chat_stream.platform
- user_id = user_info.user_id
+ platform: str = chat_stream.platform # type: ignore
+ user_id: str = user_info.user_id # type: ignore
# Initialize target_info with basic info
target_info = {
diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py
index 5579ccf84..d5fa301bb 100644
--- a/src/chat/utils/utils_image.py
+++ b/src/chat/utils/utils_image.py
@@ -3,21 +3,20 @@ import os
import time
import hashlib
import uuid
+import io
+import asyncio
+import numpy as np
+
from typing import Optional, Tuple
from PIL import Image
-import io
-import numpy as np
-import asyncio
-
+from rich.traceback import install
+from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
-from src.common.logger import get_logger
-from rich.traceback import install
-
install(extra_lines=3)
logger = get_logger("chat_image")
@@ -111,7 +110,7 @@ class ImageManager:
return f"[表情包,含义看起来是:{cached_description}]"
# 调用AI获取描述
- if image_format == "gif" or image_format == "GIF":
+ if image_format in ["gif", "GIF"]:
image_base64_processed = self.transform_gif(image_base64)
if image_base64_processed is None:
logger.warning("GIF转换失败,无法获取描述")
@@ -258,6 +257,7 @@ class ImageManager:
@staticmethod
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
+ # sourcery skip: use-contextlib-suppress
"""将GIF转换为水平拼接的静态图像, 跳过相似的帧
Args:
@@ -351,7 +351,7 @@ class ImageManager:
# 创建拼接图像
total_width = target_width * len(resized_frames)
# 防止总宽度为0
- if total_width == 0 and len(resized_frames) > 0:
+ if total_width == 0 and resized_frames:
logger.warning("计算出的总宽度为0,但有选中帧,可能目标宽度太小")
# 至少给点宽度吧
total_width = len(resized_frames)
@@ -368,10 +368,7 @@ class ImageManager:
# 转换为base64
buffer = io.BytesIO()
combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG
- result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
-
- return result_base64
-
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
except MemoryError:
logger.error("GIF转换失败: 内存不足,可能是GIF太大或帧数太多")
return None # 内存不够啦
@@ -380,6 +377,7 @@ class ImageManager:
return None # 其他错误也返回None
async def process_image(self, image_base64: str) -> Tuple[str, str]:
+ # sourcery skip: hoist-if-from-if
"""处理图片并返回图片ID和描述
Args:
@@ -418,17 +416,9 @@ class ImageManager:
if existing_image.vlm_processed is None:
existing_image.vlm_processed = False
- existing_image.count += 1
- existing_image.save()
- return existing_image.image_id, f"[picid:{existing_image.image_id}]"
- else:
- # print(f"图片已存在: {existing_image.image_id}")
- # print(f"图片描述: {existing_image.description}")
- # print(f"图片计数: {existing_image.count}")
- # 更新计数
- existing_image.count += 1
- existing_image.save()
- return existing_image.image_id, f"[picid:{existing_image.image_id}]"
+ existing_image.count += 1
+ existing_image.save()
+ return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else:
# print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4())
diff --git a/src/common/database/database.py b/src/common/database/database.py
index 249664155..ca3614816 100644
--- a/src/common/database/database.py
+++ b/src/common/database/database.py
@@ -54,11 +54,11 @@ class DBWrapper:
return getattr(get_db(), name)
def __getitem__(self, key):
- return get_db()[key]
+ return get_db()[key] # type: ignore
# 全局数据库访问点
-memory_db: Database = DBWrapper()
+memory_db: Database = DBWrapper() # type: ignore
# 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py
index 3485fedeb..b411e1b3a 100644
--- a/src/common/database/database_model.py
+++ b/src/common/database/database_model.py
@@ -406,9 +406,7 @@ def initialize_database():
existing_columns = {row[1] for row in cursor.fetchall()}
model_fields = set(model._meta.fields.keys())
- # 检查并添加缺失字段(原有逻辑)
- missing_fields = model_fields - existing_columns
- if missing_fields:
+ if missing_fields := model_fields - existing_columns:
logger.warning(f"表 '{table_name}' 缺失字段: {missing_fields}")
for field_name, field_obj in model._meta.fields.items():
@@ -424,10 +422,7 @@ def initialize_database():
"DateTimeField": "DATETIME",
}.get(field_type, "TEXT")
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
- if field_obj.null:
- alter_sql += " NULL"
- else:
- alter_sql += " NOT NULL"
+ alter_sql += " NULL" if field_obj.null else " NOT NULL"
if hasattr(field_obj, "default") and field_obj.default is not None:
# 正确处理不同类型的默认值
default_value = field_obj.default
diff --git a/src/common/logger.py b/src/common/logger.py
index 40fd15070..a235cf341 100644
--- a/src/common/logger.py
+++ b/src/common/logger.py
@@ -1,16 +1,16 @@
-import logging
-
# 使用基于时间戳的文件处理器,简单的轮转份数限制
-from pathlib import Path
-from typing import Callable, Optional
+
+import logging
import json
import threading
import time
-from datetime import datetime, timedelta
-
import structlog
import toml
+from pathlib import Path
+from typing import Callable, Optional
+from datetime import datetime, timedelta
+
# 创建logs目录
LOG_DIR = Path("logs")
LOG_DIR.mkdir(exist_ok=True)
@@ -160,7 +160,7 @@ def close_handlers():
_console_handler = None
-def remove_duplicate_handlers():
+def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-comprehension
"""移除重复的handler,特别是文件handler"""
root_logger = logging.getLogger()
@@ -184,7 +184,7 @@ def remove_duplicate_handlers():
# 读取日志配置
-def load_log_config():
+def load_log_config(): # sourcery skip: use-contextlib-suppress
"""从配置文件加载日志设置"""
config_path = Path("config/bot_config.toml")
default_config = {
@@ -365,7 +365,7 @@ MODULE_COLORS = {
"component_registry": "\033[38;5;214m", # 橙黄色
"stream_api": "\033[38;5;220m", # 黄色
"config_api": "\033[38;5;226m", # 亮黄色
- "hearflow_api": "\033[38;5;154m", # 黄绿色
+ "heartflow_api": "\033[38;5;154m", # 黄绿色
"action_apis": "\033[38;5;118m", # 绿色
"independent_apis": "\033[38;5;82m", # 绿色
"llm_api": "\033[38;5;46m", # 亮绿色
@@ -412,6 +412,7 @@ class ModuleColoredConsoleRenderer:
"""自定义控制台渲染器,为不同模块提供不同颜色"""
def __init__(self, colors=True):
+ # sourcery skip: merge-duplicate-blocks, remove-redundant-if
self._colors = colors
self._config = LOG_CONFIG
@@ -443,6 +444,7 @@ class ModuleColoredConsoleRenderer:
self._enable_full_content_colors = False
def __call__(self, logger, method_name, event_dict):
+ # sourcery skip: merge-duplicate-blocks
"""渲染日志消息"""
# 获取基本信息
timestamp = event_dict.get("timestamp", "")
@@ -662,7 +664,7 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
"""获取logger实例,支持按名称绑定"""
if name is None:
return raw_logger
- logger = binds.get(name)
+ logger = binds.get(name) # type: ignore
if logger is None:
logger: structlog.stdlib.BoundLogger = structlog.get_logger(name).bind(logger_name=name)
binds[name] = logger
@@ -671,8 +673,8 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
def configure_logging(
level: str = "INFO",
- console_level: str = None,
- file_level: str = None,
+ console_level: Optional[str] = None,
+ file_level: Optional[str] = None,
max_bytes: int = 5 * 1024 * 1024,
backup_count: int = 30,
log_dir: str = "logs",
@@ -729,14 +731,11 @@ def reload_log_config():
global LOG_CONFIG
LOG_CONFIG = load_log_config()
- # 重新设置handler的日志级别
- file_handler = get_file_handler()
- if file_handler:
+ if file_handler := get_file_handler():
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO))
- console_handler = get_console_handler()
- if console_handler:
+ if console_handler := get_console_handler():
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO))
@@ -780,8 +779,7 @@ def set_console_log_level(level: str):
global LOG_CONFIG
LOG_CONFIG["console_log_level"] = level.upper()
- console_handler = get_console_handler()
- if console_handler:
+ if console_handler := get_console_handler():
console_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
# 重新设置root logger级别
@@ -800,8 +798,7 @@ def set_file_log_level(level: str):
global LOG_CONFIG
LOG_CONFIG["file_log_level"] = level.upper()
- file_handler = get_file_handler()
- if file_handler:
+ if file_handler := get_file_handler():
file_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
# 重新设置root logger级别
@@ -933,13 +930,12 @@ def format_json_for_logging(data, indent=2, ensure_ascii=False):
Returns:
str: 格式化后的JSON字符串
"""
- if isinstance(data, str):
- # 如果是JSON字符串,先解析再格式化
- parsed_data = json.loads(data)
- return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii)
- else:
+ if not isinstance(data, str):
# 如果是对象,直接格式化
return json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
+ # 如果是JSON字符串,先解析再格式化
+ parsed_data = json.loads(data)
+ return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii)
def cleanup_old_logs():
diff --git a/src/common/message/api.py b/src/common/message/api.py
index 59ba9d1e2..eed85c0a9 100644
--- a/src/common/message/api.py
+++ b/src/common/message/api.py
@@ -8,7 +8,7 @@ from src.config.config import global_config
global_api = None
-def get_global_api() -> MessageServer:
+def get_global_api() -> MessageServer: # sourcery skip: extract-method
"""获取全局MessageServer实例"""
global global_api
if global_api is None:
@@ -36,9 +36,8 @@ def get_global_api() -> MessageServer:
kwargs["custom_logger"] = maim_message_logger
# 添加token认证
- if maim_message_config.auth_token:
- if len(maim_message_config.auth_token) > 0:
- kwargs["enable_token"] = True
+ if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
+ kwargs["enable_token"] = True
if maim_message_config.use_custom:
# 添加WSS模式支持
diff --git a/src/common/message_repository.py b/src/common/message_repository.py
index 107ee1c5e..dc5d8b7df 100644
--- a/src/common/message_repository.py
+++ b/src/common/message_repository.py
@@ -1,9 +1,11 @@
-from src.common.database.database_model import Messages # 更改导入
-from src.common.logger import get_logger
import traceback
+
from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入
+from src.common.database.database_model import Messages
+from src.common.logger import get_logger
+
logger = get_logger(__name__)
diff --git a/src/common/remote.py b/src/common/remote.py
index 955e760b0..5380cd01e 100644
--- a/src/common/remote.py
+++ b/src/common/remote.py
@@ -23,7 +23,7 @@ class TelemetryHeartBeatTask(AsyncTask):
self.server_url = TELEMETRY_SERVER_URL
"""遥测服务地址"""
- self.client_uuid = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None
+ self.client_uuid: str | None = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None # type: ignore
"""客户端UUID"""
self.info_dict = self._get_sys_info()
@@ -72,7 +72,7 @@ class TelemetryHeartBeatTask(AsyncTask):
timeout=aiohttp.ClientTimeout(total=5), # 设置超时时间为5秒
) as response:
logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client")
- logger.debug(local_storage["deploy_time"])
+ logger.debug(local_storage["deploy_time"]) # type: ignore
logger.debug(f"Response status: {response.status}")
if response.status == 200:
@@ -93,7 +93,7 @@ class TelemetryHeartBeatTask(AsyncTask):
except Exception as e:
import traceback
- error_msg = str(e) if str(e) else "未知错误"
+ error_msg = str(e) or "未知错误"
logger.warning(
f"请求UUID出错,不过你还是可以正常使用麦麦: {type(e).__name__}: {error_msg}"
) # 可能是网络问题
@@ -114,11 +114,11 @@ class TelemetryHeartBeatTask(AsyncTask):
"""向服务器发送心跳"""
headers = {
"Client-UUID": self.client_uuid,
- "User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}",
+ "User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}", # type: ignore
}
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
- logger.debug(headers)
+ logger.debug(str(headers))
try:
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
@@ -151,7 +151,7 @@ class TelemetryHeartBeatTask(AsyncTask):
except Exception as e:
import traceback
- error_msg = str(e) if str(e) else "未知错误"
+ error_msg = str(e) or "未知错误"
logger.warning(f"(此消息不会影响正常使用)状态未发生: {type(e).__name__}: {error_msg}")
logger.debug(f"完整错误信息: {traceback.format_exc()}")
diff --git a/src/config/auto_update.py b/src/config/auto_update.py
index 2088e3628..139003a84 100644
--- a/src/config/auto_update.py
+++ b/src/config/auto_update.py
@@ -1,5 +1,6 @@
import shutil
import tomlkit
+from tomlkit.items import Table
from pathlib import Path
from datetime import datetime
@@ -45,8 +46,8 @@ def update_config():
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
- old_version = old_config["inner"].get("version")
- new_version = new_config["inner"].get("version")
+ old_version = old_config["inner"].get("version") # type: ignore
+ new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
print(f"检测到版本号相同 (v{old_version}),跳过更新")
# 如果version相同,恢复旧配置文件并返回
@@ -62,7 +63,7 @@ def update_config():
if key == "version":
continue
if key in target:
- if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)):
+ if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value)
else:
try:
@@ -85,10 +86,7 @@ def update_config():
if value and isinstance(value[0], dict) and "regex" in value[0]:
contains_regex = True
- if contains_regex:
- target[key] = value
- else:
- target[key] = tomlkit.array(value)
+ target[key] = value if contains_regex else tomlkit.array(str(value))
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
diff --git a/src/config/config.py b/src/config/config.py
index de173a520..b61111ec3 100644
--- a/src/config/config.py
+++ b/src/config/config.py
@@ -1,16 +1,14 @@
import os
-from dataclasses import field, dataclass
-
import tomlkit
import shutil
-from datetime import datetime
+from datetime import datetime
from tomlkit import TOMLDocument
from tomlkit.items import Table
-
-from src.common.logger import get_logger
+from dataclasses import field, dataclass
from rich.traceback import install
+from src.common.logger import get_logger
from src.config.config_base import ConfigBase
from src.config.official_configs import (
BotConfig,
@@ -80,8 +78,8 @@ def update_config():
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
- old_version = old_config["inner"].get("version")
- new_version = new_config["inner"].get("version")
+ old_version = old_config["inner"].get("version") # type: ignore
+ new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
return
@@ -103,7 +101,7 @@ def update_config():
shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新配置文件: {new_config_path}")
- def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
+ def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中(如果target中存在相同的键)
"""
@@ -112,8 +110,9 @@ def update_config():
if key == "version":
continue
if key in target:
- if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
- update_dict(target[key], value)
+ target_value = target[key]
+ if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
+ update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
diff --git a/src/config/config_base.py b/src/config/config_base.py
index 129f5a1c0..5fb398190 100644
--- a/src/config/config_base.py
+++ b/src/config/config_base.py
@@ -43,7 +43,7 @@ class ConfigBase:
field_type = f.type
try:
- init_args[field_name] = cls._convert_field(value, field_type)
+ init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
except Exception as e:
diff --git a/src/config/official_configs.py b/src/config/official_configs.py
index 7e2efbeba..6838df1d1 100644
--- a/src/config/official_configs.py
+++ b/src/config/official_configs.py
@@ -1,7 +1,8 @@
-from dataclasses import dataclass, field
-from typing import Any, Literal
import re
+from dataclasses import dataclass, field
+from typing import Any, Literal, Optional
+
from src.config.config_base import ConfigBase
"""
@@ -113,7 +114,7 @@ class ChatConfig(ConfigBase):
exit_focus_threshold: float = 1.0
"""自动退出专注聊天的阈值,越低越容易退出专注聊天"""
- def get_current_talk_frequency(self, chat_stream_id: str = None) -> float:
+ def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
"""
根据当前时间和聊天流获取对应的 talk_frequency
@@ -138,7 +139,7 @@ class ChatConfig(ConfigBase):
# 如果都没有匹配,返回默认值
return self.talk_frequency
- def _get_time_based_frequency(self, time_freq_list: list[str]) -> float:
+ def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
"""
根据时间配置列表获取当前时段的频率
@@ -186,7 +187,7 @@ class ChatConfig(ConfigBase):
return current_frequency
- def _get_stream_specific_frequency(self, chat_stream_id: str) -> float:
+ def _get_stream_specific_frequency(self, chat_stream_id: str):
"""
获取特定聊天流在当前时间的频率
@@ -217,7 +218,7 @@ class ChatConfig(ConfigBase):
return None
- def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> str:
+ def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
"""
解析流配置字符串并生成对应的 chat_id
diff --git a/src/individuality/identity.py b/src/individuality/identity.py
index bb3125985..730615e3d 100644
--- a/src/individuality/identity.py
+++ b/src/individuality/identity.py
@@ -1,5 +1,5 @@
from dataclasses import dataclass
-from typing import List
+from typing import List, Optional
@dataclass
@@ -8,7 +8,7 @@ class Identity:
identity_detail: List[str] # 身份细节描述
- def __init__(self, identity_detail: List[str] = None):
+ def __init__(self, identity_detail: Optional[List[str]] = None):
"""初始化身份特征
Args:
diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py
index 8365c0888..532b203fd 100644
--- a/src/individuality/individuality.py
+++ b/src/individuality/individuality.py
@@ -1,17 +1,18 @@
-from typing import Optional
import ast
-
-from src.llm_models.utils_model import LLMRequest
-from .personality import Personality
-from .identity import Identity
import random
import json
import os
import hashlib
+
+from typing import Optional
from rich.traceback import install
+
from src.common.logger import get_logger
-from src.person_info.person_info import get_person_info_manager
from src.config.config import global_config
+from src.llm_models.utils_model import LLMRequest
+from src.person_info.person_info import get_person_info_manager
+from .personality import Personality
+from .identity import Identity
install(extra_lines=3)
@@ -23,7 +24,7 @@ class Individuality:
def __init__(self):
# 正常初始化实例属性
- self.personality: Optional[Personality] = None
+ self.personality: Personality = None # type: ignore
self.identity: Optional[Identity] = None
self.name = ""
@@ -109,7 +110,7 @@ class Individuality:
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
if existing_short_impression:
try:
- existing_data = ast.literal_eval(existing_short_impression)
+ existing_data = ast.literal_eval(existing_short_impression) # type: ignore
if isinstance(existing_data, list) and len(existing_data) >= 1:
personality_result = existing_data[0]
except (json.JSONDecodeError, TypeError, IndexError):
@@ -128,7 +129,7 @@ class Individuality:
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
if existing_short_impression:
try:
- existing_data = ast.literal_eval(existing_short_impression)
+ existing_data = ast.literal_eval(existing_short_impression) # type: ignore
if isinstance(existing_data, list) and len(existing_data) >= 2:
identity_result = existing_data[1]
except (json.JSONDecodeError, TypeError, IndexError):
@@ -204,6 +205,7 @@ class Individuality:
return prompt_personality
def get_identity_prompt(self, level: int, x_person: int = 2) -> str:
+ # sourcery skip: assign-if-exp, merge-else-if-into-elif
"""
获取身份特征的prompt
@@ -240,13 +242,13 @@ class Individuality:
if identity_parts:
details_str = ",".join(identity_parts)
- if x_person in [1, 2]:
+ if x_person in {1, 2}:
return f"{i_pronoun},{details_str}。"
else: # x_person == 0
# 无人称时,直接返回细节,不加代词和开头的逗号
return f"{details_str}。"
else:
- if x_person in [1, 2]:
+ if x_person in {1, 2}:
return f"{i_pronoun}的身份信息不完整。"
else: # x_person == 0
return "身份信息不完整。"
@@ -441,14 +443,15 @@ class Individuality:
if info_list_json:
try:
info_list = json.loads(info_list_json) if isinstance(info_list_json, str) else info_list_json
- for item in info_list:
- if isinstance(item, dict) and "info_type" in item:
- keywords.append(item["info_type"])
+ keywords.extend(
+ item["info_type"] for item in info_list if isinstance(item, dict) and "info_type" in item
+ )
except (json.JSONDecodeError, TypeError):
logger.error(f"解析info_list失败: {info_list_json}")
return keywords
async def _create_personality(self, personality_core: str, personality_sides: list) -> str:
+ # sourcery skip: merge-list-append, move-assign
"""使用LLM创建压缩版本的impression
Args:
diff --git a/src/individuality/personality.py b/src/individuality/personality.py
index 0ee46a3d0..ace719331 100644
--- a/src/individuality/personality.py
+++ b/src/individuality/personality.py
@@ -1,6 +1,7 @@
-from dataclasses import dataclass
-from typing import Dict, List
import json
+
+from dataclasses import dataclass
+from typing import Dict, List, Optional
from pathlib import Path
@@ -24,7 +25,7 @@ class Personality:
cls._instance = super().__new__(cls)
return cls._instance
- def __init__(self, personality_core: str = "", personality_sides: List[str] = None):
+ def __init__(self, personality_core: str = "", personality_sides: Optional[List[str]] = None):
if personality_sides is None:
personality_sides = []
self.personality_core = personality_core
@@ -41,7 +42,7 @@ class Personality:
cls._instance = cls()
return cls._instance
- def _init_big_five_personality(self):
+ def _init_big_five_personality(self): # sourcery skip: extract-method
"""初始化大五人格特质"""
# 构建文件路径
personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per"
@@ -63,7 +64,6 @@ class Personality:
else:
self.extraversion = 0.3
self.neuroticism = 0.5
-
if "认真" in self.personality_core or "负责" in self.personality_sides:
self.conscientiousness = 0.9
else:
diff --git a/src/manager/async_task_manager.py b/src/manager/async_task_manager.py
index 1e1e9132f..0a2c0d215 100644
--- a/src/manager/async_task_manager.py
+++ b/src/manager/async_task_manager.py
@@ -120,12 +120,7 @@ class AsyncTaskManager:
"""
获取所有任务的状态
"""
- tasks_status = {}
- for task_name, task in self.tasks.items():
- tasks_status[task_name] = {
- "status": "running" if not task.done() else "done",
- }
- return tasks_status
+ return {task_name: {"status": "done" if task.done() else "running"} for task_name, task in self.tasks.items()}
async def stop_and_wait_all_tasks(self):
"""
diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py
index ffdf8ff36..e3a66370b 100644
--- a/src/mood/mood_manager.py
+++ b/src/mood/mood_manager.py
@@ -2,12 +2,12 @@ import math
import random
import time
-from src.chat.message_receive.message import MessageRecv
-from src.llm_models.utils_model import LLMRequest
-from ..common.logger import get_logger
-from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
+from src.common.logger import get_logger
from src.config.config import global_config
+from src.chat.message_receive.message import MessageRecv
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
+from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
+from src.llm_models.utils_model import LLMRequest
from src.manager.async_task_manager import AsyncTask, async_task_manager
logger = get_logger("mood")
@@ -55,12 +55,12 @@ class ChatMood:
request_type="mood",
)
- self.last_change_time = 0
+ self.last_change_time: float = 0
async def update_mood_by_message(self, message: MessageRecv, interested_rate: float):
self.regression_count = 0
- during_last_time = message.message_info.time - self.last_change_time
+ during_last_time = message.message_info.time - self.last_change_time # type: ignore
base_probability = 0.05
time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time))
@@ -78,7 +78,7 @@ class ChatMood:
if random.random() > update_probability:
return
- message_time = message.message_info.time
+ message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
@@ -119,7 +119,7 @@ class ChatMood:
self.mood_state = response
- self.last_change_time = message_time
+ self.last_change_time = message_time # type: ignore
async def regress_mood(self):
message_time = time.time()
diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py
index f44a88225..5e5f033f9 100644
--- a/src/person_info/person_info.py
+++ b/src/person_info/person_info.py
@@ -1,17 +1,18 @@
-from src.common.logger import get_logger
-from src.common.database.database import db
-from src.common.database.database_model import PersonInfo # 新增导入
import copy
import hashlib
-from typing import Any, Callable, Dict, Union
import datetime
import asyncio
+import json
+
+from json_repair import repair_json
+from typing import Any, Callable, Dict, Union, Optional
+
+from src.common.logger import get_logger
+from src.common.database.database import db
+from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
-import json # 新增导入
-from json_repair import repair_json
-
"""
PersonInfoManager 类方法功能摘要:
@@ -42,7 +43,7 @@ person_info_default = {
"last_know": None,
# "user_cardname": None, # This field is not in Peewee model PersonInfo
# "user_avatar": None, # This field is not in Peewee model PersonInfo
- "impression": None, # Corrected from persion_impression
+ "impression": None, # Corrected from person_impression
"short_impression": None,
"info_list": None,
"points": None,
@@ -106,27 +107,24 @@ class PersonInfoManager:
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
return False
- def get_person_id_by_person_name(self, person_name: str):
+ def get_person_id_by_person_name(self, person_name: str) -> str:
"""根据用户名获取用户ID"""
try:
record = PersonInfo.get_or_none(PersonInfo.person_name == person_name)
- if record:
- return record.person_id
- else:
- return ""
+ return record.person_id if record else ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
return ""
@staticmethod
- async def create_person_info(person_id: str, data: dict = None):
+ async def create_person_info(person_id: str, data: Optional[dict] = None):
"""创建一个项"""
if not person_id:
- logger.debug("创建失败,personid不存在")
+ logger.debug("创建失败,person_id不存在")
return
_person_info_default = copy.deepcopy(person_info_default)
- model_fields = PersonInfo._meta.fields.keys()
+ model_fields = PersonInfo._meta.fields.keys() # type: ignore
final_data = {"person_id": person_id}
@@ -163,9 +161,9 @@ class PersonInfoManager:
await asyncio.to_thread(_db_create_sync, final_data)
- async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
+ async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
"""更新某一个字段,会补全"""
- if field_name not in PersonInfo._meta.fields:
+ if field_name not in PersonInfo._meta.fields: # type: ignore
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
return
@@ -228,15 +226,13 @@ class PersonInfoManager:
@staticmethod
async def has_one_field(person_id: str, field_name: str):
"""判断是否存在某一个字段"""
- if field_name not in PersonInfo._meta.fields:
+ if field_name not in PersonInfo._meta.fields: # type: ignore
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
return False
def _db_has_field_sync(p_id: str, f_name: str):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
- if record:
- return True
- return False
+ return bool(record)
try:
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
@@ -435,9 +431,7 @@ class PersonInfoManager:
except Exception as e:
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
# Fallback to default in case of any error during DB access
- if field_name in person_info_default:
- return default_value_for_field
- return None
+ return default_value_for_field if field_name in person_info_default else None
@staticmethod
def get_value_sync(person_id: str, field_name: str):
@@ -446,8 +440,7 @@ class PersonInfoManager:
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = []
- record = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
- if record:
+ if record := PersonInfo.get_or_none(PersonInfo.person_id == person_id):
val = getattr(record, field_name, None)
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
@@ -481,7 +474,7 @@ class PersonInfoManager:
record = await asyncio.to_thread(_db_get_record_sync, person_id)
for field_name in field_names:
- if field_name not in PersonInfo._meta.fields:
+ if field_name not in PersonInfo._meta.fields: # type: ignore
if field_name in person_info_default:
result[field_name] = copy.deepcopy(person_info_default[field_name])
logger.debug(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。")
@@ -509,7 +502,7 @@ class PersonInfoManager:
"""
获取满足条件的字段值字典
"""
- if field_name not in PersonInfo._meta.fields:
+ if field_name not in PersonInfo._meta.fields: # type: ignore
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义")
return {}
@@ -531,7 +524,7 @@ class PersonInfoManager:
return {}
async def get_or_create_person(
- self, platform: str, user_id: int, nickname: str = None, user_cardname: str = None, user_avatar: str = None
+ self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None
) -> str:
"""
根据 platform 和 user_id 获取 person_id。
@@ -561,7 +554,7 @@ class PersonInfoManager:
"points": [],
"forgotten_points": [],
}
- model_fields = PersonInfo._meta.fields.keys()
+ model_fields = PersonInfo._meta.fields.keys() # type: ignore
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
await self.create_person_info(person_id, data=filtered_initial_data)
@@ -610,7 +603,9 @@ class PersonInfoManager:
"name_reason",
]
valid_fields_to_get = [
- f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default
+ f
+ for f in required_fields
+ if f in PersonInfo._meta.fields or f in person_info_default # type: ignore
]
person_data = await self.get_values(found_person_id, valid_fields_to_get)
diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py
index 0b443850f..7b69b47bb 100644
--- a/src/person_info/relationship_builder.py
+++ b/src/person_info/relationship_builder.py
@@ -3,12 +3,12 @@ import traceback
import os
import pickle
import random
-from typing import List, Dict
+from typing import List, Dict, Any
from src.config.config import global_config
from src.common.logger import get_logger
-from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.relationship_manager import get_relationship_manager
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
+from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
@@ -45,7 +45,7 @@ class RelationshipBuilder:
self.chat_id = chat_id
# 新的消息段缓存结构:
# {person_id: [{"start_time": float, "end_time": float, "last_msg_time": float, "message_count": int}, ...]}
- self.person_engaged_cache: Dict[str, List[Dict[str, any]]] = {}
+ self.person_engaged_cache: Dict[str, List[Dict[str, Any]]] = {}
# 持久化存储文件路径
self.cache_file_path = os.path.join("data", "relationship", f"relationship_cache_{self.chat_id}.pkl")
@@ -210,11 +210,7 @@ class RelationshipBuilder:
if person_id not in self.person_engaged_cache:
return 0
- total_count = 0
- for segment in self.person_engaged_cache[person_id]:
- total_count += segment["message_count"]
-
- return total_count
+ return sum(segment["message_count"] for segment in self.person_engaged_cache[person_id])
def _cleanup_old_segments(self) -> bool:
"""清理老旧的消息段"""
@@ -289,7 +285,7 @@ class RelationshipBuilder:
self.last_cleanup_time = current_time
# 保存缓存
- if cleanup_stats["segments_removed"] > 0 or len(users_to_remove) > 0:
+ if cleanup_stats["segments_removed"] > 0 or users_to_remove:
self._save_cache()
logger.info(
f"{self.log_prefix} 清理完成 - 影响用户: {cleanup_stats['users_cleaned']}, 移除消息段: {cleanup_stats['segments_removed']}, 移除用户: {len(users_to_remove)}"
@@ -313,6 +309,7 @@ class RelationshipBuilder:
return False
def get_cache_status(self) -> str:
+ # sourcery skip: merge-list-append, merge-list-appends-into-extend
"""获取缓存状态信息,用于调试和监控"""
if not self.person_engaged_cache:
return f"{self.log_prefix} 关系缓存为空"
@@ -357,13 +354,12 @@ class RelationshipBuilder:
self._cleanup_old_segments()
current_time = time.time()
- latest_messages = get_raw_msg_by_timestamp_with_chat(
+ if latest_messages := get_raw_msg_by_timestamp_with_chat(
self.chat_id,
self.last_processed_message_time,
current_time,
limit=50, # 获取自上次处理后的消息
- )
- if latest_messages:
+ ):
# 处理所有新的非bot消息
for latest_msg in latest_messages:
user_id = latest_msg.get("user_id")
@@ -414,7 +410,7 @@ class RelationshipBuilder:
# 负责触发关系构建、整合消息段、更新用户印象
# ================================
- async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, any]]):
+ async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]):
"""基于消息段更新用户印象"""
original_segment_count = len(segments)
logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象")
diff --git a/src/person_info/relationship_builder_manager.py b/src/person_info/relationship_builder_manager.py
index 926d67fca..f3bca25d2 100644
--- a/src/person_info/relationship_builder_manager.py
+++ b/src/person_info/relationship_builder_manager.py
@@ -1,4 +1,5 @@
-from typing import Dict, Optional, List
+from typing import Dict, Optional, List, Any
+
from src.common.logger import get_logger
from .relationship_builder import RelationshipBuilder
@@ -63,7 +64,7 @@ class RelationshipBuilderManager:
"""
return list(self.builders.keys())
- def get_status(self) -> Dict[str, any]:
+ def get_status(self) -> Dict[str, Any]:
"""获取管理器状态
Returns:
@@ -94,9 +95,7 @@ class RelationshipBuilderManager:
bool: 是否成功清理
"""
builder = self.get_builder(chat_id)
- if builder:
- return builder.force_cleanup_user_segments(person_id)
- return False
+ return builder.force_cleanup_user_segments(person_id) if builder else False
# 全局管理器实例
diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py
index 65be0b3af..5e369e752 100644
--- a/src/person_info/relationship_fetcher.py
+++ b/src/person_info/relationship_fetcher.py
@@ -1,16 +1,19 @@
-from src.config.config import global_config
-from src.llm_models.utils_model import LLMRequest
import time
import traceback
-from src.common.logger import get_logger
-from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
-from src.person_info.person_info import get_person_info_manager
-from typing import List, Dict
-from json_repair import repair_json
-from src.chat.message_receive.chat_stream import get_chat_manager
import json
import random
+from typing import List, Dict, Any
+from json_repair import repair_json
+
+from src.common.logger import get_logger
+from src.config.config import global_config
+from src.llm_models.utils_model import LLMRequest
+from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
+from src.chat.message_receive.chat_stream import get_chat_manager
+from src.person_info.person_info import get_person_info_manager
+
+
logger = get_logger("relationship_fetcher")
@@ -62,11 +65,11 @@ class RelationshipFetcher:
self.chat_id = chat_id
# 信息获取缓存:记录正在获取的信息请求
- self.info_fetching_cache: List[Dict[str, any]] = []
+ self.info_fetching_cache: List[Dict[str, Any]] = []
# 信息结果缓存:存储已获取的信息结果,带TTL
- self.info_fetched_cache: Dict[str, Dict[str, any]] = {}
- # 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknow": bool}}}
+ self.info_fetched_cache: Dict[str, Dict[str, Any]] = {}
+ # 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}}
# LLM模型配置
self.llm_model = LLMRequest(
@@ -184,7 +187,7 @@ class RelationshipFetcher:
nickname_str = ",".join(global_config.bot.alias_names)
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
person_info_manager = get_person_info_manager()
- person_name = await person_info_manager.get_value(person_id, "person_name")
+ person_name: str = await person_info_manager.get_value(person_id, "person_name") # type: ignore
info_cache_block = self._build_info_cache_block()
@@ -208,8 +211,7 @@ class RelationshipFetcher:
logger.debug(f"{self.log_prefix} LLM判断当前不需要查询任何信息:{content_json.get('none', '')}")
return None
- info_type = content_json.get("info_type")
- if info_type:
+ if info_type := content_json.get("info_type"):
# 记录信息获取请求
self.info_fetching_cache.append(
{
@@ -287,7 +289,7 @@ class RelationshipFetcher:
"ttl": 2,
"start_time": start_time,
"person_name": person_name,
- "unknow": cached_info == "none",
+ "unknown": cached_info == "none",
}
logger.info(f"{self.log_prefix} 记得 {person_name} 的 {info_type}: {cached_info}")
return
@@ -321,7 +323,7 @@ class RelationshipFetcher:
"ttl": 2,
"start_time": start_time,
"person_name": person_name,
- "unknow": True,
+ "unknown": True,
}
logger.info(f"{self.log_prefix} 完全不认识 {person_name}")
await self._save_info_to_cache(person_id, info_type, "none")
@@ -353,15 +355,15 @@ class RelationshipFetcher:
if person_id not in self.info_fetched_cache:
self.info_fetched_cache[person_id] = {}
self.info_fetched_cache[person_id][info_type] = {
- "info": "unknow" if is_unknown else info_content,
+ "info": "unknown" if is_unknown else info_content,
"ttl": 3,
"start_time": start_time,
"person_name": person_name,
- "unknow": is_unknown,
+ "unknown": is_unknown,
}
# 保存到持久化缓存 (info_list)
- await self._save_info_to_cache(person_id, info_type, info_content if not is_unknown else "none")
+ await self._save_info_to_cache(person_id, info_type, "none" if is_unknown else info_content)
if not is_unknown:
logger.info(f"{self.log_prefix} 思考得到,{person_name} 的 {info_type}: {info_content}")
@@ -393,7 +395,7 @@ class RelationshipFetcher:
for info_type in self.info_fetched_cache[person_id]:
person_name = self.info_fetched_cache[person_id][info_type]["person_name"]
- if not self.info_fetched_cache[person_id][info_type]["unknow"]:
+ if not self.info_fetched_cache[person_id][info_type]["unknown"]:
info_content = self.info_fetched_cache[person_id][info_type]["info"]
person_known_infos.append(f"[{info_type}]:{info_content}")
else:
@@ -430,6 +432,7 @@ class RelationshipFetcher:
return persons_infos_str
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
+ # sourcery skip: use-next
"""将提取到的信息保存到 person_info 的 info_list 字段中
Args:
diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py
index 039197250..2c544fe46 100644
--- a/src/person_info/relationship_manager.py
+++ b/src/person_info/relationship_manager.py
@@ -1,5 +1,5 @@
from src.common.logger import get_logger
-from src.person_info.person_info import PersonInfoManager, get_person_info_manager
+from .person_info import PersonInfoManager, get_person_info_manager
import time
import random
from src.llm_models.utils_model import LLMRequest
@@ -12,7 +12,7 @@ from difflib import SequenceMatcher
import jieba
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
-
+from typing import List, Dict, Any
logger = get_logger("relation")
@@ -28,8 +28,7 @@ class RelationshipManager:
async def is_known_some_one(platform, user_id):
"""判断是否认识某人"""
person_info_manager = get_person_info_manager()
- is_known = await person_info_manager.is_person_known(platform, user_id)
- return is_known
+ return await person_info_manager.is_person_known(platform, user_id)
@staticmethod
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
@@ -110,7 +109,7 @@ class RelationshipManager:
return relation_prompt
- async def update_person_impression(self, person_id, timestamp, bot_engaged_messages=None):
+ async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
"""更新用户印象
Args:
@@ -123,7 +122,7 @@ class RelationshipManager:
person_info_manager = get_person_info_manager()
person_name = await person_info_manager.get_value(person_id, "person_name")
nickname = await person_info_manager.get_value(person_id, "nickname")
- know_times = await person_info_manager.get_value(person_id, "know_times") or 0
+ know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore
alias_str = ", ".join(global_config.bot.alias_names)
# personality_block =get_individuality().get_personality_prompt(x_person=2, level=2)
@@ -142,13 +141,13 @@ class RelationshipManager:
# 遍历消息,构建映射
for msg in user_messages:
await person_info_manager.get_or_create_person(
- platform=msg.get("chat_info_platform"),
- user_id=msg.get("user_id"),
- nickname=msg.get("user_nickname"),
- user_cardname=msg.get("user_cardname"),
+ platform=msg.get("chat_info_platform"), # type: ignore
+ user_id=msg.get("user_id"), # type: ignore
+ nickname=msg.get("user_nickname"), # type: ignore
+ user_cardname=msg.get("user_cardname"), # type: ignore
)
- replace_user_id = msg.get("user_id")
- replace_platform = msg.get("chat_info_platform")
+ replace_user_id: str = msg.get("user_id") # type: ignore
+ replace_platform: str = msg.get("chat_info_platform") # type: ignore
replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id)
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
@@ -354,8 +353,8 @@ class RelationshipManager:
person_name = await person_info_manager.get_value(person_id, "person_name")
nickname = await person_info_manager.get_value(person_id, "nickname")
- know_times = await person_info_manager.get_value(person_id, "know_times") or 0
- attitude = await person_info_manager.get_value(person_id, "attitude") or 50
+ know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore
+ attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore
# 根据熟悉度,调整印象和简短印象的最大长度
if know_times > 300:
@@ -414,16 +413,14 @@ class RelationshipManager:
if len(remaining_points) < 10:
# 如果还没达到30条,直接保留
remaining_points.append(point)
+ elif random.random() < keep_probability:
+ # 保留这个点,随机移除一个已保留的点
+ idx_to_remove = random.randrange(len(remaining_points))
+ points_to_move.append(remaining_points[idx_to_remove])
+ remaining_points[idx_to_remove] = point
else:
- # 随机决定是否保留
- if random.random() < keep_probability:
- # 保留这个点,随机移除一个已保留的点
- idx_to_remove = random.randrange(len(remaining_points))
- points_to_move.append(remaining_points[idx_to_remove])
- remaining_points[idx_to_remove] = point
- else:
- # 不保留这个点
- points_to_move.append(point)
+ # 不保留这个点
+ points_to_move.append(point)
# 更新points和forgotten_points
current_points = remaining_points
@@ -520,7 +517,7 @@ class RelationshipManager:
new_attitude = int(relation_value_json.get("attitude", 50))
# 获取当前的关系值
- old_attitude = await person_info_manager.get_value(person_id, "attitude") or 50
+ old_attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore
# 更新熟悉度
if new_attitude > 25:
diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py
index c341e5214..6c8cc01da 100644
--- a/src/plugin_system/apis/generator_api.py
+++ b/src/plugin_system/apis/generator_api.py
@@ -65,9 +65,9 @@ def get_replyer(
async def generate_reply(
- chat_stream=None,
- chat_id: str = None,
- action_data: Dict[str, Any] = None,
+ chat_stream: Optional[ChatStream] = None,
+ chat_id: Optional[str] = None,
+ action_data: Optional[Dict[str, Any]] = None,
reply_to: str = "",
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
@@ -78,25 +78,25 @@ async def generate_reply(
model_configs: Optional[List[Dict[str, Any]]] = None,
request_type: str = "",
enable_timeout: bool = False,
-) -> Tuple[bool, List[Tuple[str, Any]]]:
+) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""生成回复
Args:
chat_stream: 聊天流对象(优先)
- action_data: 动作数据
chat_id: 聊天ID(备用)
+ action_data: 动作数据
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
return_prompt: 是否返回提示词
Returns:
- Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
+ Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
- return False, []
+ return False, [], None
logger.debug("[GeneratorAPI] 开始生成回复")
@@ -109,8 +109,9 @@ async def generate_reply(
enable_timeout=enable_timeout,
enable_tool=enable_tool,
)
-
- reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
+ reply_set = []
+ if content:
+ reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
if success:
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
@@ -118,19 +119,19 @@ async def generate_reply(
logger.warning("[GeneratorAPI] 回复生成失败")
if return_prompt:
- return success, reply_set or [], prompt
+ return success, reply_set, prompt
else:
- return success, reply_set or []
+ return success, reply_set, None
except Exception as e:
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
- return False, []
+ return False, [], None
async def rewrite_reply(
- chat_stream=None,
- reply_data: Dict[str, Any] = None,
- chat_id: str = None,
+ chat_stream: Optional[ChatStream] = None,
+ reply_data: Optional[Dict[str, Any]] = None,
+ chat_id: Optional[str] = None,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
model_configs: Optional[List[Dict[str, Any]]] = None,
@@ -158,15 +159,16 @@ async def rewrite_reply(
# 调用回复器重写回复
success, content = await replyer.rewrite_reply_with_context(reply_data=reply_data or {})
-
- reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
+ reply_set = []
+ if content:
+ reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
if success:
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项")
else:
logger.warning("[GeneratorAPI] 重写回复失败")
- return success, reply_set or []
+ return success, reply_set
except Exception as e:
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
diff --git a/src/tools/tool_executor.py b/src/tools/tool_executor.py
index 29ee8be1b..403ed554f 100644
--- a/src/tools/tool_executor.py
+++ b/src/tools/tool_executor.py
@@ -34,7 +34,7 @@ class ToolExecutor:
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
"""
- def __init__(self, chat_id: str = None, enable_cache: bool = True, cache_ttl: int = 3):
+ def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
"""初始化工具执行器
Args:
@@ -62,8 +62,8 @@ class ToolExecutor:
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}")
async def execute_from_chat_message(
- self, target_message: str, chat_history: list[str], sender: str, return_details: bool = False
- ) -> List[Dict] | Tuple[List[Dict], List[str], str]:
+ self, target_message: str, chat_history: str, sender: str, return_details: bool = False
+ ) -> Tuple[List[Dict], List[str], str]:
"""从聊天消息执行工具
Args:
@@ -79,16 +79,14 @@ class ToolExecutor:
# 首先检查缓存
cache_key = self._generate_cache_key(target_message, chat_history, sender)
- cached_result = self._get_from_cache(cache_key)
-
- if cached_result:
+ if cached_result := self._get_from_cache(cache_key):
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
- if return_details:
- # 从缓存结果中提取工具名称
- used_tools = [result.get("tool_name", "unknown") for result in cached_result]
- return cached_result, used_tools, "使用缓存结果"
- else:
- return cached_result
+ if not return_details:
+ return cached_result, [], "使用缓存结果"
+
+ # 从缓存结果中提取工具名称
+ used_tools = [result.get("tool_name", "unknown") for result in cached_result]
+ return cached_result, used_tools, "使用缓存结果"
# 缓存未命中,执行工具调用
# 获取可用工具
@@ -134,7 +132,7 @@ class ToolExecutor:
if return_details:
return tool_results, used_tools, prompt
else:
- return tool_results
+ return tool_results, [], ""
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
"""执行工具调用
@@ -207,7 +205,7 @@ class ToolExecutor:
return tool_results, used_tools
- def _generate_cache_key(self, target_message: str, chat_history: list[str], sender: str) -> str:
+ def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
"""生成缓存键
Args:
@@ -267,10 +265,7 @@ class ToolExecutor:
return
expired_keys = []
- for cache_key, cache_item in self.tool_cache.items():
- if cache_item["ttl"] <= 0:
- expired_keys.append(cache_key)
-
+ expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
for key in expired_keys:
del self.tool_cache[key]
@@ -355,7 +350,7 @@ class ToolExecutor:
"ttl_distribution": ttl_distribution,
}
- def set_cache_config(self, enable_cache: bool = None, cache_ttl: int = None):
+ def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
"""动态修改缓存配置
Args:
@@ -366,7 +361,7 @@ class ToolExecutor:
self.enable_cache = enable_cache
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
- if cache_ttl is not None and cache_ttl > 0:
+ if cache_ttl > 0:
self.cache_ttl = cache_ttl
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
@@ -380,7 +375,7 @@ init_tool_executor_prompt()
# 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3)
executor = ToolExecutor(executor_id="my_executor")
-results = await executor.execute_from_chat_message(
+results, _, _ = await executor.execute_from_chat_message(
talking_message_str="今天天气怎么样?现在几点了?",
is_group_chat=False
)