完成所有类型注解的修复

This commit is contained in:
UnCLAS-Prommer
2025-07-13 00:19:54 +08:00
parent d2ad6ea1d8
commit 7ef0bfb7c8
32 changed files with 358 additions and 434 deletions

View File

@@ -4,6 +4,7 @@ import asyncio
import random import random
import ast import ast
import re import re
from typing import List, Optional, Dict, Any, Tuple from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime from datetime import datetime
@@ -161,13 +162,13 @@ class DefaultReplyer:
async def generate_reply_with_context( async def generate_reply_with_context(
self, self,
reply_data: Dict[str, Any] = None, reply_data: Optional[Dict[str, Any]] = None,
reply_to: str = "", reply_to: str = "",
extra_info: str = "", extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None, available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_tool: bool = True, enable_tool: bool = True,
enable_timeout: bool = False, enable_timeout: bool = False,
) -> Tuple[bool, Optional[str]]: ) -> Tuple[bool, Optional[str], Optional[str]]:
""" """
回复器 (Replier): 核心逻辑,负责生成回复文本。 回复器 (Replier): 核心逻辑,负责生成回复文本。
(已整合原 HeartFCGenerator 的功能) (已整合原 HeartFCGenerator 的功能)
@@ -225,14 +226,14 @@ class DefaultReplyer:
except Exception as llm_e: except Exception as llm_e:
# 精简报错信息 # 精简报错信息
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}") logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
return False, None # LLM 调用失败则无法生成回复 return False, None, prompt # LLM 调用失败则无法生成回复
return True, content, prompt return True, content, prompt
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix}回复生成意外失败: {e}") logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
traceback.print_exc() traceback.print_exc()
return False, None return False, None, prompt
async def rewrite_reply_with_context( async def rewrite_reply_with_context(
self, self,
@@ -368,7 +369,7 @@ class DefaultReplyer:
memory_str += f"- {running_memory['content']}\n" memory_str += f"- {running_memory['content']}\n"
return memory_str return memory_str
async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True): async def build_tool_info(self, chat_history, reply_data: Optional[Dict], enable_tool: bool = True):
"""构建工具信息块 """构建工具信息块
Args: Args:
@@ -393,7 +394,7 @@ class DefaultReplyer:
try: try:
# 使用工具执行器获取信息 # 使用工具执行器获取信息
tool_results = await self.tool_executor.execute_from_chat_message( tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
sender=sender, target_message=text, chat_history=chat_history, return_details=False sender=sender, target_message=text, chat_history=chat_history, return_details=False
) )
@@ -468,7 +469,7 @@ class DefaultReplyer:
async def build_prompt_reply_context( async def build_prompt_reply_context(
self, self,
reply_data=None, reply_data: Dict[str, Any],
available_actions: Optional[Dict[str, ActionInfo]] = None, available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_timeout: bool = False, enable_timeout: bool = False,
enable_tool: bool = True, enable_tool: bool = True,
@@ -549,7 +550,7 @@ class DefaultReplyer:
), ),
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_half, target), "build_memory_block"), self._time_and_run_task(self.build_memory_block(chat_talking_prompt_half, target), "build_memory_block"),
self._time_and_run_task( self._time_and_run_task(
self.build_tool_info(reply_data, chat_talking_prompt_half, enable_tool=enable_tool), "build_tool_info" self.build_tool_info(chat_talking_prompt_half, reply_data, enable_tool=enable_tool), "build_tool_info"
), ),
) )
@@ -806,7 +807,7 @@ class DefaultReplyer:
response_set: List[Tuple[str, str]], response_set: List[Tuple[str, str]],
thinking_id: str = "", thinking_id: str = "",
display_message: str = "", display_message: str = "",
) -> Optional[MessageSending]: ) -> Optional[List[Tuple[str, bool]]]:
# sourcery skip: assign-if-exp, boolean-if-exp-identity, remove-unnecessary-cast # sourcery skip: assign-if-exp, boolean-if-exp-identity, remove-unnecessary-cast
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender""" """发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
chat = self.chat_stream chat = self.chat_stream
@@ -869,7 +870,7 @@ class DefaultReplyer:
try: try:
if ( if (
bot_message.is_private_message() bot_message.is_private_message()
or bot_message.reply.processed_plain_text != "[System Trigger Context]" or bot_message.reply.processed_plain_text != "[System Trigger Context]" # type: ignore
or mark_head or mark_head
): ):
set_reply = False set_reply = False
@@ -910,7 +911,7 @@ class DefaultReplyer:
is_emoji: bool, is_emoji: bool,
thinking_start_time: float, thinking_start_time: float,
display_message: str, display_message: str,
anchor_message: MessageRecv = None, anchor_message: Optional[MessageRecv] = None,
) -> MessageSending: ) -> MessageSending:
"""构建单个发送消息""" """构建单个发送消息"""

View File

@@ -1,8 +1,8 @@
from typing import Dict, Any, Optional, List from typing import Dict, Any, Optional, List
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.replyer.default_generator import DefaultReplyer from src.chat.replyer.default_generator import DefaultReplyer
from src.common.logger import get_logger
logger = get_logger("ReplyerManager") logger = get_logger("ReplyerManager")

View File

@@ -1,6 +1,7 @@
import time # 导入 time 模块以获取当前时间 import time # 导入 time 模块以获取当前时间
import random import random
import re import re
from typing import List, Dict, Any, Tuple, Optional from typing import List, Dict, Any, Tuple, Optional
from rich.traceback import install from rich.traceback import install
@@ -88,8 +89,8 @@ def get_actions_by_timestamp_with_chat(
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表""" """获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where( query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id) (ActionRecords.chat_id == chat_id)
& (ActionRecords.time > timestamp_start) & (ActionRecords.time > timestamp_start) # type: ignore
& (ActionRecords.time < timestamp_end) & (ActionRecords.time < timestamp_end) # type: ignore
) )
if limit > 0: if limit > 0:
@@ -113,8 +114,8 @@ def get_actions_by_timestamp_with_chat_inclusive(
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表""" """获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where( query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id) (ActionRecords.chat_id == chat_id)
& (ActionRecords.time >= timestamp_start) & (ActionRecords.time >= timestamp_start) # type: ignore
& (ActionRecords.time <= timestamp_end) & (ActionRecords.time <= timestamp_end) # type: ignore
) )
if limit > 0: if limit > 0:
@@ -331,7 +332,7 @@ def _build_readable_messages_internal(
if replace_bot_name and user_id == global_config.bot.qq_account: if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)" person_name = f"{global_config.bot.nickname}(你)"
else: else:
person_name = person_info_manager.get_value_sync(person_id, "person_name") person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称 # 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
if not person_name: if not person_name:
@@ -911,8 +912,8 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
person_ids_set = set() # 使用集合来自动去重 person_ids_set = set() # 使用集合来自动去重
for msg in messages: for msg in messages:
platform = msg.get("user_platform") platform: str = msg.get("user_platform") # type: ignore
user_id = msg.get("user_id") user_id: str = msg.get("user_id") # type: ignore
# 检查必要信息是否存在 且 不是机器人自己 # 检查必要信息是否存在 且 不是机器人自己
if not all([platform, user_id]) or user_id == global_config.bot.qq_account: if not all([platform, user_id]) or user_id == global_config.bot.qq_account:

View File

@@ -1,7 +1,8 @@
import ast
import json import json
import logging import logging
from typing import Any, Dict, TypeVar, List, Union, Tuple
import ast from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional
# 定义类型变量用于泛型类型提示 # 定义类型变量用于泛型类型提示
T = TypeVar("T") T = TypeVar("T")
@@ -30,16 +31,12 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
# 尝试标准的 JSON 解析 # 尝试标准的 JSON 解析
return json.loads(json_str) return json.loads(json_str)
except json.JSONDecodeError: except json.JSONDecodeError:
# 如果标准解析失败,尝试将单引号替换为双引号再解析 # 如果标准解析失败,尝试用 ast.literal_eval 解析
# (注意:这种替换可能不安全,如果字符串内容本身包含引号)
# 更安全的方式是用 ast.literal_eval
try: try:
# logger.debug(f"标准JSON解析失败尝试用 ast.literal_eval 解析: {json_str[:100]}...") # logger.debug(f"标准JSON解析失败尝试用 ast.literal_eval 解析: {json_str[:100]}...")
result = ast.literal_eval(json_str) result = ast.literal_eval(json_str)
# 确保结果是字典(因为我们通常期望参数是字典)
if isinstance(result, dict): if isinstance(result, dict):
return result return result
else:
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}") logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
return default_value return default_value
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e: except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
@@ -53,7 +50,9 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
return default_value return default_value
def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]: def extract_tool_call_arguments(
tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
""" """
从LLM工具调用对象中提取参数 从LLM工具调用对象中提取参数
@@ -77,13 +76,11 @@ def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[s
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}") logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
return default_result return default_result
# 提取arguments if arguments_str := function_data.get("arguments", "{}"):
arguments_str = function_data.get("arguments", "{}")
if not arguments_str:
return default_result
# 解析JSON # 解析JSON
return safe_json_loads(arguments_str, default_result) return safe_json_loads(arguments_str, default_result)
else:
return default_result
except Exception as e: except Exception as e:
logger.error(f"提取工具调用参数时出错: {e}") logger.error(f"提取工具调用参数时出错: {e}")

View File

@@ -1,12 +1,12 @@
from typing import Dict, Any, Optional, List, Union
import re import re
from contextlib import asynccontextmanager
import asyncio import asyncio
import contextvars import contextvars
from src.common.logger import get_logger
# import traceback
from rich.traceback import install from rich.traceback import install
from contextlib import asynccontextmanager
from typing import Dict, Any, Optional, List, Union
from src.common.logger import get_logger
install(extra_lines=3) install(extra_lines=3)
@@ -32,6 +32,7 @@ class PromptContext:
@asynccontextmanager @asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None): async def async_scope(self, context_id: Optional[str] = None):
# sourcery skip: hoist-statement-from-if, use-contextlib-suppress
"""创建一个异步的临时提示模板作用域""" """创建一个异步的临时提示模板作用域"""
# 保存当前上下文并设置新上下文 # 保存当前上下文并设置新上下文
if context_id is not None: if context_id is not None:
@@ -88,8 +89,7 @@ class PromptContext:
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None: async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域""" """异步注册提示模板到指定作用域"""
async with self._context_lock: async with self._context_lock:
target_context = context_id or self._current_context if target_context := context_id or self._current_context:
if target_context:
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
@@ -151,7 +151,7 @@ class Prompt(str):
@staticmethod @staticmethod
def _process_escaped_braces(template) -> str: def _process_escaped_braces(template) -> str:
"""处理模板中的转义花括号,将 \{\} 替换为临时标记""" """处理模板中的转义花括号,将 \{\} 替换为临时标记""" # type: ignore
# 如果传入的是列表,将其转换为字符串 # 如果传入的是列表,将其转换为字符串
if isinstance(template, list): if isinstance(template, list):
template = "\n".join(str(item) for item in template) template = "\n".join(str(item) for item in template)
@@ -195,13 +195,7 @@ class Prompt(str):
obj._kwargs = kwargs obj._kwargs = kwargs
# 修改自动注册逻辑 # 修改自动注册逻辑
if should_register: if should_register and not global_prompt_manager._context._current_context:
if global_prompt_manager._context._current_context:
# 如果存在当前上下文,则注册到上下文中
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
pass
else:
# 否则注册到全局管理器
global_prompt_manager.register(obj) global_prompt_manager.register(obj)
return obj return obj
@@ -276,15 +270,13 @@ class Prompt(str):
self.name, self.name,
args=list(args) if args else self._args, args=list(args) if args else self._args,
_should_register=False, _should_register=False,
**kwargs if kwargs else self._kwargs, **kwargs or self._kwargs,
) )
# print(f"prompt build result: {ret} name: {ret.name} ") # print(f"prompt build result: {ret} name: {ret.name} ")
return str(ret) return str(ret)
def __str__(self) -> str: def __str__(self) -> str:
if self._kwargs or self._args: return super().__str__() if self._kwargs or self._args else self.template
return super().__str__()
return self.template
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Prompt(template='{self.template}', name='{self.name}')" return f"Prompt(template='{self.template}', name='{self.name}')"

View File

@@ -1,18 +1,17 @@
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import json import json
import os import os
import glob import glob
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
from src.manager.async_task_manager import AsyncTask from src.manager.async_task_manager import AsyncTask
from ...common.database.database import db # This db is the Peewee database instance
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
from src.manager.local_store_manager import local_storage from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic") logger = get_logger("maibot_statistic")
@@ -76,14 +75,14 @@ class OnlineTimeRecordTask(AsyncTask):
with db.atomic(): # Use atomic operations for schema changes with db.atomic(): # Use atomic operations for schema changes
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
async def run(self): async def run(self): # sourcery skip: use-named-expression
try: try:
current_time = datetime.now() current_time = datetime.now()
extended_end_time = current_time + timedelta(minutes=1) extended_end_time = current_time + timedelta(minutes=1)
if self.record_id: if self.record_id:
# 如果有记录,则更新结束时间 # 如果有记录,则更新结束时间
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore
updated_rows = query.execute() updated_rows = query.execute()
if updated_rows == 0: if updated_rows == 0:
# Record might have been deleted or ID is stale, try to find/create # Record might have been deleted or ID is stale, try to find/create
@@ -94,7 +93,7 @@ class OnlineTimeRecordTask(AsyncTask):
# Look for a record whose end_timestamp is recent enough to be considered ongoing # Look for a record whose end_timestamp is recent enough to be considered ongoing
recent_record = ( recent_record = (
OnlineTime.select() OnlineTime.select()
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore
.order_by(OnlineTime.end_timestamp.desc()) .order_by(OnlineTime.end_timestamp.desc())
.first() .first()
) )
@@ -123,15 +122,15 @@ def _format_online_time(online_seconds: int) -> str:
:param online_seconds: 在线时间(秒) :param online_seconds: 在线时间(秒)
:return: 格式化后的在线时间字符串 :return: 格式化后的在线时间字符串
""" """
total_oneline_time = timedelta(seconds=online_seconds) total_online_time = timedelta(seconds=online_seconds)
days = total_oneline_time.days days = total_online_time.days
hours = total_oneline_time.seconds // 3600 hours = total_online_time.seconds // 3600
minutes = (total_oneline_time.seconds // 60) % 60 minutes = (total_online_time.seconds // 60) % 60
seconds = total_oneline_time.seconds % 60 seconds = total_online_time.seconds % 60
if days > 0: if days > 0:
# 如果在线时间超过1天则格式化为"X天X小时X分钟" # 如果在线时间超过1天则格式化为"X天X小时X分钟"
return f"{total_oneline_time.days}{hours}小时{minutes}分钟{seconds}" return f"{total_online_time.days}{hours}小时{minutes}分钟{seconds}"
elif hours > 0: elif hours > 0:
# 如果在线时间超过1小时则格式化为"X小时X分钟X秒" # 如果在线时间超过1小时则格式化为"X小时X分钟X秒"
return f"{hours}小时{minutes}分钟{seconds}" return f"{hours}小时{minutes}分钟{seconds}"
@@ -163,7 +162,7 @@ class StatisticOutputTask(AsyncTask):
now = datetime.now() now = datetime.now()
if "deploy_time" in local_storage: if "deploy_time" in local_storage:
# 如果存在部署时间,则使用该时间作为全量统计的起始时间 # 如果存在部署时间,则使用该时间作为全量统计的起始时间
deploy_time = datetime.fromtimestamp(local_storage["deploy_time"]) deploy_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
else: else:
# 否则,使用最大时间范围,并记录部署时间为当前时间 # 否则,使用最大时间范围,并记录部署时间为当前时间
deploy_time = datetime(2000, 1, 1) deploy_time = datetime(2000, 1, 1)
@@ -252,7 +251,7 @@ class StatisticOutputTask(AsyncTask):
# 创建后台任务,不等待完成 # 创建后台任务,不等待完成
collect_task = asyncio.create_task( collect_task = asyncio.create_task(
loop.run_in_executor(executor, self._collect_all_statistics, now) loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
) )
stats = await collect_task stats = await collect_task
@@ -260,8 +259,8 @@ class StatisticOutputTask(AsyncTask):
# 创建并发的输出任务 # 创建并发的输出任务
output_tasks = [ output_tasks = [
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
] ]
# 等待所有输出任务完成 # 等待所有输出任务完成
@@ -320,7 +319,7 @@ class StatisticOutputTask(AsyncTask):
# 以最早的时间戳为起始时间获取记录 # 以最早的时间戳为起始时间获取记录
# Assuming LLMUsage.timestamp is a DateTimeField # Assuming LLMUsage.timestamp is a DateTimeField
query_start_time = collect_period[-1][1] query_start_time = collect_period[-1][1]
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_timestamp = record.timestamp # This is already a datetime object record_timestamp = record.timestamp # This is already a datetime object
for idx, (_, period_start) in enumerate(collect_period): for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start: if record_timestamp >= period_start:
@@ -388,7 +387,7 @@ class StatisticOutputTask(AsyncTask):
query_start_time = collect_period[-1][1] query_start_time = collect_period[-1][1]
# Assuming OnlineTime.end_timestamp is a DateTimeField # Assuming OnlineTime.end_timestamp is a DateTimeField
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore
# record.end_timestamp and record.start_timestamp are datetime objects # record.end_timestamp and record.start_timestamp are datetime objects
record_end_timestamp = record.end_timestamp record_end_timestamp = record.end_timestamp
record_start_timestamp = record.start_timestamp record_start_timestamp = record.start_timestamp
@@ -428,7 +427,7 @@ class StatisticOutputTask(AsyncTask):
} }
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp): for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time # This is a float timestamp message_time_ts = message.time # This is a float timestamp
chat_id = None chat_id = None
@@ -661,7 +660,7 @@ class StatisticOutputTask(AsyncTask):
if "last_full_statistics" in local_storage: if "last_full_statistics" in local_storage:
# 如果存在上次完整统计数据,则使用该数据进行增量统计 # 如果存在上次完整统计数据,则使用该数据进行增量统计
last_stat = local_storage["last_full_statistics"] # 上次完整统计数据 last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射 self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据 last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
@@ -727,6 +726,7 @@ class StatisticOutputTask(AsyncTask):
return stat return stat
def _convert_defaultdict_to_dict(self, data): def _convert_defaultdict_to_dict(self, data):
# sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks
"""递归转换defaultdict为普通dict""" """递归转换defaultdict为普通dict"""
if isinstance(data, defaultdict): if isinstance(data, defaultdict):
# 转换defaultdict为普通dict # 转换defaultdict为普通dict
@@ -812,8 +812,7 @@ class StatisticOutputTask(AsyncTask):
# 全局阶段平均时间 # 全局阶段平均时间
if stats[FOCUS_AVG_TIMES_BY_STAGE]: if stats[FOCUS_AVG_TIMES_BY_STAGE]:
output.append("全局阶段平均时间:") output.append("全局阶段平均时间:")
for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items(): output.extend(f" {stage}: {avg_time:.3f}" for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items())
output.append(f" {stage}: {avg_time:.3f}")
output.append("") output.append("")
# Action类型比例 # Action类型比例
@@ -1050,7 +1049,7 @@ class StatisticOutputTask(AsyncTask):
] ]
tab_content_list.append( tab_content_list.append(
_format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) _format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) # type: ignore
) )
# 添加Focus统计内容 # 添加Focus统计内容
@@ -1212,6 +1211,7 @@ class StatisticOutputTask(AsyncTask):
f.write(html_template) f.write(html_template)
def _generate_focus_tab(self, stat: dict[str, Any]) -> str: def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
# sourcery skip: for-append-to-extend, list-comprehension, use-any
"""生成Focus统计独立分页的HTML内容""" """生成Focus统计独立分页的HTML内容"""
# 为每个时间段准备Focus数据 # 为每个时间段准备Focus数据
@@ -1313,12 +1313,11 @@ class StatisticOutputTask(AsyncTask):
# 聊天流Action选择比例对比表横向表格 # 聊天流Action选择比例对比表横向表格
focus_chat_action_ratios_rows = "" focus_chat_action_ratios_rows = ""
if stat_data.get("focus_action_ratios_by_chat"): if stat_data.get("focus_action_ratios_by_chat"):
# 获取所有action类型按全局频率排序 if all_action_types_for_ratio := sorted(
all_action_types_for_ratio = sorted( stat_data[FOCUS_ACTION_RATIOS].keys(),
stat_data[FOCUS_ACTION_RATIOS].keys(), key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x], reverse=True key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x],
) reverse=True,
):
if all_action_types_for_ratio:
# 为每个聊天流生成数据行(按循环数排序) # 为每个聊天流生成数据行(按循环数排序)
chat_ratio_rows = [] chat_ratio_rows = []
for chat_id in sorted( for chat_id in sorted(
@@ -1379,16 +1378,11 @@ class StatisticOutputTask(AsyncTask):
if period_name == "all_time": if period_name == "all_time":
from src.manager.local_store_manager import local_storage from src.manager.local_store_manager import local_storage
start_time = datetime.fromtimestamp(local_storage["deploy_time"]) start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
time_range = (
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
else: else:
start_time = datetime.now() - period_delta start_time = datetime.now() - period_delta
time_range = (
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# 生成该时间段的Focus统计HTML # 生成该时间段的Focus统计HTML
section_html = f""" section_html = f"""
<div class="focus-period-section"> <div class="focus-period-section">
@@ -1681,16 +1675,10 @@ class StatisticOutputTask(AsyncTask):
if period_name == "all_time": if period_name == "all_time":
from src.manager.local_store_manager import local_storage from src.manager.local_store_manager import local_storage
start_time = datetime.fromtimestamp(local_storage["deploy_time"]) start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
time_range = (
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
else: else:
start_time = datetime.now() - period_delta start_time = datetime.now() - period_delta
time_range = ( time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
# 生成该时间段的版本对比HTML # 生成该时间段的版本对比HTML
section_html = f""" section_html = f"""
<div class="version-period-section"> <div class="version-period-section">
@@ -1865,7 +1853,7 @@ class StatisticOutputTask(AsyncTask):
# 查询LLM使用记录 # 查询LLM使用记录
query_start_time = start_time query_start_time = start_time
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_time = record.timestamp record_time = record.timestamp
# 找到对应的时间间隔索引 # 找到对应的时间间隔索引
@@ -1875,7 +1863,7 @@ class StatisticOutputTask(AsyncTask):
if 0 <= interval_index < len(time_points): if 0 <= interval_index < len(time_points):
# 累加总花费数据 # 累加总花费数据
cost = record.cost or 0.0 cost = record.cost or 0.0
total_cost_data[interval_index] += cost total_cost_data[interval_index] += cost # type: ignore
# 累加按模型分类的花费 # 累加按模型分类的花费
model_name = record.model_name or "unknown" model_name = record.model_name or "unknown"
@@ -1892,7 +1880,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录 # 查询消息记录
query_start_timestamp = start_time.timestamp() query_start_timestamp = start_time.timestamp()
for message in Messages.select().where(Messages.time >= query_start_timestamp): for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time message_time_ts = message.time
# 找到对应的时间间隔索引 # 找到对应的时间间隔索引
@@ -1982,6 +1970,7 @@ class StatisticOutputTask(AsyncTask):
} }
def _generate_chart_tab(self, chart_data: dict) -> str: def _generate_chart_tab(self, chart_data: dict) -> str:
# sourcery skip: extract-duplicate-method, move-assign-in-block
"""生成图表选项卡HTML内容""" """生成图表选项卡HTML内容"""
# 生成不同颜色的调色板 # 生成不同颜色的调色板
@@ -2293,7 +2282,7 @@ class AsyncStatisticOutputTask(AsyncTask):
# 数据收集任务 # 数据收集任务
collect_task = asyncio.create_task( collect_task = asyncio.create_task(
loop.run_in_executor(executor, self._collect_all_statistics, now) loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
) )
stats = await collect_task stats = await collect_task
@@ -2301,8 +2290,8 @@ class AsyncStatisticOutputTask(AsyncTask):
# 创建并发的输出任务 # 创建并发的输出任务
output_tasks = [ output_tasks = [
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
] ]
# 等待所有输出任务完成 # 等待所有输出任务完成

View File

@@ -1,7 +1,8 @@
import asyncio
from time import perf_counter from time import perf_counter
from functools import wraps from functools import wraps
from typing import Optional, Dict, Callable from typing import Optional, Dict, Callable
import asyncio
from rich.traceback import install from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)
@@ -88,10 +89,10 @@ class Timer:
self.name = name self.name = name
self.storage = storage self.storage = storage
self.elapsed = None self.elapsed: float = None # type: ignore
self.auto_unit = auto_unit self.auto_unit = auto_unit
self.start = None self.start: float = None # type: ignore
@staticmethod @staticmethod
def _validate_types(name, storage): def _validate_types(name, storage):
@@ -120,7 +121,7 @@ class Timer:
return None return None
wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
wrapper.__timer__ = self # 保留计时器引用 wrapper.__timer__ = self # 保留计时器引用 # type: ignore
return wrapper return wrapper
def __enter__(self): def __enter__(self):

View File

@@ -7,10 +7,10 @@ import math
import os import os
import random import random
import time import time
import jieba
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import jieba
from pypinyin import Style, pinyin from pypinyin import Style, pinyin
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -104,7 +104,7 @@ class ChineseTypoGenerator:
try: try:
return "\u4e00" <= char <= "\u9fff" return "\u4e00" <= char <= "\u9fff"
except Exception as e: except Exception as e:
logger.debug(e) logger.debug(str(e))
return False return False
def _get_pinyin(self, sentence): def _get_pinyin(self, sentence):
@@ -138,7 +138,7 @@ class ChineseTypoGenerator:
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况 # 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit(): if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1 # 为非数字结尾的拼音添加数字声调1
return py + "1" return f"{py}1"
base = py[:-1] # 去掉声调 base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调 tone = int(py[-1]) # 获取声调

View File

@@ -1,23 +1,21 @@
import random import random
import re import re
import time import time
from collections import Counter
import jieba import jieba
import numpy as np import numpy as np
from collections import Counter
from maim_message import UserInfo from maim_message import UserInfo
from typing import Optional, Tuple, Dict
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.message_repository import find_messages, count_messages
# from src.mood.mood_manager import mood_manager from src.config.config import global_config
from ..message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest
from .typo_generator import ChineseTypoGenerator
from ...config.config import global_config
from ...common.message_repository import find_messages, count_messages
from typing import Optional, Tuple, Dict
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from .typo_generator import ChineseTypoGenerator
logger = get_logger("chat_utils") logger = get_logger("chat_utils")
@@ -31,11 +29,7 @@ def db_message_to_str(message_dict: dict) -> str:
logger.debug(f"message_dict: {message_dict}") logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"])) time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try: try:
name = "[(%s)%s]%s" % ( name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
message_dict["user_id"],
message_dict.get("user_nickname", ""),
message_dict.get("user_cardname", ""),
)
except Exception: except Exception:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}" name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "") content = message_dict.get("processed_plain_text", "")
@@ -58,11 +52,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
and message.message_info.additional_config.get("is_mentioned") is not None and message.message_info.additional_config.get("is_mentioned") is not None
): ):
try: try:
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
is_mentioned = True is_mentioned = True
return is_mentioned, reply_probability return is_mentioned, reply_probability
except Exception as e: except Exception as e:
logger.warning(e) logger.warning(str(e))
logger.warning( logger.warning(
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}" f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
) )
@@ -135,17 +129,14 @@ def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, c
if not recent_messages: if not recent_messages:
return [] return []
message_detailed_plain_text = ""
message_detailed_plain_text_list = []
# 反转消息列表,使最新的消息在最后 # 反转消息列表,使最新的消息在最后
recent_messages.reverse() recent_messages.reverse()
if combine: if combine:
for msg_db_data in recent_messages: return "".join(str(msg_db_data["detailed_plain_text"]) for msg_db_data in recent_messages)
message_detailed_plain_text += str(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text message_detailed_plain_text_list = []
else:
for msg_db_data in recent_messages: for msg_db_data in recent_messages:
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"]) message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text_list return message_detailed_plain_text_list
@@ -204,10 +195,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
len_text = len(text) len_text = len(text)
if len_text < 3: if len_text < 3:
if random.random() < 0.01: return list(text) if random.random() < 0.01 else [text]
return list(text) # 如果文本很短且触发随机条件,直接按字符分割
else:
return [text]
# 定义分隔符 # 定义分隔符
separators = {"", ",", " ", "", ";"} separators = {"", ",", " ", "", ";"}
@@ -352,8 +340,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
max_length = global_config.response_splitter.max_length * 2 max_length = global_config.response_splitter.max_length * 2
max_sentence_num = global_config.response_splitter.max_sentence_num max_sentence_num = global_config.response_splitter.max_sentence_num
# 如果基本上是中文,则进行长度过滤 # 如果基本上是中文,则进行长度过滤
if get_western_ratio(cleaned_text) < 0.1: if get_western_ratio(cleaned_text) < 0.1 and len(cleaned_text) > max_length:
if len(cleaned_text) > max_length:
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复") logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
return ["懒得说"] return ["懒得说"]
@@ -420,7 +407,7 @@ def calculate_typing_time(
# chinese_time *= 1 / typing_speed_multiplier # chinese_time *= 1 / typing_speed_multiplier
# english_time *= 1 / typing_speed_multiplier # english_time *= 1 / typing_speed_multiplier
# 计算中文字符数 # 计算中文字符数
chinese_chars = sum(1 for char in input_string if "\u4e00" <= char <= "\u9fff") chinese_chars = sum("\u4e00" <= char <= "\u9fff" for char in input_string)
# 如果只有一个中文字符使用3倍时间 # 如果只有一个中文字符使用3倍时间
if chinese_chars == 1 and len(input_string.strip()) == 1: if chinese_chars == 1 and len(input_string.strip()) == 1:
@@ -429,11 +416,7 @@ def calculate_typing_time(
# 正常计算所有字符的输入时间 # 正常计算所有字符的输入时间
total_time = 0.0 total_time = 0.0
for char in input_string: for char in input_string:
if "\u4e00" <= char <= "\u9fff": # 判断是否为中文字符 total_time += chinese_time if "\u4e00" <= char <= "\u9fff" else english_time
total_time += chinese_time
else: # 其他字符(如英文)
total_time += english_time
if is_emoji: if is_emoji:
total_time = 1 total_time = 1
@@ -453,18 +436,14 @@ def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2) dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1) norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2) norm2 = np.linalg.norm(v2)
if norm1 == 0 or norm2 == 0: return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
return 0
return dot_product / (norm1 * norm2)
def text_to_vector(text): def text_to_vector(text):
"""将文本转换为词频向量""" """将文本转换为词频向量"""
# 分词 # 分词
words = jieba.lcut(text) words = jieba.lcut(text)
# 统计词频 return Counter(words)
word_freq = Counter(words)
return word_freq
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list: def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
@@ -491,9 +470,7 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
def truncate_message(message: str, max_length=20) -> str: def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度""" """截断消息,使其不超过指定长度"""
if len(message) > max_length: return f"{message[:max_length]}..." if len(message) > max_length else message
return message[:max_length] + "..."
return message
def protect_kaomoji(sentence): def protect_kaomoji(sentence):
@@ -522,7 +499,7 @@ def protect_kaomoji(sentence):
placeholder_to_kaomoji = {} placeholder_to_kaomoji = {}
for idx, match in enumerate(kaomoji_matches): for idx, match in enumerate(kaomoji_matches):
kaomoji = match[0] if match[0] else match[1] kaomoji = match[0] or match[1]
placeholder = f"__KAOMOJI_{idx}__" placeholder = f"__KAOMOJI_{idx}__"
sentence = sentence.replace(kaomoji, placeholder, 1) sentence = sentence.replace(kaomoji, placeholder, 1)
placeholder_to_kaomoji[placeholder] = kaomoji placeholder_to_kaomoji[placeholder] = kaomoji
@@ -563,7 +540,7 @@ def get_western_ratio(paragraph):
if not alnum_chars: if not alnum_chars:
return 0.0 return 0.0
western_count = sum(1 for char in alnum_chars if is_english_letter(char)) western_count = sum(bool(is_english_letter(char)) for char in alnum_chars)
return western_count / len(alnum_chars) return western_count / len(alnum_chars)
@@ -610,6 +587,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str: def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
"""将时间戳转换为人类可读的时间格式 """将时间戳转换为人类可读的时间格式
Args: Args:
@@ -621,7 +599,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
""" """
if mode == "normal": if mode == "normal":
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
if mode == "normal_no_YMD": elif mode == "normal_no_YMD":
return time.strftime("%H:%M:%S", time.localtime(timestamp)) return time.strftime("%H:%M:%S", time.localtime(timestamp))
elif mode == "relative": elif mode == "relative":
now = time.time() now = time.time()
@@ -640,7 +618,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
else: else:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":" return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":"
else: # mode = "lite" or unknown else: # mode = "lite" or unknown
# 只返回时分秒格式,喵~ # 只返回时分秒格式
return time.strftime("%H:%M:%S", time.localtime(timestamp)) return time.strftime("%H:%M:%S", time.localtime(timestamp))
@@ -670,8 +648,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
elif chat_stream.user_info: # It's a private chat elif chat_stream.user_info: # It's a private chat
is_group_chat = False is_group_chat = False
user_info = chat_stream.user_info user_info = chat_stream.user_info
platform = chat_stream.platform platform: str = chat_stream.platform # type: ignore
user_id = user_info.user_id user_id: str = user_info.user_id # type: ignore
# Initialize target_info with basic info # Initialize target_info with basic info
target_info = { target_info = {

View File

@@ -3,21 +3,20 @@ import os
import time import time
import hashlib import hashlib
import uuid import uuid
import io
import asyncio
import numpy as np
from typing import Optional, Tuple from typing import Optional, Tuple
from PIL import Image from PIL import Image
import io from rich.traceback import install
import numpy as np
import asyncio
from src.common.logger import get_logger
from src.common.database.database import db from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions from src.common.database.database_model import Images, ImageDescriptions
from src.config.config import global_config from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)
logger = get_logger("chat_image") logger = get_logger("chat_image")
@@ -111,7 +110,7 @@ class ImageManager:
return f"[表情包,含义看起来是:{cached_description}]" return f"[表情包,含义看起来是:{cached_description}]"
# 调用AI获取描述 # 调用AI获取描述
if image_format == "gif" or image_format == "GIF": if image_format in ["gif", "GIF"]:
image_base64_processed = self.transform_gif(image_base64) image_base64_processed = self.transform_gif(image_base64)
if image_base64_processed is None: if image_base64_processed is None:
logger.warning("GIF转换失败无法获取描述") logger.warning("GIF转换失败无法获取描述")
@@ -258,6 +257,7 @@ class ImageManager:
@staticmethod @staticmethod
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]: def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
# sourcery skip: use-contextlib-suppress
"""将GIF转换为水平拼接的静态图像, 跳过相似的帧 """将GIF转换为水平拼接的静态图像, 跳过相似的帧
Args: Args:
@@ -351,7 +351,7 @@ class ImageManager:
# 创建拼接图像 # 创建拼接图像
total_width = target_width * len(resized_frames) total_width = target_width * len(resized_frames)
# 防止总宽度为0 # 防止总宽度为0
if total_width == 0 and len(resized_frames) > 0: if total_width == 0 and resized_frames:
logger.warning("计算出的总宽度为0但有选中帧可能目标宽度太小") logger.warning("计算出的总宽度为0但有选中帧可能目标宽度太小")
# 至少给点宽度吧 # 至少给点宽度吧
total_width = len(resized_frames) total_width = len(resized_frames)
@@ -368,10 +368,7 @@ class ImageManager:
# 转换为base64 # 转换为base64
buffer = io.BytesIO() buffer = io.BytesIO()
combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG
result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") return base64.b64encode(buffer.getvalue()).decode("utf-8")
return result_base64
except MemoryError: except MemoryError:
logger.error("GIF转换失败: 内存不足可能是GIF太大或帧数太多") logger.error("GIF转换失败: 内存不足可能是GIF太大或帧数太多")
return None # 内存不够啦 return None # 内存不够啦
@@ -380,6 +377,7 @@ class ImageManager:
return None # 其他错误也返回None return None # 其他错误也返回None
async def process_image(self, image_base64: str) -> Tuple[str, str]: async def process_image(self, image_base64: str) -> Tuple[str, str]:
# sourcery skip: hoist-if-from-if
"""处理图片并返回图片ID和描述 """处理图片并返回图片ID和描述
Args: Args:
@@ -422,14 +420,6 @@ class ImageManager:
existing_image.save() existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]" return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else: else:
# print(f"图片已存在: {existing_image.image_id}")
# print(f"图片描述: {existing_image.description}")
# print(f"图片计数: {existing_image.count}")
# 更新计数
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else:
# print(f"图片不存在: {image_hash}") # print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4()) image_id = str(uuid.uuid4())

View File

@@ -54,11 +54,11 @@ class DBWrapper:
return getattr(get_db(), name) return getattr(get_db(), name)
def __getitem__(self, key): def __getitem__(self, key):
return get_db()[key] return get_db()[key] # type: ignore
# 全局数据库访问点 # 全局数据库访问点
memory_db: Database = DBWrapper() memory_db: Database = DBWrapper() # type: ignore
# 定义数据库文件路径 # 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))

View File

@@ -406,9 +406,7 @@ def initialize_database():
existing_columns = {row[1] for row in cursor.fetchall()} existing_columns = {row[1] for row in cursor.fetchall()}
model_fields = set(model._meta.fields.keys()) model_fields = set(model._meta.fields.keys())
# 检查并添加缺失字段(原有逻辑) if missing_fields := model_fields - existing_columns:
missing_fields = model_fields - existing_columns
if missing_fields:
logger.warning(f"'{table_name}' 缺失字段: {missing_fields}") logger.warning(f"'{table_name}' 缺失字段: {missing_fields}")
for field_name, field_obj in model._meta.fields.items(): for field_name, field_obj in model._meta.fields.items():
@@ -424,10 +422,7 @@ def initialize_database():
"DateTimeField": "DATETIME", "DateTimeField": "DATETIME",
}.get(field_type, "TEXT") }.get(field_type, "TEXT")
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}" alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
if field_obj.null: alter_sql += " NULL" if field_obj.null else " NOT NULL"
alter_sql += " NULL"
else:
alter_sql += " NOT NULL"
if hasattr(field_obj, "default") and field_obj.default is not None: if hasattr(field_obj, "default") and field_obj.default is not None:
# 正确处理不同类型的默认值 # 正确处理不同类型的默认值
default_value = field_obj.default default_value = field_obj.default

View File

@@ -1,16 +1,16 @@
import logging
# 使用基于时间戳的文件处理器,简单的轮转份数限制 # 使用基于时间戳的文件处理器,简单的轮转份数限制
from pathlib import Path
from typing import Callable, Optional import logging
import json import json
import threading import threading
import time import time
from datetime import datetime, timedelta
import structlog import structlog
import toml import toml
from pathlib import Path
from typing import Callable, Optional
from datetime import datetime, timedelta
# 创建logs目录 # 创建logs目录
LOG_DIR = Path("logs") LOG_DIR = Path("logs")
LOG_DIR.mkdir(exist_ok=True) LOG_DIR.mkdir(exist_ok=True)
@@ -160,7 +160,7 @@ def close_handlers():
_console_handler = None _console_handler = None
def remove_duplicate_handlers(): def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-comprehension
"""移除重复的handler特别是文件handler""" """移除重复的handler特别是文件handler"""
root_logger = logging.getLogger() root_logger = logging.getLogger()
@@ -184,7 +184,7 @@ def remove_duplicate_handlers():
# 读取日志配置 # 读取日志配置
def load_log_config(): def load_log_config(): # sourcery skip: use-contextlib-suppress
"""从配置文件加载日志设置""" """从配置文件加载日志设置"""
config_path = Path("config/bot_config.toml") config_path = Path("config/bot_config.toml")
default_config = { default_config = {
@@ -365,7 +365,7 @@ MODULE_COLORS = {
"component_registry": "\033[38;5;214m", # 橙黄色 "component_registry": "\033[38;5;214m", # 橙黄色
"stream_api": "\033[38;5;220m", # 黄色 "stream_api": "\033[38;5;220m", # 黄色
"config_api": "\033[38;5;226m", # 亮黄色 "config_api": "\033[38;5;226m", # 亮黄色
"hearflow_api": "\033[38;5;154m", # 黄绿色 "heartflow_api": "\033[38;5;154m", # 黄绿色
"action_apis": "\033[38;5;118m", # 绿色 "action_apis": "\033[38;5;118m", # 绿色
"independent_apis": "\033[38;5;82m", # 绿色 "independent_apis": "\033[38;5;82m", # 绿色
"llm_api": "\033[38;5;46m", # 亮绿色 "llm_api": "\033[38;5;46m", # 亮绿色
@@ -412,6 +412,7 @@ class ModuleColoredConsoleRenderer:
"""自定义控制台渲染器,为不同模块提供不同颜色""" """自定义控制台渲染器,为不同模块提供不同颜色"""
def __init__(self, colors=True): def __init__(self, colors=True):
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
self._colors = colors self._colors = colors
self._config = LOG_CONFIG self._config = LOG_CONFIG
@@ -443,6 +444,7 @@ class ModuleColoredConsoleRenderer:
self._enable_full_content_colors = False self._enable_full_content_colors = False
def __call__(self, logger, method_name, event_dict): def __call__(self, logger, method_name, event_dict):
# sourcery skip: merge-duplicate-blocks
"""渲染日志消息""" """渲染日志消息"""
# 获取基本信息 # 获取基本信息
timestamp = event_dict.get("timestamp", "") timestamp = event_dict.get("timestamp", "")
@@ -662,7 +664,7 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
"""获取logger实例支持按名称绑定""" """获取logger实例支持按名称绑定"""
if name is None: if name is None:
return raw_logger return raw_logger
logger = binds.get(name) logger = binds.get(name) # type: ignore
if logger is None: if logger is None:
logger: structlog.stdlib.BoundLogger = structlog.get_logger(name).bind(logger_name=name) logger: structlog.stdlib.BoundLogger = structlog.get_logger(name).bind(logger_name=name)
binds[name] = logger binds[name] = logger
@@ -671,8 +673,8 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
def configure_logging( def configure_logging(
level: str = "INFO", level: str = "INFO",
console_level: str = None, console_level: Optional[str] = None,
file_level: str = None, file_level: Optional[str] = None,
max_bytes: int = 5 * 1024 * 1024, max_bytes: int = 5 * 1024 * 1024,
backup_count: int = 30, backup_count: int = 30,
log_dir: str = "logs", log_dir: str = "logs",
@@ -729,14 +731,11 @@ def reload_log_config():
global LOG_CONFIG global LOG_CONFIG
LOG_CONFIG = load_log_config() LOG_CONFIG = load_log_config()
# 重新设置handler的日志级别 if file_handler := get_file_handler():
file_handler = get_file_handler()
if file_handler:
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO")) file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO)) file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO))
console_handler = get_console_handler() if console_handler := get_console_handler():
if console_handler:
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO")) console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO)) console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO))
@@ -780,8 +779,7 @@ def set_console_log_level(level: str):
global LOG_CONFIG global LOG_CONFIG
LOG_CONFIG["console_log_level"] = level.upper() LOG_CONFIG["console_log_level"] = level.upper()
console_handler = get_console_handler() if console_handler := get_console_handler():
if console_handler:
console_handler.setLevel(getattr(logging, level.upper(), logging.INFO)) console_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
# 重新设置root logger级别 # 重新设置root logger级别
@@ -800,8 +798,7 @@ def set_file_log_level(level: str):
global LOG_CONFIG global LOG_CONFIG
LOG_CONFIG["file_log_level"] = level.upper() LOG_CONFIG["file_log_level"] = level.upper()
file_handler = get_file_handler() if file_handler := get_file_handler():
if file_handler:
file_handler.setLevel(getattr(logging, level.upper(), logging.INFO)) file_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
# 重新设置root logger级别 # 重新设置root logger级别
@@ -933,13 +930,12 @@ def format_json_for_logging(data, indent=2, ensure_ascii=False):
Returns: Returns:
str: 格式化后的JSON字符串 str: 格式化后的JSON字符串
""" """
if isinstance(data, str): if not isinstance(data, str):
# 如果是对象,直接格式化
return json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
# 如果是JSON字符串先解析再格式化 # 如果是JSON字符串先解析再格式化
parsed_data = json.loads(data) parsed_data = json.loads(data)
return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii) return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii)
else:
# 如果是对象,直接格式化
return json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
def cleanup_old_logs(): def cleanup_old_logs():

View File

@@ -8,7 +8,7 @@ from src.config.config import global_config
global_api = None global_api = None
def get_global_api() -> MessageServer: def get_global_api() -> MessageServer: # sourcery skip: extract-method
"""获取全局MessageServer实例""" """获取全局MessageServer实例"""
global global_api global global_api
if global_api is None: if global_api is None:
@@ -36,8 +36,7 @@ def get_global_api() -> MessageServer:
kwargs["custom_logger"] = maim_message_logger kwargs["custom_logger"] = maim_message_logger
# 添加token认证 # 添加token认证
if maim_message_config.auth_token: if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
if len(maim_message_config.auth_token) > 0:
kwargs["enable_token"] = True kwargs["enable_token"] = True
if maim_message_config.use_custom: if maim_message_config.use_custom:

View File

@@ -1,9 +1,11 @@
from src.common.database.database_model import Messages # 更改导入
from src.common.logger import get_logger
import traceback import traceback
from typing import List, Any, Optional from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入 from peewee import Model # 添加 Peewee Model 导入
from src.common.database.database_model import Messages
from src.common.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -23,7 +23,7 @@ class TelemetryHeartBeatTask(AsyncTask):
self.server_url = TELEMETRY_SERVER_URL self.server_url = TELEMETRY_SERVER_URL
"""遥测服务地址""" """遥测服务地址"""
self.client_uuid = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None self.client_uuid: str | None = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None # type: ignore
"""客户端UUID""" """客户端UUID"""
self.info_dict = self._get_sys_info() self.info_dict = self._get_sys_info()
@@ -72,7 +72,7 @@ class TelemetryHeartBeatTask(AsyncTask):
timeout=aiohttp.ClientTimeout(total=5), # 设置超时时间为5秒 timeout=aiohttp.ClientTimeout(total=5), # 设置超时时间为5秒
) as response: ) as response:
logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client") logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client")
logger.debug(local_storage["deploy_time"]) logger.debug(local_storage["deploy_time"]) # type: ignore
logger.debug(f"Response status: {response.status}") logger.debug(f"Response status: {response.status}")
if response.status == 200: if response.status == 200:
@@ -93,7 +93,7 @@ class TelemetryHeartBeatTask(AsyncTask):
except Exception as e: except Exception as e:
import traceback import traceback
error_msg = str(e) if str(e) else "未知错误" error_msg = str(e) or "未知错误"
logger.warning( logger.warning(
f"请求UUID出错不过你还是可以正常使用麦麦: {type(e).__name__}: {error_msg}" f"请求UUID出错不过你还是可以正常使用麦麦: {type(e).__name__}: {error_msg}"
) # 可能是网络问题 ) # 可能是网络问题
@@ -114,11 +114,11 @@ class TelemetryHeartBeatTask(AsyncTask):
"""向服务器发送心跳""" """向服务器发送心跳"""
headers = { headers = {
"Client-UUID": self.client_uuid, "Client-UUID": self.client_uuid,
"User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}", "User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}", # type: ignore
} }
logger.debug(f"正在发送心跳到服务器: {self.server_url}") logger.debug(f"正在发送心跳到服务器: {self.server_url}")
logger.debug(headers) logger.debug(str(headers))
try: try:
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session: async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
@@ -151,7 +151,7 @@ class TelemetryHeartBeatTask(AsyncTask):
except Exception as e: except Exception as e:
import traceback import traceback
error_msg = str(e) if str(e) else "未知错误" error_msg = str(e) or "未知错误"
logger.warning(f"(此消息不会影响正常使用)状态未发生: {type(e).__name__}: {error_msg}") logger.warning(f"(此消息不会影响正常使用)状态未发生: {type(e).__name__}: {error_msg}")
logger.debug(f"完整错误信息: {traceback.format_exc()}") logger.debug(f"完整错误信息: {traceback.format_exc()}")

View File

@@ -1,5 +1,6 @@
import shutil import shutil
import tomlkit import tomlkit
from tomlkit.items import Table
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
@@ -45,8 +46,8 @@ def update_config():
# 检查version是否相同 # 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config: if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version: if old_version and new_version and old_version == new_version:
print(f"检测到版本号相同 (v{old_version}),跳过更新") print(f"检测到版本号相同 (v{old_version}),跳过更新")
# 如果version相同恢复旧配置文件并返回 # 如果version相同恢复旧配置文件并返回
@@ -62,7 +63,7 @@ def update_config():
if key == "version": if key == "version":
continue continue
if key in target: if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)): if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value) update_dict(target[key], value)
else: else:
try: try:
@@ -85,10 +86,7 @@ def update_config():
if value and isinstance(value[0], dict) and "regex" in value[0]: if value and isinstance(value[0], dict) and "regex" in value[0]:
contains_regex = True contains_regex = True
if contains_regex: target[key] = value if contains_regex else tomlkit.array(str(value))
target[key] = value
else:
target[key] = tomlkit.array(value)
else: else:
# 其他类型使用item方法创建新值 # 其他类型使用item方法创建新值
target[key] = tomlkit.item(value) target[key] = tomlkit.item(value)

View File

@@ -1,16 +1,14 @@
import os import os
from dataclasses import field, dataclass
import tomlkit import tomlkit
import shutil import shutil
from datetime import datetime
from datetime import datetime
from tomlkit import TOMLDocument from tomlkit import TOMLDocument
from tomlkit.items import Table from tomlkit.items import Table
from dataclasses import field, dataclass
from src.common.logger import get_logger
from rich.traceback import install from rich.traceback import install
from src.common.logger import get_logger
from src.config.config_base import ConfigBase from src.config.config_base import ConfigBase
from src.config.official_configs import ( from src.config.official_configs import (
BotConfig, BotConfig,
@@ -80,8 +78,8 @@ def update_config():
# 检查version是否相同 # 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config: if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version: if old_version and new_version and old_version == new_version:
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
return return
@@ -103,7 +101,7 @@ def update_config():
shutil.copy2(template_path, new_config_path) shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新配置文件: {new_config_path}") logger.info(f"已创建新配置文件: {new_config_path}")
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict): def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
""" """
将source字典的值更新到target字典中如果target中存在相同的键 将source字典的值更新到target字典中如果target中存在相同的键
""" """
@@ -112,8 +110,9 @@ def update_config():
if key == "version": if key == "version":
continue continue
if key in target: if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, Table)): target_value = target[key]
update_dict(target[key], value) if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
update_dict(target_value, value)
else: else:
try: try:
# 对数组类型进行特殊处理 # 对数组类型进行特殊处理

View File

@@ -43,7 +43,7 @@ class ConfigBase:
field_type = f.type field_type = f.type
try: try:
init_args[field_name] = cls._convert_field(value, field_type) init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
except TypeError as e: except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
except Exception as e: except Exception as e:

View File

@@ -1,7 +1,8 @@
from dataclasses import dataclass, field
from typing import Any, Literal
import re import re
from dataclasses import dataclass, field
from typing import Any, Literal, Optional
from src.config.config_base import ConfigBase from src.config.config_base import ConfigBase
""" """
@@ -113,7 +114,7 @@ class ChatConfig(ConfigBase):
exit_focus_threshold: float = 1.0 exit_focus_threshold: float = 1.0
"""自动退出专注聊天的阈值,越低越容易退出专注聊天""" """自动退出专注聊天的阈值,越低越容易退出专注聊天"""
def get_current_talk_frequency(self, chat_stream_id: str = None) -> float: def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
""" """
根据当前时间和聊天流获取对应的 talk_frequency 根据当前时间和聊天流获取对应的 talk_frequency
@@ -138,7 +139,7 @@ class ChatConfig(ConfigBase):
# 如果都没有匹配,返回默认值 # 如果都没有匹配,返回默认值
return self.talk_frequency return self.talk_frequency
def _get_time_based_frequency(self, time_freq_list: list[str]) -> float: def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
""" """
根据时间配置列表获取当前时段的频率 根据时间配置列表获取当前时段的频率
@@ -186,7 +187,7 @@ class ChatConfig(ConfigBase):
return current_frequency return current_frequency
def _get_stream_specific_frequency(self, chat_stream_id: str) -> float: def _get_stream_specific_frequency(self, chat_stream_id: str):
""" """
获取特定聊天流在当前时间的频率 获取特定聊天流在当前时间的频率
@@ -217,7 +218,7 @@ class ChatConfig(ConfigBase):
return None return None
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> str: def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
""" """
解析流配置字符串并生成对应的 chat_id 解析流配置字符串并生成对应的 chat_id

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List, Optional
@dataclass @dataclass
@@ -8,7 +8,7 @@ class Identity:
identity_detail: List[str] # 身份细节描述 identity_detail: List[str] # 身份细节描述
def __init__(self, identity_detail: List[str] = None): def __init__(self, identity_detail: Optional[List[str]] = None):
"""初始化身份特征 """初始化身份特征
Args: Args:

View File

@@ -1,17 +1,18 @@
from typing import Optional
import ast import ast
from src.llm_models.utils_model import LLMRequest
from .personality import Personality
from .identity import Identity
import random import random
import json import json
import os import os
import hashlib import hashlib
from typing import Optional
from rich.traceback import install from rich.traceback import install
from src.common.logger import get_logger from src.common.logger import get_logger
from src.person_info.person_info import get_person_info_manager
from src.config.config import global_config from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import get_person_info_manager
from .personality import Personality
from .identity import Identity
install(extra_lines=3) install(extra_lines=3)
@@ -23,7 +24,7 @@ class Individuality:
def __init__(self): def __init__(self):
# 正常初始化实例属性 # 正常初始化实例属性
self.personality: Optional[Personality] = None self.personality: Personality = None # type: ignore
self.identity: Optional[Identity] = None self.identity: Optional[Identity] = None
self.name = "" self.name = ""
@@ -109,7 +110,7 @@ class Individuality:
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression") existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
if existing_short_impression: if existing_short_impression:
try: try:
existing_data = ast.literal_eval(existing_short_impression) existing_data = ast.literal_eval(existing_short_impression) # type: ignore
if isinstance(existing_data, list) and len(existing_data) >= 1: if isinstance(existing_data, list) and len(existing_data) >= 1:
personality_result = existing_data[0] personality_result = existing_data[0]
except (json.JSONDecodeError, TypeError, IndexError): except (json.JSONDecodeError, TypeError, IndexError):
@@ -128,7 +129,7 @@ class Individuality:
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression") existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
if existing_short_impression: if existing_short_impression:
try: try:
existing_data = ast.literal_eval(existing_short_impression) existing_data = ast.literal_eval(existing_short_impression) # type: ignore
if isinstance(existing_data, list) and len(existing_data) >= 2: if isinstance(existing_data, list) and len(existing_data) >= 2:
identity_result = existing_data[1] identity_result = existing_data[1]
except (json.JSONDecodeError, TypeError, IndexError): except (json.JSONDecodeError, TypeError, IndexError):
@@ -204,6 +205,7 @@ class Individuality:
return prompt_personality return prompt_personality
def get_identity_prompt(self, level: int, x_person: int = 2) -> str: def get_identity_prompt(self, level: int, x_person: int = 2) -> str:
# sourcery skip: assign-if-exp, merge-else-if-into-elif
""" """
获取身份特征的prompt 获取身份特征的prompt
@@ -240,13 +242,13 @@ class Individuality:
if identity_parts: if identity_parts:
details_str = "".join(identity_parts) details_str = "".join(identity_parts)
if x_person in [1, 2]: if x_person in {1, 2}:
return f"{i_pronoun}{details_str}" return f"{i_pronoun}{details_str}"
else: # x_person == 0 else: # x_person == 0
# 无人称时,直接返回细节,不加代词和开头的逗号 # 无人称时,直接返回细节,不加代词和开头的逗号
return f"{details_str}" return f"{details_str}"
else: else:
if x_person in [1, 2]: if x_person in {1, 2}:
return f"{i_pronoun}的身份信息不完整。" return f"{i_pronoun}的身份信息不完整。"
else: # x_person == 0 else: # x_person == 0
return "身份信息不完整。" return "身份信息不完整。"
@@ -441,14 +443,15 @@ class Individuality:
if info_list_json: if info_list_json:
try: try:
info_list = json.loads(info_list_json) if isinstance(info_list_json, str) else info_list_json info_list = json.loads(info_list_json) if isinstance(info_list_json, str) else info_list_json
for item in info_list: keywords.extend(
if isinstance(item, dict) and "info_type" in item: item["info_type"] for item in info_list if isinstance(item, dict) and "info_type" in item
keywords.append(item["info_type"]) )
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
logger.error(f"解析info_list失败: {info_list_json}") logger.error(f"解析info_list失败: {info_list_json}")
return keywords return keywords
async def _create_personality(self, personality_core: str, personality_sides: list) -> str: async def _create_personality(self, personality_core: str, personality_sides: list) -> str:
# sourcery skip: merge-list-append, move-assign
"""使用LLM创建压缩版本的impression """使用LLM创建压缩版本的impression
Args: Args:

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Dict, List
import json import json
from dataclasses import dataclass
from typing import Dict, List, Optional
from pathlib import Path from pathlib import Path
@@ -24,7 +25,7 @@ class Personality:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self, personality_core: str = "", personality_sides: List[str] = None): def __init__(self, personality_core: str = "", personality_sides: Optional[List[str]] = None):
if personality_sides is None: if personality_sides is None:
personality_sides = [] personality_sides = []
self.personality_core = personality_core self.personality_core = personality_core
@@ -41,7 +42,7 @@ class Personality:
cls._instance = cls() cls._instance = cls()
return cls._instance return cls._instance
def _init_big_five_personality(self): def _init_big_five_personality(self): # sourcery skip: extract-method
"""初始化大五人格特质""" """初始化大五人格特质"""
# 构建文件路径 # 构建文件路径
personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per" personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per"
@@ -63,7 +64,6 @@ class Personality:
else: else:
self.extraversion = 0.3 self.extraversion = 0.3
self.neuroticism = 0.5 self.neuroticism = 0.5
if "认真" in self.personality_core or "负责" in self.personality_sides: if "认真" in self.personality_core or "负责" in self.personality_sides:
self.conscientiousness = 0.9 self.conscientiousness = 0.9
else: else:

View File

@@ -120,12 +120,7 @@ class AsyncTaskManager:
""" """
获取所有任务的状态 获取所有任务的状态
""" """
tasks_status = {} return {task_name: {"status": "done" if task.done() else "running"} for task_name, task in self.tasks.items()}
for task_name, task in self.tasks.items():
tasks_status[task_name] = {
"status": "running" if not task.done() else "done",
}
return tasks_status
async def stop_and_wait_all_tasks(self): async def stop_and_wait_all_tasks(self):
""" """

View File

@@ -2,12 +2,12 @@ import math
import random import random
import time import time
from src.chat.message_receive.message import MessageRecv from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from ..common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.llm_models.utils_model import LLMRequest
from src.manager.async_task_manager import AsyncTask, async_task_manager from src.manager.async_task_manager import AsyncTask, async_task_manager
logger = get_logger("mood") logger = get_logger("mood")
@@ -55,12 +55,12 @@ class ChatMood:
request_type="mood", request_type="mood",
) )
self.last_change_time = 0 self.last_change_time: float = 0
async def update_mood_by_message(self, message: MessageRecv, interested_rate: float): async def update_mood_by_message(self, message: MessageRecv, interested_rate: float):
self.regression_count = 0 self.regression_count = 0
during_last_time = message.message_info.time - self.last_change_time during_last_time = message.message_info.time - self.last_change_time # type: ignore
base_probability = 0.05 base_probability = 0.05
time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time)) time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time))
@@ -78,7 +78,7 @@ class ChatMood:
if random.random() > update_probability: if random.random() > update_probability:
return return
message_time = message.message_info.time message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id, chat_id=self.chat_id,
timestamp_start=self.last_change_time, timestamp_start=self.last_change_time,
@@ -119,7 +119,7 @@ class ChatMood:
self.mood_state = response self.mood_state = response
self.last_change_time = message_time self.last_change_time = message_time # type: ignore
async def regress_mood(self): async def regress_mood(self):
message_time = time.time() message_time = time.time()

View File

@@ -1,17 +1,18 @@
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import PersonInfo # 新增导入
import copy import copy
import hashlib import hashlib
from typing import Any, Callable, Dict, Union
import datetime import datetime
import asyncio import asyncio
import json
from json_repair import repair_json
from typing import Any, Callable, Dict, Union, Optional
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
import json # 新增导入
from json_repair import repair_json
""" """
PersonInfoManager 类方法功能摘要: PersonInfoManager 类方法功能摘要:
@@ -42,7 +43,7 @@ person_info_default = {
"last_know": None, "last_know": None,
# "user_cardname": None, # This field is not in Peewee model PersonInfo # "user_cardname": None, # This field is not in Peewee model PersonInfo
# "user_avatar": None, # This field is not in Peewee model PersonInfo # "user_avatar": None, # This field is not in Peewee model PersonInfo
"impression": None, # Corrected from persion_impression "impression": None, # Corrected from person_impression
"short_impression": None, "short_impression": None,
"info_list": None, "info_list": None,
"points": None, "points": None,
@@ -106,27 +107,24 @@ class PersonInfoManager:
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}") logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
return False return False
def get_person_id_by_person_name(self, person_name: str): def get_person_id_by_person_name(self, person_name: str) -> str:
"""根据用户名获取用户ID""" """根据用户名获取用户ID"""
try: try:
record = PersonInfo.get_or_none(PersonInfo.person_name == person_name) record = PersonInfo.get_or_none(PersonInfo.person_name == person_name)
if record: return record.person_id if record else ""
return record.person_id
else:
return ""
except Exception as e: except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}") logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
return "" return ""
@staticmethod @staticmethod
async def create_person_info(person_id: str, data: dict = None): async def create_person_info(person_id: str, data: Optional[dict] = None):
"""创建一个项""" """创建一个项"""
if not person_id: if not person_id:
logger.debug("创建失败personid不存在") logger.debug("创建失败person_id不存在")
return return
_person_info_default = copy.deepcopy(person_info_default) _person_info_default = copy.deepcopy(person_info_default)
model_fields = PersonInfo._meta.fields.keys() model_fields = PersonInfo._meta.fields.keys() # type: ignore
final_data = {"person_id": person_id} final_data = {"person_id": person_id}
@@ -163,9 +161,9 @@ class PersonInfoManager:
await asyncio.to_thread(_db_create_sync, final_data) await asyncio.to_thread(_db_create_sync, final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None): async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
"""更新某一个字段,会补全""" """更新某一个字段,会补全"""
if field_name not in PersonInfo._meta.fields: if field_name not in PersonInfo._meta.fields: # type: ignore
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。") logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
return return
@@ -228,15 +226,13 @@ class PersonInfoManager:
@staticmethod @staticmethod
async def has_one_field(person_id: str, field_name: str): async def has_one_field(person_id: str, field_name: str):
"""判断是否存在某一个字段""" """判断是否存在某一个字段"""
if field_name not in PersonInfo._meta.fields: if field_name not in PersonInfo._meta.fields: # type: ignore
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。") logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
return False return False
def _db_has_field_sync(p_id: str, f_name: str): def _db_has_field_sync(p_id: str, f_name: str):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record: return bool(record)
return True
return False
try: try:
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name) return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
@@ -435,9 +431,7 @@ class PersonInfoManager:
except Exception as e: except Exception as e:
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}") logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
# Fallback to default in case of any error during DB access # Fallback to default in case of any error during DB access
if field_name in person_info_default: return default_value_for_field if field_name in person_info_default else None
return default_value_for_field
return None
@staticmethod @staticmethod
def get_value_sync(person_id: str, field_name: str): def get_value_sync(person_id: str, field_name: str):
@@ -446,8 +440,7 @@ class PersonInfoManager:
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None: if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = [] default_value_for_field = []
record = PersonInfo.get_or_none(PersonInfo.person_id == person_id) if record := PersonInfo.get_or_none(PersonInfo.person_id == person_id):
if record:
val = getattr(record, field_name, None) val = getattr(record, field_name, None)
if field_name in JSON_SERIALIZED_FIELDS: if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str): if isinstance(val, str):
@@ -481,7 +474,7 @@ class PersonInfoManager:
record = await asyncio.to_thread(_db_get_record_sync, person_id) record = await asyncio.to_thread(_db_get_record_sync, person_id)
for field_name in field_names: for field_name in field_names:
if field_name not in PersonInfo._meta.fields: if field_name not in PersonInfo._meta.fields: # type: ignore
if field_name in person_info_default: if field_name in person_info_default:
result[field_name] = copy.deepcopy(person_info_default[field_name]) result[field_name] = copy.deepcopy(person_info_default[field_name])
logger.debug(f"字段'{field_name}'不在Peewee模型中使用默认配置值。") logger.debug(f"字段'{field_name}'不在Peewee模型中使用默认配置值。")
@@ -509,7 +502,7 @@ class PersonInfoManager:
""" """
获取满足条件的字段值字典 获取满足条件的字段值字典
""" """
if field_name not in PersonInfo._meta.fields: if field_name not in PersonInfo._meta.fields: # type: ignore
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义") logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义")
return {} return {}
@@ -531,7 +524,7 @@ class PersonInfoManager:
return {} return {}
async def get_or_create_person( async def get_or_create_person(
self, platform: str, user_id: int, nickname: str = None, user_cardname: str = None, user_avatar: str = None self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None
) -> str: ) -> str:
""" """
根据 platform 和 user_id 获取 person_id。 根据 platform 和 user_id 获取 person_id。
@@ -561,7 +554,7 @@ class PersonInfoManager:
"points": [], "points": [],
"forgotten_points": [], "forgotten_points": [],
} }
model_fields = PersonInfo._meta.fields.keys() model_fields = PersonInfo._meta.fields.keys() # type: ignore
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
await self.create_person_info(person_id, data=filtered_initial_data) await self.create_person_info(person_id, data=filtered_initial_data)
@@ -610,7 +603,9 @@ class PersonInfoManager:
"name_reason", "name_reason",
] ]
valid_fields_to_get = [ valid_fields_to_get = [
f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default f
for f in required_fields
if f in PersonInfo._meta.fields or f in person_info_default # type: ignore
] ]
person_data = await self.get_values(found_person_id, valid_fields_to_get) person_data = await self.get_values(found_person_id, valid_fields_to_get)

View File

@@ -3,12 +3,12 @@ import traceback
import os import os
import pickle import pickle
import random import random
from typing import List, Dict from typing import List, Dict, Any
from src.config.config import global_config from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.relationship_manager import get_relationship_manager from src.person_info.relationship_manager import get_relationship_manager
from src.person_info.person_info import get_person_info_manager, PersonInfoManager from src.person_info.person_info import get_person_info_manager, PersonInfoManager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive, get_raw_msg_by_timestamp_with_chat_inclusive,
@@ -45,7 +45,7 @@ class RelationshipBuilder:
self.chat_id = chat_id self.chat_id = chat_id
# 新的消息段缓存结构: # 新的消息段缓存结构:
# {person_id: [{"start_time": float, "end_time": float, "last_msg_time": float, "message_count": int}, ...]} # {person_id: [{"start_time": float, "end_time": float, "last_msg_time": float, "message_count": int}, ...]}
self.person_engaged_cache: Dict[str, List[Dict[str, any]]] = {} self.person_engaged_cache: Dict[str, List[Dict[str, Any]]] = {}
# 持久化存储文件路径 # 持久化存储文件路径
self.cache_file_path = os.path.join("data", "relationship", f"relationship_cache_{self.chat_id}.pkl") self.cache_file_path = os.path.join("data", "relationship", f"relationship_cache_{self.chat_id}.pkl")
@@ -210,11 +210,7 @@ class RelationshipBuilder:
if person_id not in self.person_engaged_cache: if person_id not in self.person_engaged_cache:
return 0 return 0
total_count = 0 return sum(segment["message_count"] for segment in self.person_engaged_cache[person_id])
for segment in self.person_engaged_cache[person_id]:
total_count += segment["message_count"]
return total_count
def _cleanup_old_segments(self) -> bool: def _cleanup_old_segments(self) -> bool:
"""清理老旧的消息段""" """清理老旧的消息段"""
@@ -289,7 +285,7 @@ class RelationshipBuilder:
self.last_cleanup_time = current_time self.last_cleanup_time = current_time
# 保存缓存 # 保存缓存
if cleanup_stats["segments_removed"] > 0 or len(users_to_remove) > 0: if cleanup_stats["segments_removed"] > 0 or users_to_remove:
self._save_cache() self._save_cache()
logger.info( logger.info(
f"{self.log_prefix} 清理完成 - 影响用户: {cleanup_stats['users_cleaned']}, 移除消息段: {cleanup_stats['segments_removed']}, 移除用户: {len(users_to_remove)}" f"{self.log_prefix} 清理完成 - 影响用户: {cleanup_stats['users_cleaned']}, 移除消息段: {cleanup_stats['segments_removed']}, 移除用户: {len(users_to_remove)}"
@@ -313,6 +309,7 @@ class RelationshipBuilder:
return False return False
def get_cache_status(self) -> str: def get_cache_status(self) -> str:
# sourcery skip: merge-list-append, merge-list-appends-into-extend
"""获取缓存状态信息,用于调试和监控""" """获取缓存状态信息,用于调试和监控"""
if not self.person_engaged_cache: if not self.person_engaged_cache:
return f"{self.log_prefix} 关系缓存为空" return f"{self.log_prefix} 关系缓存为空"
@@ -357,13 +354,12 @@ class RelationshipBuilder:
self._cleanup_old_segments() self._cleanup_old_segments()
current_time = time.time() current_time = time.time()
latest_messages = get_raw_msg_by_timestamp_with_chat( if latest_messages := get_raw_msg_by_timestamp_with_chat(
self.chat_id, self.chat_id,
self.last_processed_message_time, self.last_processed_message_time,
current_time, current_time,
limit=50, # 获取自上次处理后的消息 limit=50, # 获取自上次处理后的消息
) ):
if latest_messages:
# 处理所有新的非bot消息 # 处理所有新的非bot消息
for latest_msg in latest_messages: for latest_msg in latest_messages:
user_id = latest_msg.get("user_id") user_id = latest_msg.get("user_id")
@@ -414,7 +410,7 @@ class RelationshipBuilder:
# 负责触发关系构建、整合消息段、更新用户印象 # 负责触发关系构建、整合消息段、更新用户印象
# ================================ # ================================
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, any]]): async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]):
"""基于消息段更新用户印象""" """基于消息段更新用户印象"""
original_segment_count = len(segments) original_segment_count = len(segments)
logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象") logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象")

View File

@@ -1,4 +1,5 @@
from typing import Dict, Optional, List from typing import Dict, Optional, List, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from .relationship_builder import RelationshipBuilder from .relationship_builder import RelationshipBuilder
@@ -63,7 +64,7 @@ class RelationshipBuilderManager:
""" """
return list(self.builders.keys()) return list(self.builders.keys())
def get_status(self) -> Dict[str, any]: def get_status(self) -> Dict[str, Any]:
"""获取管理器状态 """获取管理器状态
Returns: Returns:
@@ -94,9 +95,7 @@ class RelationshipBuilderManager:
bool: 是否成功清理 bool: 是否成功清理
""" """
builder = self.get_builder(chat_id) builder = self.get_builder(chat_id)
if builder: return builder.force_cleanup_user_segments(person_id) if builder else False
return builder.force_cleanup_user_segments(person_id)
return False
# 全局管理器实例 # 全局管理器实例

View File

@@ -1,16 +1,19 @@
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
import time import time
import traceback import traceback
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.person_info.person_info import get_person_info_manager
from typing import List, Dict
from json_repair import repair_json
from src.chat.message_receive.chat_stream import get_chat_manager
import json import json
import random import random
from typing import List, Dict, Any
from json_repair import repair_json
from src.common.logger import get_logger
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
logger = get_logger("relationship_fetcher") logger = get_logger("relationship_fetcher")
@@ -62,11 +65,11 @@ class RelationshipFetcher:
self.chat_id = chat_id self.chat_id = chat_id
# 信息获取缓存:记录正在获取的信息请求 # 信息获取缓存:记录正在获取的信息请求
self.info_fetching_cache: List[Dict[str, any]] = [] self.info_fetching_cache: List[Dict[str, Any]] = []
# 信息结果缓存存储已获取的信息结果带TTL # 信息结果缓存存储已获取的信息结果带TTL
self.info_fetched_cache: Dict[str, Dict[str, any]] = {} self.info_fetched_cache: Dict[str, Dict[str, Any]] = {}
# 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknow": bool}}} # 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}}
# LLM模型配置 # LLM模型配置
self.llm_model = LLMRequest( self.llm_model = LLMRequest(
@@ -184,7 +187,7 @@ class RelationshipFetcher:
nickname_str = ",".join(global_config.bot.alias_names) nickname_str = ",".join(global_config.bot.alias_names)
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。" name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
person_name = await person_info_manager.get_value(person_id, "person_name") person_name: str = await person_info_manager.get_value(person_id, "person_name") # type: ignore
info_cache_block = self._build_info_cache_block() info_cache_block = self._build_info_cache_block()
@@ -208,8 +211,7 @@ class RelationshipFetcher:
logger.debug(f"{self.log_prefix} LLM判断当前不需要查询任何信息{content_json.get('none', '')}") logger.debug(f"{self.log_prefix} LLM判断当前不需要查询任何信息{content_json.get('none', '')}")
return None return None
info_type = content_json.get("info_type") if info_type := content_json.get("info_type"):
if info_type:
# 记录信息获取请求 # 记录信息获取请求
self.info_fetching_cache.append( self.info_fetching_cache.append(
{ {
@@ -287,7 +289,7 @@ class RelationshipFetcher:
"ttl": 2, "ttl": 2,
"start_time": start_time, "start_time": start_time,
"person_name": person_name, "person_name": person_name,
"unknow": cached_info == "none", "unknown": cached_info == "none",
} }
logger.info(f"{self.log_prefix} 记得 {person_name}{info_type}: {cached_info}") logger.info(f"{self.log_prefix} 记得 {person_name}{info_type}: {cached_info}")
return return
@@ -321,7 +323,7 @@ class RelationshipFetcher:
"ttl": 2, "ttl": 2,
"start_time": start_time, "start_time": start_time,
"person_name": person_name, "person_name": person_name,
"unknow": True, "unknown": True,
} }
logger.info(f"{self.log_prefix} 完全不认识 {person_name}") logger.info(f"{self.log_prefix} 完全不认识 {person_name}")
await self._save_info_to_cache(person_id, info_type, "none") await self._save_info_to_cache(person_id, info_type, "none")
@@ -353,15 +355,15 @@ class RelationshipFetcher:
if person_id not in self.info_fetched_cache: if person_id not in self.info_fetched_cache:
self.info_fetched_cache[person_id] = {} self.info_fetched_cache[person_id] = {}
self.info_fetched_cache[person_id][info_type] = { self.info_fetched_cache[person_id][info_type] = {
"info": "unknow" if is_unknown else info_content, "info": "unknown" if is_unknown else info_content,
"ttl": 3, "ttl": 3,
"start_time": start_time, "start_time": start_time,
"person_name": person_name, "person_name": person_name,
"unknow": is_unknown, "unknown": is_unknown,
} }
# 保存到持久化缓存 (info_list) # 保存到持久化缓存 (info_list)
await self._save_info_to_cache(person_id, info_type, info_content if not is_unknown else "none") await self._save_info_to_cache(person_id, info_type, "none" if is_unknown else info_content)
if not is_unknown: if not is_unknown:
logger.info(f"{self.log_prefix} 思考得到,{person_name}{info_type}: {info_content}") logger.info(f"{self.log_prefix} 思考得到,{person_name}{info_type}: {info_content}")
@@ -393,7 +395,7 @@ class RelationshipFetcher:
for info_type in self.info_fetched_cache[person_id]: for info_type in self.info_fetched_cache[person_id]:
person_name = self.info_fetched_cache[person_id][info_type]["person_name"] person_name = self.info_fetched_cache[person_id][info_type]["person_name"]
if not self.info_fetched_cache[person_id][info_type]["unknow"]: if not self.info_fetched_cache[person_id][info_type]["unknown"]:
info_content = self.info_fetched_cache[person_id][info_type]["info"] info_content = self.info_fetched_cache[person_id][info_type]["info"]
person_known_infos.append(f"[{info_type}]{info_content}") person_known_infos.append(f"[{info_type}]{info_content}")
else: else:
@@ -430,6 +432,7 @@ class RelationshipFetcher:
return persons_infos_str return persons_infos_str
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str): async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
# sourcery skip: use-next
"""将提取到的信息保存到 person_info 的 info_list 字段中 """将提取到的信息保存到 person_info 的 info_list 字段中
Args: Args:

View File

@@ -1,5 +1,5 @@
from src.common.logger import get_logger from src.common.logger import get_logger
from src.person_info.person_info import PersonInfoManager, get_person_info_manager from .person_info import PersonInfoManager, get_person_info_manager
import time import time
import random import random
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
@@ -12,7 +12,7 @@ from difflib import SequenceMatcher
import jieba import jieba
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Dict, Any
logger = get_logger("relation") logger = get_logger("relation")
@@ -28,8 +28,7 @@ class RelationshipManager:
async def is_known_some_one(platform, user_id): async def is_known_some_one(platform, user_id):
"""判断是否认识某人""" """判断是否认识某人"""
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
is_known = await person_info_manager.is_person_known(platform, user_id) return await person_info_manager.is_person_known(platform, user_id)
return is_known
@staticmethod @staticmethod
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str): async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
@@ -110,7 +109,7 @@ class RelationshipManager:
return relation_prompt return relation_prompt
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages=None): async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
"""更新用户印象 """更新用户印象
Args: Args:
@@ -123,7 +122,7 @@ class RelationshipManager:
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
person_name = await person_info_manager.get_value(person_id, "person_name") person_name = await person_info_manager.get_value(person_id, "person_name")
nickname = await person_info_manager.get_value(person_id, "nickname") nickname = await person_info_manager.get_value(person_id, "nickname")
know_times = await person_info_manager.get_value(person_id, "know_times") or 0 know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore
alias_str = ", ".join(global_config.bot.alias_names) alias_str = ", ".join(global_config.bot.alias_names)
# personality_block =get_individuality().get_personality_prompt(x_person=2, level=2) # personality_block =get_individuality().get_personality_prompt(x_person=2, level=2)
@@ -142,13 +141,13 @@ class RelationshipManager:
# 遍历消息,构建映射 # 遍历消息,构建映射
for msg in user_messages: for msg in user_messages:
await person_info_manager.get_or_create_person( await person_info_manager.get_or_create_person(
platform=msg.get("chat_info_platform"), platform=msg.get("chat_info_platform"), # type: ignore
user_id=msg.get("user_id"), user_id=msg.get("user_id"), # type: ignore
nickname=msg.get("user_nickname"), nickname=msg.get("user_nickname"), # type: ignore
user_cardname=msg.get("user_cardname"), user_cardname=msg.get("user_cardname"), # type: ignore
) )
replace_user_id = msg.get("user_id") replace_user_id: str = msg.get("user_id") # type: ignore
replace_platform = msg.get("chat_info_platform") replace_platform: str = msg.get("chat_info_platform") # type: ignore
replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id) replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id)
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name") replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
@@ -354,8 +353,8 @@ class RelationshipManager:
person_name = await person_info_manager.get_value(person_id, "person_name") person_name = await person_info_manager.get_value(person_id, "person_name")
nickname = await person_info_manager.get_value(person_id, "nickname") nickname = await person_info_manager.get_value(person_id, "nickname")
know_times = await person_info_manager.get_value(person_id, "know_times") or 0 know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore
attitude = await person_info_manager.get_value(person_id, "attitude") or 50 attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore
# 根据熟悉度,调整印象和简短印象的最大长度 # 根据熟悉度,调整印象和简短印象的最大长度
if know_times > 300: if know_times > 300:
@@ -414,9 +413,7 @@ class RelationshipManager:
if len(remaining_points) < 10: if len(remaining_points) < 10:
# 如果还没达到30条直接保留 # 如果还没达到30条直接保留
remaining_points.append(point) remaining_points.append(point)
else: elif random.random() < keep_probability:
# 随机决定是否保留
if random.random() < keep_probability:
# 保留这个点,随机移除一个已保留的点 # 保留这个点,随机移除一个已保留的点
idx_to_remove = random.randrange(len(remaining_points)) idx_to_remove = random.randrange(len(remaining_points))
points_to_move.append(remaining_points[idx_to_remove]) points_to_move.append(remaining_points[idx_to_remove])
@@ -520,7 +517,7 @@ class RelationshipManager:
new_attitude = int(relation_value_json.get("attitude", 50)) new_attitude = int(relation_value_json.get("attitude", 50))
# 获取当前的关系值 # 获取当前的关系值
old_attitude = await person_info_manager.get_value(person_id, "attitude") or 50 old_attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore
# 更新熟悉度 # 更新熟悉度
if new_attitude > 25: if new_attitude > 25:

View File

@@ -65,9 +65,9 @@ def get_replyer(
async def generate_reply( async def generate_reply(
chat_stream=None, chat_stream: Optional[ChatStream] = None,
chat_id: str = None, chat_id: Optional[str] = None,
action_data: Dict[str, Any] = None, action_data: Optional[Dict[str, Any]] = None,
reply_to: str = "", reply_to: str = "",
extra_info: str = "", extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None, available_actions: Optional[Dict[str, ActionInfo]] = None,
@@ -78,25 +78,25 @@ async def generate_reply(
model_configs: Optional[List[Dict[str, Any]]] = None, model_configs: Optional[List[Dict[str, Any]]] = None,
request_type: str = "", request_type: str = "",
enable_timeout: bool = False, enable_timeout: bool = False,
) -> Tuple[bool, List[Tuple[str, Any]]]: ) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""生成回复 """生成回复
Args: Args:
chat_stream: 聊天流对象(优先) chat_stream: 聊天流对象(优先)
action_data: 动作数据
chat_id: 聊天ID备用 chat_id: 聊天ID备用
action_data: 动作数据
enable_splitter: 是否启用消息分割器 enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器 enable_chinese_typo: 是否启用错字生成器
return_prompt: 是否返回提示词 return_prompt: 是否返回提示词
Returns: Returns:
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合) Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
""" """
try: try:
# 获取回复器 # 获取回复器
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type) replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type)
if not replyer: if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器") logger.error("[GeneratorAPI] 无法获取回复器")
return False, [] return False, [], None
logger.debug("[GeneratorAPI] 开始生成回复") logger.debug("[GeneratorAPI] 开始生成回复")
@@ -109,7 +109,8 @@ async def generate_reply(
enable_timeout=enable_timeout, enable_timeout=enable_timeout,
enable_tool=enable_tool, enable_tool=enable_tool,
) )
reply_set = []
if content:
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo) reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
if success: if success:
@@ -118,19 +119,19 @@ async def generate_reply(
logger.warning("[GeneratorAPI] 回复生成失败") logger.warning("[GeneratorAPI] 回复生成失败")
if return_prompt: if return_prompt:
return success, reply_set or [], prompt return success, reply_set, prompt
else: else:
return success, reply_set or [] return success, reply_set, None
except Exception as e: except Exception as e:
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}") logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
return False, [] return False, [], None
async def rewrite_reply( async def rewrite_reply(
chat_stream=None, chat_stream: Optional[ChatStream] = None,
reply_data: Dict[str, Any] = None, reply_data: Optional[Dict[str, Any]] = None,
chat_id: str = None, chat_id: Optional[str] = None,
enable_splitter: bool = True, enable_splitter: bool = True,
enable_chinese_typo: bool = True, enable_chinese_typo: bool = True,
model_configs: Optional[List[Dict[str, Any]]] = None, model_configs: Optional[List[Dict[str, Any]]] = None,
@@ -158,7 +159,8 @@ async def rewrite_reply(
# 调用回复器重写回复 # 调用回复器重写回复
success, content = await replyer.rewrite_reply_with_context(reply_data=reply_data or {}) success, content = await replyer.rewrite_reply_with_context(reply_data=reply_data or {})
reply_set = []
if content:
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo) reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
if success: if success:
@@ -166,7 +168,7 @@ async def rewrite_reply(
else: else:
logger.warning("[GeneratorAPI] 重写回复失败") logger.warning("[GeneratorAPI] 重写回复失败")
return success, reply_set or [] return success, reply_set
except Exception as e: except Exception as e:
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}") logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")

View File

@@ -34,7 +34,7 @@ class ToolExecutor:
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。 可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
""" """
def __init__(self, chat_id: str = None, enable_cache: bool = True, cache_ttl: int = 3): def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
"""初始化工具执行器 """初始化工具执行器
Args: Args:
@@ -62,8 +62,8 @@ class ToolExecutor:
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'}TTL={cache_ttl}") logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'}TTL={cache_ttl}")
async def execute_from_chat_message( async def execute_from_chat_message(
self, target_message: str, chat_history: list[str], sender: str, return_details: bool = False self, target_message: str, chat_history: str, sender: str, return_details: bool = False
) -> List[Dict] | Tuple[List[Dict], List[str], str]: ) -> Tuple[List[Dict], List[str], str]:
"""从聊天消息执行工具 """从聊天消息执行工具
Args: Args:
@@ -79,16 +79,14 @@ class ToolExecutor:
# 首先检查缓存 # 首先检查缓存
cache_key = self._generate_cache_key(target_message, chat_history, sender) cache_key = self._generate_cache_key(target_message, chat_history, sender)
cached_result = self._get_from_cache(cache_key) if cached_result := self._get_from_cache(cache_key):
if cached_result:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行") logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
if return_details: if not return_details:
return cached_result, [], "使用缓存结果"
# 从缓存结果中提取工具名称 # 从缓存结果中提取工具名称
used_tools = [result.get("tool_name", "unknown") for result in cached_result] used_tools = [result.get("tool_name", "unknown") for result in cached_result]
return cached_result, used_tools, "使用缓存结果" return cached_result, used_tools, "使用缓存结果"
else:
return cached_result
# 缓存未命中,执行工具调用 # 缓存未命中,执行工具调用
# 获取可用工具 # 获取可用工具
@@ -134,7 +132,7 @@ class ToolExecutor:
if return_details: if return_details:
return tool_results, used_tools, prompt return tool_results, used_tools, prompt
else: else:
return tool_results return tool_results, [], ""
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]: async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
"""执行工具调用 """执行工具调用
@@ -207,7 +205,7 @@ class ToolExecutor:
return tool_results, used_tools return tool_results, used_tools
def _generate_cache_key(self, target_message: str, chat_history: list[str], sender: str) -> str: def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
"""生成缓存键 """生成缓存键
Args: Args:
@@ -267,10 +265,7 @@ class ToolExecutor:
return return
expired_keys = [] expired_keys = []
for cache_key, cache_item in self.tool_cache.items(): expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
if cache_item["ttl"] <= 0:
expired_keys.append(cache_key)
for key in expired_keys: for key in expired_keys:
del self.tool_cache[key] del self.tool_cache[key]
@@ -355,7 +350,7 @@ class ToolExecutor:
"ttl_distribution": ttl_distribution, "ttl_distribution": ttl_distribution,
} }
def set_cache_config(self, enable_cache: bool = None, cache_ttl: int = None): def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
"""动态修改缓存配置 """动态修改缓存配置
Args: Args:
@@ -366,7 +361,7 @@ class ToolExecutor:
self.enable_cache = enable_cache self.enable_cache = enable_cache
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}") logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
if cache_ttl is not None and cache_ttl > 0: if cache_ttl > 0:
self.cache_ttl = cache_ttl self.cache_ttl = cache_ttl
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}") logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
@@ -380,7 +375,7 @@ init_tool_executor_prompt()
# 1. 基础使用 - 从聊天消息执行工具启用缓存默认TTL=3 # 1. 基础使用 - 从聊天消息执行工具启用缓存默认TTL=3
executor = ToolExecutor(executor_id="my_executor") executor = ToolExecutor(executor_id="my_executor")
results = await executor.execute_from_chat_message( results, _, _ = await executor.execute_from_chat_message(
talking_message_str="今天天气怎么样?现在几点了?", talking_message_str="今天天气怎么样?现在几点了?",
is_group_chat=False is_group_chat=False
) )