完成所有类型注解的修复

This commit is contained in:
UnCLAS-Prommer
2025-07-13 00:19:54 +08:00
parent d2ad6ea1d8
commit 7ef0bfb7c8
32 changed files with 358 additions and 434 deletions

View File

@@ -4,6 +4,7 @@ import asyncio
import random
import ast
import re
from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
@@ -161,13 +162,13 @@ class DefaultReplyer:
async def generate_reply_with_context(
self,
reply_data: Dict[str, Any] = None,
reply_data: Optional[Dict[str, Any]] = None,
reply_to: str = "",
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_tool: bool = True,
enable_timeout: bool = False,
) -> Tuple[bool, Optional[str]]:
) -> Tuple[bool, Optional[str], Optional[str]]:
"""
回复器 (Replier): 核心逻辑,负责生成回复文本。
(已整合原 HeartFCGenerator 的功能)
@@ -225,14 +226,14 @@ class DefaultReplyer:
except Exception as llm_e:
# 精简报错信息
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
return False, None # LLM 调用失败则无法生成回复
return False, None, prompt # LLM 调用失败则无法生成回复
return True, content, prompt
except Exception as e:
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
traceback.print_exc()
return False, None
return False, None, prompt
async def rewrite_reply_with_context(
self,
@@ -368,7 +369,7 @@ class DefaultReplyer:
memory_str += f"- {running_memory['content']}\n"
return memory_str
async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True):
async def build_tool_info(self, chat_history, reply_data: Optional[Dict], enable_tool: bool = True):
"""构建工具信息块
Args:
@@ -393,7 +394,7 @@ class DefaultReplyer:
try:
# 使用工具执行器获取信息
tool_results = await self.tool_executor.execute_from_chat_message(
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
sender=sender, target_message=text, chat_history=chat_history, return_details=False
)
@@ -468,7 +469,7 @@ class DefaultReplyer:
async def build_prompt_reply_context(
self,
reply_data=None,
reply_data: Dict[str, Any],
available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_timeout: bool = False,
enable_tool: bool = True,
@@ -549,7 +550,7 @@ class DefaultReplyer:
),
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_half, target), "build_memory_block"),
self._time_and_run_task(
self.build_tool_info(reply_data, chat_talking_prompt_half, enable_tool=enable_tool), "build_tool_info"
self.build_tool_info(chat_talking_prompt_half, reply_data, enable_tool=enable_tool), "build_tool_info"
),
)
@@ -806,7 +807,7 @@ class DefaultReplyer:
response_set: List[Tuple[str, str]],
thinking_id: str = "",
display_message: str = "",
) -> Optional[MessageSending]:
) -> Optional[List[Tuple[str, bool]]]:
# sourcery skip: assign-if-exp, boolean-if-exp-identity, remove-unnecessary-cast
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
chat = self.chat_stream
@@ -869,7 +870,7 @@ class DefaultReplyer:
try:
if (
bot_message.is_private_message()
or bot_message.reply.processed_plain_text != "[System Trigger Context]"
or bot_message.reply.processed_plain_text != "[System Trigger Context]" # type: ignore
or mark_head
):
set_reply = False
@@ -910,7 +911,7 @@ class DefaultReplyer:
is_emoji: bool,
thinking_start_time: float,
display_message: str,
anchor_message: MessageRecv = None,
anchor_message: Optional[MessageRecv] = None,
) -> MessageSending:
"""构建单个发送消息"""

View File

@@ -1,8 +1,8 @@
from typing import Dict, Any, Optional, List
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.replyer.default_generator import DefaultReplyer
from src.common.logger import get_logger
logger = get_logger("ReplyerManager")

View File

@@ -1,6 +1,7 @@
import time # 导入 time 模块以获取当前时间
import random
import re
from typing import List, Dict, Any, Tuple, Optional
from rich.traceback import install
@@ -88,8 +89,8 @@ def get_actions_by_timestamp_with_chat(
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time > timestamp_start)
& (ActionRecords.time < timestamp_end)
& (ActionRecords.time > timestamp_start) # type: ignore
& (ActionRecords.time < timestamp_end) # type: ignore
)
if limit > 0:
@@ -113,8 +114,8 @@ def get_actions_by_timestamp_with_chat_inclusive(
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time >= timestamp_start)
& (ActionRecords.time <= timestamp_end)
& (ActionRecords.time >= timestamp_start) # type: ignore
& (ActionRecords.time <= timestamp_end) # type: ignore
)
if limit > 0:
@@ -331,7 +332,7 @@ def _build_readable_messages_internal(
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
else:
person_name = person_info_manager.get_value_sync(person_id, "person_name")
person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
if not person_name:
@@ -911,8 +912,8 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
person_ids_set = set() # 使用集合来自动去重
for msg in messages:
platform = msg.get("user_platform")
user_id = msg.get("user_id")
platform: str = msg.get("user_platform") # type: ignore
user_id: str = msg.get("user_id") # type: ignore
# 检查必要信息是否存在 且 不是机器人自己
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:

View File

@@ -1,7 +1,8 @@
import ast
import json
import logging
from typing import Any, Dict, TypeVar, List, Union, Tuple
import ast
from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional
# 定义类型变量用于泛型类型提示
T = TypeVar("T")
@@ -30,18 +31,14 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
# 尝试标准的 JSON 解析
return json.loads(json_str)
except json.JSONDecodeError:
# 如果标准解析失败,尝试将单引号替换为双引号再解析
# (注意:这种替换可能不安全,如果字符串内容本身包含引号)
# 更安全的方式是用 ast.literal_eval
# 如果标准解析失败,尝试用 ast.literal_eval 解析
try:
# logger.debug(f"标准JSON解析失败尝试用 ast.literal_eval 解析: {json_str[:100]}...")
result = ast.literal_eval(json_str)
# 确保结果是字典(因为我们通常期望参数是字典)
if isinstance(result, dict):
return result
else:
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
return default_value
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
return default_value
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...")
return default_value
@@ -53,7 +50,9 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
return default_value
def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]:
def extract_tool_call_arguments(
tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
从LLM工具调用对象中提取参数
@@ -77,14 +76,12 @@ def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[s
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
return default_result
# 提取arguments
arguments_str = function_data.get("arguments", "{}")
if not arguments_str:
if arguments_str := function_data.get("arguments", "{}"):
# 解析JSON
return safe_json_loads(arguments_str, default_result)
else:
return default_result
# 解析JSON
return safe_json_loads(arguments_str, default_result)
except Exception as e:
logger.error(f"提取工具调用参数时出错: {e}")
return default_result

View File

@@ -1,12 +1,12 @@
from typing import Dict, Any, Optional, List, Union
import re
from contextlib import asynccontextmanager
import asyncio
import contextvars
from src.common.logger import get_logger
# import traceback
from rich.traceback import install
from contextlib import asynccontextmanager
from typing import Dict, Any, Optional, List, Union
from src.common.logger import get_logger
install(extra_lines=3)
@@ -32,6 +32,7 @@ class PromptContext:
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
# sourcery skip: hoist-statement-from-if, use-contextlib-suppress
"""创建一个异步的临时提示模板作用域"""
# 保存当前上下文并设置新上下文
if context_id is not None:
@@ -88,8 +89,7 @@ class PromptContext:
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
target_context = context_id or self._current_context
if target_context:
if target_context := context_id or self._current_context:
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
@@ -151,7 +151,7 @@ class Prompt(str):
@staticmethod
def _process_escaped_braces(template) -> str:
"""处理模板中的转义花括号,将 \{\} 替换为临时标记"""
"""处理模板中的转义花括号,将 \{\} 替换为临时标记""" # type: ignore
# 如果传入的是列表,将其转换为字符串
if isinstance(template, list):
template = "\n".join(str(item) for item in template)
@@ -195,14 +195,8 @@ class Prompt(str):
obj._kwargs = kwargs
# 修改自动注册逻辑
if should_register:
if global_prompt_manager._context._current_context:
# 如果存在当前上下文,则注册到上下文中
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
pass
else:
# 否则注册到全局管理器
global_prompt_manager.register(obj)
if should_register and not global_prompt_manager._context._current_context:
global_prompt_manager.register(obj)
return obj
@classmethod
@@ -276,15 +270,13 @@ class Prompt(str):
self.name,
args=list(args) if args else self._args,
_should_register=False,
**kwargs if kwargs else self._kwargs,
**kwargs or self._kwargs,
)
# print(f"prompt build result: {ret} name: {ret.name} ")
return str(ret)
def __str__(self) -> str:
if self._kwargs or self._args:
return super().__str__()
return self.template
return super().__str__() if self._kwargs or self._args else self.template
def __repr__(self) -> str:
return f"Prompt(template='{self.template}', name='{self.name}')"

View File

@@ -1,18 +1,17 @@
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
import asyncio
import concurrent.futures
import json
import os
import glob
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
from src.manager.async_task_manager import AsyncTask
from ...common.database.database import db # This db is the Peewee database instance
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic")
@@ -76,14 +75,14 @@ class OnlineTimeRecordTask(AsyncTask):
with db.atomic(): # Use atomic operations for schema changes
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
async def run(self):
async def run(self): # sourcery skip: use-named-expression
try:
current_time = datetime.now()
extended_end_time = current_time + timedelta(minutes=1)
if self.record_id:
# 如果有记录,则更新结束时间
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore
updated_rows = query.execute()
if updated_rows == 0:
# Record might have been deleted or ID is stale, try to find/create
@@ -94,7 +93,7 @@ class OnlineTimeRecordTask(AsyncTask):
# Look for a record whose end_timestamp is recent enough to be considered ongoing
recent_record = (
OnlineTime.select()
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore
.order_by(OnlineTime.end_timestamp.desc())
.first()
)
@@ -123,15 +122,15 @@ def _format_online_time(online_seconds: int) -> str:
:param online_seconds: 在线时间(秒)
:return: 格式化后的在线时间字符串
"""
total_oneline_time = timedelta(seconds=online_seconds)
total_online_time = timedelta(seconds=online_seconds)
days = total_oneline_time.days
hours = total_oneline_time.seconds // 3600
minutes = (total_oneline_time.seconds // 60) % 60
seconds = total_oneline_time.seconds % 60
days = total_online_time.days
hours = total_online_time.seconds // 3600
minutes = (total_online_time.seconds // 60) % 60
seconds = total_online_time.seconds % 60
if days > 0:
# 如果在线时间超过1天则格式化为"X天X小时X分钟"
return f"{total_oneline_time.days}{hours}小时{minutes}分钟{seconds}"
return f"{total_online_time.days}{hours}小时{minutes}分钟{seconds}"
elif hours > 0:
# 如果在线时间超过1小时则格式化为"X小时X分钟X秒"
return f"{hours}小时{minutes}分钟{seconds}"
@@ -163,7 +162,7 @@ class StatisticOutputTask(AsyncTask):
now = datetime.now()
if "deploy_time" in local_storage:
# 如果存在部署时间,则使用该时间作为全量统计的起始时间
deploy_time = datetime.fromtimestamp(local_storage["deploy_time"])
deploy_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
else:
# 否则,使用最大时间范围,并记录部署时间为当前时间
deploy_time = datetime(2000, 1, 1)
@@ -252,7 +251,7 @@ class StatisticOutputTask(AsyncTask):
# 创建后台任务,不等待完成
collect_task = asyncio.create_task(
loop.run_in_executor(executor, self._collect_all_statistics, now)
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
)
stats = await collect_task
@@ -260,8 +259,8 @@ class StatisticOutputTask(AsyncTask):
# 创建并发的输出任务
output_tasks = [
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)),
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)),
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
]
# 等待所有输出任务完成
@@ -320,7 +319,7 @@ class StatisticOutputTask(AsyncTask):
# 以最早的时间戳为起始时间获取记录
# Assuming LLMUsage.timestamp is a DateTimeField
query_start_time = collect_period[-1][1]
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_timestamp = record.timestamp # This is already a datetime object
for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
@@ -388,7 +387,7 @@ class StatisticOutputTask(AsyncTask):
query_start_time = collect_period[-1][1]
# Assuming OnlineTime.end_timestamp is a DateTimeField
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore
# record.end_timestamp and record.start_timestamp are datetime objects
record_end_timestamp = record.end_timestamp
record_start_timestamp = record.start_timestamp
@@ -428,7 +427,7 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp):
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time # This is a float timestamp
chat_id = None
@@ -661,7 +660,7 @@ class StatisticOutputTask(AsyncTask):
if "last_full_statistics" in local_storage:
# 如果存在上次完整统计数据,则使用该数据进行增量统计
last_stat = local_storage["last_full_statistics"] # 上次完整统计数据
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
@@ -727,6 +726,7 @@ class StatisticOutputTask(AsyncTask):
return stat
def _convert_defaultdict_to_dict(self, data):
# sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks
"""递归转换defaultdict为普通dict"""
if isinstance(data, defaultdict):
# 转换defaultdict为普通dict
@@ -812,8 +812,7 @@ class StatisticOutputTask(AsyncTask):
# 全局阶段平均时间
if stats[FOCUS_AVG_TIMES_BY_STAGE]:
output.append("全局阶段平均时间:")
for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items():
output.append(f" {stage}: {avg_time:.3f}")
output.extend(f" {stage}: {avg_time:.3f}" for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items())
output.append("")
# Action类型比例
@@ -1050,7 +1049,7 @@ class StatisticOutputTask(AsyncTask):
]
tab_content_list.append(
_format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"]))
_format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) # type: ignore
)
# 添加Focus统计内容
@@ -1212,6 +1211,7 @@ class StatisticOutputTask(AsyncTask):
f.write(html_template)
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
# sourcery skip: for-append-to-extend, list-comprehension, use-any
"""生成Focus统计独立分页的HTML内容"""
# 为每个时间段准备Focus数据
@@ -1313,12 +1313,11 @@ class StatisticOutputTask(AsyncTask):
# 聊天流Action选择比例对比表横向表格
focus_chat_action_ratios_rows = ""
if stat_data.get("focus_action_ratios_by_chat"):
# 获取所有action类型按全局频率排序
all_action_types_for_ratio = sorted(
stat_data[FOCUS_ACTION_RATIOS].keys(), key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x], reverse=True
)
if all_action_types_for_ratio:
if all_action_types_for_ratio := sorted(
stat_data[FOCUS_ACTION_RATIOS].keys(),
key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x],
reverse=True,
):
# 为每个聊天流生成数据行(按循环数排序)
chat_ratio_rows = []
for chat_id in sorted(
@@ -1379,16 +1378,11 @@ class StatisticOutputTask(AsyncTask):
if period_name == "all_time":
from src.manager.local_store_manager import local_storage
start_time = datetime.fromtimestamp(local_storage["deploy_time"])
time_range = (
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
else:
start_time = datetime.now() - period_delta
time_range = (
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# 生成该时间段的Focus统计HTML
section_html = f"""
<div class="focus-period-section">
@@ -1681,16 +1675,10 @@ class StatisticOutputTask(AsyncTask):
if period_name == "all_time":
from src.manager.local_store_manager import local_storage
start_time = datetime.fromtimestamp(local_storage["deploy_time"])
time_range = (
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
else:
start_time = datetime.now() - period_delta
time_range = (
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# 生成该时间段的版本对比HTML
section_html = f"""
<div class="version-period-section">
@@ -1865,7 +1853,7 @@ class StatisticOutputTask(AsyncTask):
# 查询LLM使用记录
query_start_time = start_time
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_time = record.timestamp
# 找到对应的时间间隔索引
@@ -1875,7 +1863,7 @@ class StatisticOutputTask(AsyncTask):
if 0 <= interval_index < len(time_points):
# 累加总花费数据
cost = record.cost or 0.0
total_cost_data[interval_index] += cost
total_cost_data[interval_index] += cost # type: ignore
# 累加按模型分类的花费
model_name = record.model_name or "unknown"
@@ -1892,7 +1880,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
for message in Messages.select().where(Messages.time >= query_start_timestamp):
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time
# 找到对应的时间间隔索引
@@ -1982,6 +1970,7 @@ class StatisticOutputTask(AsyncTask):
}
def _generate_chart_tab(self, chart_data: dict) -> str:
# sourcery skip: extract-duplicate-method, move-assign-in-block
"""生成图表选项卡HTML内容"""
# 生成不同颜色的调色板
@@ -2293,7 +2282,7 @@ class AsyncStatisticOutputTask(AsyncTask):
# 数据收集任务
collect_task = asyncio.create_task(
loop.run_in_executor(executor, self._collect_all_statistics, now)
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
)
stats = await collect_task
@@ -2301,8 +2290,8 @@ class AsyncStatisticOutputTask(AsyncTask):
# 创建并发的输出任务
output_tasks = [
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)),
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)),
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
]
# 等待所有输出任务完成

View File

@@ -1,7 +1,8 @@
import asyncio
from time import perf_counter
from functools import wraps
from typing import Optional, Dict, Callable
import asyncio
from rich.traceback import install
install(extra_lines=3)
@@ -88,10 +89,10 @@ class Timer:
self.name = name
self.storage = storage
self.elapsed = None
self.elapsed: float = None # type: ignore
self.auto_unit = auto_unit
self.start = None
self.start: float = None # type: ignore
@staticmethod
def _validate_types(name, storage):
@@ -120,7 +121,7 @@ class Timer:
return None
wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
wrapper.__timer__ = self # 保留计时器引用
wrapper.__timer__ = self # 保留计时器引用 # type: ignore
return wrapper
def __enter__(self):

View File

@@ -7,10 +7,10 @@ import math
import os
import random
import time
import jieba
from collections import defaultdict
from pathlib import Path
import jieba
from pypinyin import Style, pinyin
from src.common.logger import get_logger
@@ -104,7 +104,7 @@ class ChineseTypoGenerator:
try:
return "\u4e00" <= char <= "\u9fff"
except Exception as e:
logger.debug(e)
logger.debug(str(e))
return False
def _get_pinyin(self, sentence):
@@ -138,7 +138,7 @@ class ChineseTypoGenerator:
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1
return py + "1"
return f"{py}1"
base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调

View File

@@ -1,23 +1,21 @@
import random
import re
import time
from collections import Counter
import jieba
import numpy as np
from collections import Counter
from maim_message import UserInfo
from typing import Optional, Tuple, Dict
from src.common.logger import get_logger
# from src.mood.mood_manager import mood_manager
from ..message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest
from .typo_generator import ChineseTypoGenerator
from ...config.config import global_config
from ...common.message_repository import find_messages, count_messages
from typing import Optional, Tuple, Dict
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from .typo_generator import ChineseTypoGenerator
logger = get_logger("chat_utils")
@@ -31,11 +29,7 @@ def db_message_to_str(message_dict: dict) -> str:
logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try:
name = "[(%s)%s]%s" % (
message_dict["user_id"],
message_dict.get("user_nickname", ""),
message_dict.get("user_cardname", ""),
)
name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
except Exception:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "")
@@ -58,11 +52,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
and message.message_info.additional_config.get("is_mentioned") is not None
):
try:
reply_probability = float(message.message_info.additional_config.get("is_mentioned"))
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
is_mentioned = True
return is_mentioned, reply_probability
except Exception as e:
logger.warning(e)
logger.warning(str(e))
logger.warning(
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
)
@@ -135,20 +129,17 @@ def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, c
if not recent_messages:
return []
message_detailed_plain_text = ""
message_detailed_plain_text_list = []
# 反转消息列表,使最新的消息在最后
recent_messages.reverse()
if combine:
for msg_db_data in recent_messages:
message_detailed_plain_text += str(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text
else:
for msg_db_data in recent_messages:
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text_list
return "".join(str(msg_db_data["detailed_plain_text"]) for msg_db_data in recent_messages)
message_detailed_plain_text_list = []
for msg_db_data in recent_messages:
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text_list
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
@@ -204,10 +195,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
len_text = len(text)
if len_text < 3:
if random.random() < 0.01:
return list(text) # 如果文本很短且触发随机条件,直接按字符分割
else:
return [text]
return list(text) if random.random() < 0.01 else [text]
# 定义分隔符
separators = {"", ",", " ", "", ";"}
@@ -352,10 +340,9 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
max_length = global_config.response_splitter.max_length * 2
max_sentence_num = global_config.response_splitter.max_sentence_num
# 如果基本上是中文,则进行长度过滤
if get_western_ratio(cleaned_text) < 0.1:
if len(cleaned_text) > max_length:
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
return ["懒得说"]
if get_western_ratio(cleaned_text) < 0.1 and len(cleaned_text) > max_length:
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
return ["懒得说"]
typo_generator = ChineseTypoGenerator(
error_rate=global_config.chinese_typo.error_rate,
@@ -420,7 +407,7 @@ def calculate_typing_time(
# chinese_time *= 1 / typing_speed_multiplier
# english_time *= 1 / typing_speed_multiplier
# 计算中文字符数
chinese_chars = sum(1 for char in input_string if "\u4e00" <= char <= "\u9fff")
chinese_chars = sum("\u4e00" <= char <= "\u9fff" for char in input_string)
# 如果只有一个中文字符使用3倍时间
if chinese_chars == 1 and len(input_string.strip()) == 1:
@@ -429,11 +416,7 @@ def calculate_typing_time(
# 正常计算所有字符的输入时间
total_time = 0.0
for char in input_string:
if "\u4e00" <= char <= "\u9fff": # 判断是否为中文字符
total_time += chinese_time
else: # 其他字符(如英文)
total_time += english_time
total_time += chinese_time if "\u4e00" <= char <= "\u9fff" else english_time
if is_emoji:
total_time = 1
@@ -453,18 +436,14 @@ def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
if norm1 == 0 or norm2 == 0:
return 0
return dot_product / (norm1 * norm2)
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
def text_to_vector(text):
"""将文本转换为词频向量"""
# 分词
words = jieba.lcut(text)
# 统计词频
word_freq = Counter(words)
return word_freq
return Counter(words)
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
@@ -491,9 +470,7 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度"""
if len(message) > max_length:
return message[:max_length] + "..."
return message
return f"{message[:max_length]}..." if len(message) > max_length else message
def protect_kaomoji(sentence):
@@ -522,7 +499,7 @@ def protect_kaomoji(sentence):
placeholder_to_kaomoji = {}
for idx, match in enumerate(kaomoji_matches):
kaomoji = match[0] if match[0] else match[1]
kaomoji = match[0] or match[1]
placeholder = f"__KAOMOJI_{idx}__"
sentence = sentence.replace(kaomoji, placeholder, 1)
placeholder_to_kaomoji[placeholder] = kaomoji
@@ -563,7 +540,7 @@ def get_western_ratio(paragraph):
if not alnum_chars:
return 0.0
western_count = sum(1 for char in alnum_chars if is_english_letter(char))
western_count = sum(bool(is_english_letter(char)) for char in alnum_chars)
return western_count / len(alnum_chars)
@@ -610,6 +587,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
"""将时间戳转换为人类可读的时间格式
Args:
@@ -621,7 +599,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
"""
if mode == "normal":
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
if mode == "normal_no_YMD":
elif mode == "normal_no_YMD":
return time.strftime("%H:%M:%S", time.localtime(timestamp))
elif mode == "relative":
now = time.time()
@@ -640,7 +618,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
else:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":"
else: # mode = "lite" or unknown
# 只返回时分秒格式,喵~
# 只返回时分秒格式
return time.strftime("%H:%M:%S", time.localtime(timestamp))
@@ -670,8 +648,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
elif chat_stream.user_info: # It's a private chat
is_group_chat = False
user_info = chat_stream.user_info
platform = chat_stream.platform
user_id = user_info.user_id
platform: str = chat_stream.platform # type: ignore
user_id: str = user_info.user_id # type: ignore
# Initialize target_info with basic info
target_info = {

View File

@@ -3,21 +3,20 @@ import os
import time
import hashlib
import uuid
import io
import asyncio
import numpy as np
from typing import Optional, Tuple
from PIL import Image
import io
import numpy as np
import asyncio
from rich.traceback import install
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("chat_image")
@@ -111,7 +110,7 @@ class ImageManager:
return f"[表情包,含义看起来是:{cached_description}]"
# 调用AI获取描述
if image_format == "gif" or image_format == "GIF":
if image_format in ["gif", "GIF"]:
image_base64_processed = self.transform_gif(image_base64)
if image_base64_processed is None:
logger.warning("GIF转换失败无法获取描述")
@@ -258,6 +257,7 @@ class ImageManager:
@staticmethod
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
# sourcery skip: use-contextlib-suppress
"""将GIF转换为水平拼接的静态图像, 跳过相似的帧
Args:
@@ -351,7 +351,7 @@ class ImageManager:
# 创建拼接图像
total_width = target_width * len(resized_frames)
# 防止总宽度为0
if total_width == 0 and len(resized_frames) > 0:
if total_width == 0 and resized_frames:
logger.warning("计算出的总宽度为0但有选中帧可能目标宽度太小")
# 至少给点宽度吧
total_width = len(resized_frames)
@@ -368,10 +368,7 @@ class ImageManager:
# 转换为base64
buffer = io.BytesIO()
combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG
result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return result_base64
return base64.b64encode(buffer.getvalue()).decode("utf-8")
except MemoryError:
logger.error("GIF转换失败: 内存不足可能是GIF太大或帧数太多")
return None # 内存不够啦
@@ -380,6 +377,7 @@ class ImageManager:
return None # 其他错误也返回None
async def process_image(self, image_base64: str) -> Tuple[str, str]:
# sourcery skip: hoist-if-from-if
"""处理图片并返回图片ID和描述
Args:
@@ -418,17 +416,9 @@ class ImageManager:
if existing_image.vlm_processed is None:
existing_image.vlm_processed = False
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else:
# print(f"图片已存在: {existing_image.image_id}")
# print(f"图片描述: {existing_image.description}")
# print(f"图片计数: {existing_image.count}")
# 更新计数
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else:
# print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4())

View File

@@ -54,11 +54,11 @@ class DBWrapper:
return getattr(get_db(), name)
def __getitem__(self, key):
return get_db()[key]
return get_db()[key] # type: ignore
# 全局数据库访问点
memory_db: Database = DBWrapper()
memory_db: Database = DBWrapper() # type: ignore
# 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))

View File

@@ -406,9 +406,7 @@ def initialize_database():
existing_columns = {row[1] for row in cursor.fetchall()}
model_fields = set(model._meta.fields.keys())
# 检查并添加缺失字段(原有逻辑)
missing_fields = model_fields - existing_columns
if missing_fields:
if missing_fields := model_fields - existing_columns:
logger.warning(f"'{table_name}' 缺失字段: {missing_fields}")
for field_name, field_obj in model._meta.fields.items():
@@ -424,10 +422,7 @@ def initialize_database():
"DateTimeField": "DATETIME",
}.get(field_type, "TEXT")
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
if field_obj.null:
alter_sql += " NULL"
else:
alter_sql += " NOT NULL"
alter_sql += " NULL" if field_obj.null else " NOT NULL"
if hasattr(field_obj, "default") and field_obj.default is not None:
# 正确处理不同类型的默认值
default_value = field_obj.default

View File

@@ -1,16 +1,16 @@
import logging
# 使用基于时间戳的文件处理器,简单的轮转份数限制
from pathlib import Path
from typing import Callable, Optional
import logging
import json
import threading
import time
from datetime import datetime, timedelta
import structlog
import toml
from pathlib import Path
from typing import Callable, Optional
from datetime import datetime, timedelta
# 创建logs目录
LOG_DIR = Path("logs")
LOG_DIR.mkdir(exist_ok=True)
@@ -160,7 +160,7 @@ def close_handlers():
_console_handler = None
def remove_duplicate_handlers():
def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-comprehension
"""移除重复的handler特别是文件handler"""
root_logger = logging.getLogger()
@@ -184,7 +184,7 @@ def remove_duplicate_handlers():
# 读取日志配置
def load_log_config():
def load_log_config(): # sourcery skip: use-contextlib-suppress
"""从配置文件加载日志设置"""
config_path = Path("config/bot_config.toml")
default_config = {
@@ -365,7 +365,7 @@ MODULE_COLORS = {
"component_registry": "\033[38;5;214m", # 橙黄色
"stream_api": "\033[38;5;220m", # 黄色
"config_api": "\033[38;5;226m", # 亮黄色
"hearflow_api": "\033[38;5;154m", # 黄绿色
"heartflow_api": "\033[38;5;154m", # 黄绿色
"action_apis": "\033[38;5;118m", # 绿色
"independent_apis": "\033[38;5;82m", # 绿色
"llm_api": "\033[38;5;46m", # 亮绿色
@@ -412,6 +412,7 @@ class ModuleColoredConsoleRenderer:
"""自定义控制台渲染器,为不同模块提供不同颜色"""
def __init__(self, colors=True):
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
self._colors = colors
self._config = LOG_CONFIG
@@ -443,6 +444,7 @@ class ModuleColoredConsoleRenderer:
self._enable_full_content_colors = False
def __call__(self, logger, method_name, event_dict):
# sourcery skip: merge-duplicate-blocks
"""渲染日志消息"""
# 获取基本信息
timestamp = event_dict.get("timestamp", "")
@@ -662,7 +664,7 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
"""获取logger实例支持按名称绑定"""
if name is None:
return raw_logger
logger = binds.get(name)
logger = binds.get(name) # type: ignore
if logger is None:
logger: structlog.stdlib.BoundLogger = structlog.get_logger(name).bind(logger_name=name)
binds[name] = logger
@@ -671,8 +673,8 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
def configure_logging(
level: str = "INFO",
console_level: str = None,
file_level: str = None,
console_level: Optional[str] = None,
file_level: Optional[str] = None,
max_bytes: int = 5 * 1024 * 1024,
backup_count: int = 30,
log_dir: str = "logs",
@@ -729,14 +731,11 @@ def reload_log_config():
global LOG_CONFIG
LOG_CONFIG = load_log_config()
# 重新设置handler的日志级别
file_handler = get_file_handler()
if file_handler:
if file_handler := get_file_handler():
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO))
console_handler = get_console_handler()
if console_handler:
if console_handler := get_console_handler():
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO))
@@ -780,8 +779,7 @@ def set_console_log_level(level: str):
global LOG_CONFIG
LOG_CONFIG["console_log_level"] = level.upper()
console_handler = get_console_handler()
if console_handler:
if console_handler := get_console_handler():
console_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
# 重新设置root logger级别
@@ -800,8 +798,7 @@ def set_file_log_level(level: str):
global LOG_CONFIG
LOG_CONFIG["file_log_level"] = level.upper()
file_handler = get_file_handler()
if file_handler:
if file_handler := get_file_handler():
file_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
# 重新设置root logger级别
@@ -933,13 +930,12 @@ def format_json_for_logging(data, indent=2, ensure_ascii=False):
Returns:
str: 格式化后的JSON字符串
"""
if isinstance(data, str):
# 如果是JSON字符串先解析再格式化
parsed_data = json.loads(data)
return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii)
else:
if not isinstance(data, str):
# 如果是对象,直接格式化
return json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
# 如果是JSON字符串先解析再格式化
parsed_data = json.loads(data)
return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii)
def cleanup_old_logs():

View File

@@ -8,7 +8,7 @@ from src.config.config import global_config
global_api = None
def get_global_api() -> MessageServer:
def get_global_api() -> MessageServer: # sourcery skip: extract-method
"""获取全局MessageServer实例"""
global global_api
if global_api is None:
@@ -36,9 +36,8 @@ def get_global_api() -> MessageServer:
kwargs["custom_logger"] = maim_message_logger
# 添加token认证
if maim_message_config.auth_token:
if len(maim_message_config.auth_token) > 0:
kwargs["enable_token"] = True
if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
kwargs["enable_token"] = True
if maim_message_config.use_custom:
# 添加WSS模式支持

View File

@@ -1,9 +1,11 @@
from src.common.database.database_model import Messages # 更改导入
from src.common.logger import get_logger
import traceback
from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入
from src.common.database.database_model import Messages
from src.common.logger import get_logger
logger = get_logger(__name__)

View File

@@ -23,7 +23,7 @@ class TelemetryHeartBeatTask(AsyncTask):
self.server_url = TELEMETRY_SERVER_URL
"""遥测服务地址"""
self.client_uuid = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None
self.client_uuid: str | None = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None # type: ignore
"""客户端UUID"""
self.info_dict = self._get_sys_info()
@@ -72,7 +72,7 @@ class TelemetryHeartBeatTask(AsyncTask):
timeout=aiohttp.ClientTimeout(total=5), # 设置超时时间为5秒
) as response:
logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client")
logger.debug(local_storage["deploy_time"])
logger.debug(local_storage["deploy_time"]) # type: ignore
logger.debug(f"Response status: {response.status}")
if response.status == 200:
@@ -93,7 +93,7 @@ class TelemetryHeartBeatTask(AsyncTask):
except Exception as e:
import traceback
error_msg = str(e) if str(e) else "未知错误"
error_msg = str(e) or "未知错误"
logger.warning(
f"请求UUID出错不过你还是可以正常使用麦麦: {type(e).__name__}: {error_msg}"
) # 可能是网络问题
@@ -114,11 +114,11 @@ class TelemetryHeartBeatTask(AsyncTask):
"""向服务器发送心跳"""
headers = {
"Client-UUID": self.client_uuid,
"User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}",
"User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}", # type: ignore
}
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
logger.debug(headers)
logger.debug(str(headers))
try:
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
@@ -151,7 +151,7 @@ class TelemetryHeartBeatTask(AsyncTask):
except Exception as e:
import traceback
error_msg = str(e) if str(e) else "未知错误"
error_msg = str(e) or "未知错误"
logger.warning(f"(此消息不会影响正常使用)状态未发生: {type(e).__name__}: {error_msg}")
logger.debug(f"完整错误信息: {traceback.format_exc()}")

View File

@@ -1,5 +1,6 @@
import shutil
import tomlkit
from tomlkit.items import Table
from pathlib import Path
from datetime import datetime
@@ -45,8 +46,8 @@ def update_config():
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version")
new_version = new_config["inner"].get("version")
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
print(f"检测到版本号相同 (v{old_version}),跳过更新")
# 如果version相同恢复旧配置文件并返回
@@ -62,7 +63,7 @@ def update_config():
if key == "version":
continue
if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)):
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value)
else:
try:
@@ -85,10 +86,7 @@ def update_config():
if value and isinstance(value[0], dict) and "regex" in value[0]:
contains_regex = True
if contains_regex:
target[key] = value
else:
target[key] = tomlkit.array(value)
target[key] = value if contains_regex else tomlkit.array(str(value))
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)

View File

@@ -1,16 +1,14 @@
import os
from dataclasses import field, dataclass
import tomlkit
import shutil
from datetime import datetime
from datetime import datetime
from tomlkit import TOMLDocument
from tomlkit.items import Table
from src.common.logger import get_logger
from dataclasses import field, dataclass
from rich.traceback import install
from src.common.logger import get_logger
from src.config.config_base import ConfigBase
from src.config.official_configs import (
BotConfig,
@@ -80,8 +78,8 @@ def update_config():
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version")
new_version = new_config["inner"].get("version")
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
return
@@ -103,7 +101,7 @@ def update_config():
shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新配置文件: {new_config_path}")
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
@@ -112,8 +110,9 @@ def update_config():
if key == "version":
continue
if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value)
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理

View File

@@ -43,7 +43,7 @@ class ConfigBase:
field_type = f.type
try:
init_args[field_name] = cls._convert_field(value, field_type)
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
except Exception as e:

View File

@@ -1,7 +1,8 @@
from dataclasses import dataclass, field
from typing import Any, Literal
import re
from dataclasses import dataclass, field
from typing import Any, Literal, Optional
from src.config.config_base import ConfigBase
"""
@@ -113,7 +114,7 @@ class ChatConfig(ConfigBase):
exit_focus_threshold: float = 1.0
"""自动退出专注聊天的阈值,越低越容易退出专注聊天"""
def get_current_talk_frequency(self, chat_stream_id: str = None) -> float:
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
"""
根据当前时间和聊天流获取对应的 talk_frequency
@@ -138,7 +139,7 @@ class ChatConfig(ConfigBase):
# 如果都没有匹配,返回默认值
return self.talk_frequency
def _get_time_based_frequency(self, time_freq_list: list[str]) -> float:
def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
"""
根据时间配置列表获取当前时段的频率
@@ -186,7 +187,7 @@ class ChatConfig(ConfigBase):
return current_frequency
def _get_stream_specific_frequency(self, chat_stream_id: str) -> float:
def _get_stream_specific_frequency(self, chat_stream_id: str):
"""
获取特定聊天流在当前时间的频率
@@ -217,7 +218,7 @@ class ChatConfig(ConfigBase):
return None
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> str:
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
"""
解析流配置字符串并生成对应的 chat_id

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List
from typing import List, Optional
@dataclass
@@ -8,7 +8,7 @@ class Identity:
identity_detail: List[str] # 身份细节描述
def __init__(self, identity_detail: List[str] = None):
def __init__(self, identity_detail: Optional[List[str]] = None):
"""初始化身份特征
Args:

View File

@@ -1,17 +1,18 @@
from typing import Optional
import ast
from src.llm_models.utils_model import LLMRequest
from .personality import Personality
from .identity import Identity
import random
import json
import os
import hashlib
from typing import Optional
from rich.traceback import install
from src.common.logger import get_logger
from src.person_info.person_info import get_person_info_manager
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import get_person_info_manager
from .personality import Personality
from .identity import Identity
install(extra_lines=3)
@@ -23,7 +24,7 @@ class Individuality:
def __init__(self):
# 正常初始化实例属性
self.personality: Optional[Personality] = None
self.personality: Personality = None # type: ignore
self.identity: Optional[Identity] = None
self.name = ""
@@ -109,7 +110,7 @@ class Individuality:
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
if existing_short_impression:
try:
existing_data = ast.literal_eval(existing_short_impression)
existing_data = ast.literal_eval(existing_short_impression) # type: ignore
if isinstance(existing_data, list) and len(existing_data) >= 1:
personality_result = existing_data[0]
except (json.JSONDecodeError, TypeError, IndexError):
@@ -128,7 +129,7 @@ class Individuality:
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
if existing_short_impression:
try:
existing_data = ast.literal_eval(existing_short_impression)
existing_data = ast.literal_eval(existing_short_impression) # type: ignore
if isinstance(existing_data, list) and len(existing_data) >= 2:
identity_result = existing_data[1]
except (json.JSONDecodeError, TypeError, IndexError):
@@ -204,6 +205,7 @@ class Individuality:
return prompt_personality
def get_identity_prompt(self, level: int, x_person: int = 2) -> str:
# sourcery skip: assign-if-exp, merge-else-if-into-elif
"""
获取身份特征的prompt
@@ -240,13 +242,13 @@ class Individuality:
if identity_parts:
details_str = "".join(identity_parts)
if x_person in [1, 2]:
if x_person in {1, 2}:
return f"{i_pronoun}{details_str}"
else: # x_person == 0
# 无人称时,直接返回细节,不加代词和开头的逗号
return f"{details_str}"
else:
if x_person in [1, 2]:
if x_person in {1, 2}:
return f"{i_pronoun}的身份信息不完整。"
else: # x_person == 0
return "身份信息不完整。"
@@ -441,14 +443,15 @@ class Individuality:
if info_list_json:
try:
info_list = json.loads(info_list_json) if isinstance(info_list_json, str) else info_list_json
for item in info_list:
if isinstance(item, dict) and "info_type" in item:
keywords.append(item["info_type"])
keywords.extend(
item["info_type"] for item in info_list if isinstance(item, dict) and "info_type" in item
)
except (json.JSONDecodeError, TypeError):
logger.error(f"解析info_list失败: {info_list_json}")
return keywords
async def _create_personality(self, personality_core: str, personality_sides: list) -> str:
# sourcery skip: merge-list-append, move-assign
"""使用LLM创建压缩版本的impression
Args:

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Dict, List
import json
from dataclasses import dataclass
from typing import Dict, List, Optional
from pathlib import Path
@@ -24,7 +25,7 @@ class Personality:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, personality_core: str = "", personality_sides: List[str] = None):
def __init__(self, personality_core: str = "", personality_sides: Optional[List[str]] = None):
if personality_sides is None:
personality_sides = []
self.personality_core = personality_core
@@ -41,7 +42,7 @@ class Personality:
cls._instance = cls()
return cls._instance
def _init_big_five_personality(self):
def _init_big_five_personality(self): # sourcery skip: extract-method
"""初始化大五人格特质"""
# 构建文件路径
personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per"
@@ -63,7 +64,6 @@ class Personality:
else:
self.extraversion = 0.3
self.neuroticism = 0.5
if "认真" in self.personality_core or "负责" in self.personality_sides:
self.conscientiousness = 0.9
else:

View File

@@ -120,12 +120,7 @@ class AsyncTaskManager:
"""
获取所有任务的状态
"""
tasks_status = {}
for task_name, task in self.tasks.items():
tasks_status[task_name] = {
"status": "running" if not task.done() else "done",
}
return tasks_status
return {task_name: {"status": "done" if task.done() else "running"} for task_name, task in self.tasks.items()}
async def stop_and_wait_all_tasks(self):
"""

View File

@@ -2,12 +2,12 @@ import math
import random
import time
from src.chat.message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest
from ..common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.llm_models.utils_model import LLMRequest
from src.manager.async_task_manager import AsyncTask, async_task_manager
logger = get_logger("mood")
@@ -55,12 +55,12 @@ class ChatMood:
request_type="mood",
)
self.last_change_time = 0
self.last_change_time: float = 0
async def update_mood_by_message(self, message: MessageRecv, interested_rate: float):
self.regression_count = 0
during_last_time = message.message_info.time - self.last_change_time
during_last_time = message.message_info.time - self.last_change_time # type: ignore
base_probability = 0.05
time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time))
@@ -78,7 +78,7 @@ class ChatMood:
if random.random() > update_probability:
return
message_time = message.message_info.time
message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
@@ -119,7 +119,7 @@ class ChatMood:
self.mood_state = response
self.last_change_time = message_time
self.last_change_time = message_time # type: ignore
async def regress_mood(self):
message_time = time.time()

View File

@@ -1,17 +1,18 @@
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import PersonInfo # 新增导入
import copy
import hashlib
from typing import Any, Callable, Dict, Union
import datetime
import asyncio
import json
from json_repair import repair_json
from typing import Any, Callable, Dict, Union, Optional
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
import json # 新增导入
from json_repair import repair_json
"""
PersonInfoManager 类方法功能摘要:
@@ -42,7 +43,7 @@ person_info_default = {
"last_know": None,
# "user_cardname": None, # This field is not in Peewee model PersonInfo
# "user_avatar": None, # This field is not in Peewee model PersonInfo
"impression": None, # Corrected from persion_impression
"impression": None, # Corrected from person_impression
"short_impression": None,
"info_list": None,
"points": None,
@@ -106,27 +107,24 @@ class PersonInfoManager:
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
return False
def get_person_id_by_person_name(self, person_name: str):
def get_person_id_by_person_name(self, person_name: str) -> str:
"""根据用户名获取用户ID"""
try:
record = PersonInfo.get_or_none(PersonInfo.person_name == person_name)
if record:
return record.person_id
else:
return ""
return record.person_id if record else ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
return ""
@staticmethod
async def create_person_info(person_id: str, data: dict = None):
async def create_person_info(person_id: str, data: Optional[dict] = None):
"""创建一个项"""
if not person_id:
logger.debug("创建失败personid不存在")
logger.debug("创建失败person_id不存在")
return
_person_info_default = copy.deepcopy(person_info_default)
model_fields = PersonInfo._meta.fields.keys()
model_fields = PersonInfo._meta.fields.keys() # type: ignore
final_data = {"person_id": person_id}
@@ -163,9 +161,9 @@ class PersonInfoManager:
await asyncio.to_thread(_db_create_sync, final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
"""更新某一个字段,会补全"""
if field_name not in PersonInfo._meta.fields:
if field_name not in PersonInfo._meta.fields: # type: ignore
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
return
@@ -228,15 +226,13 @@ class PersonInfoManager:
@staticmethod
async def has_one_field(person_id: str, field_name: str):
"""判断是否存在某一个字段"""
if field_name not in PersonInfo._meta.fields:
if field_name not in PersonInfo._meta.fields: # type: ignore
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
return False
def _db_has_field_sync(p_id: str, f_name: str):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
return True
return False
return bool(record)
try:
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
@@ -435,9 +431,7 @@ class PersonInfoManager:
except Exception as e:
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
# Fallback to default in case of any error during DB access
if field_name in person_info_default:
return default_value_for_field
return None
return default_value_for_field if field_name in person_info_default else None
@staticmethod
def get_value_sync(person_id: str, field_name: str):
@@ -446,8 +440,7 @@ class PersonInfoManager:
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = []
record = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if record:
if record := PersonInfo.get_or_none(PersonInfo.person_id == person_id):
val = getattr(record, field_name, None)
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
@@ -481,7 +474,7 @@ class PersonInfoManager:
record = await asyncio.to_thread(_db_get_record_sync, person_id)
for field_name in field_names:
if field_name not in PersonInfo._meta.fields:
if field_name not in PersonInfo._meta.fields: # type: ignore
if field_name in person_info_default:
result[field_name] = copy.deepcopy(person_info_default[field_name])
logger.debug(f"字段'{field_name}'不在Peewee模型中使用默认配置值。")
@@ -509,7 +502,7 @@ class PersonInfoManager:
"""
获取满足条件的字段值字典
"""
if field_name not in PersonInfo._meta.fields:
if field_name not in PersonInfo._meta.fields: # type: ignore
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义")
return {}
@@ -531,7 +524,7 @@ class PersonInfoManager:
return {}
async def get_or_create_person(
self, platform: str, user_id: int, nickname: str = None, user_cardname: str = None, user_avatar: str = None
self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None
) -> str:
"""
根据 platform 和 user_id 获取 person_id。
@@ -561,7 +554,7 @@ class PersonInfoManager:
"points": [],
"forgotten_points": [],
}
model_fields = PersonInfo._meta.fields.keys()
model_fields = PersonInfo._meta.fields.keys() # type: ignore
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
await self.create_person_info(person_id, data=filtered_initial_data)
@@ -610,7 +603,9 @@ class PersonInfoManager:
"name_reason",
]
valid_fields_to_get = [
f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default
f
for f in required_fields
if f in PersonInfo._meta.fields or f in person_info_default # type: ignore
]
person_data = await self.get_values(found_person_id, valid_fields_to_get)

View File

@@ -3,12 +3,12 @@ import traceback
import os
import pickle
import random
from typing import List, Dict
from typing import List, Dict, Any
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.relationship_manager import get_relationship_manager
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
@@ -45,7 +45,7 @@ class RelationshipBuilder:
self.chat_id = chat_id
# 新的消息段缓存结构:
# {person_id: [{"start_time": float, "end_time": float, "last_msg_time": float, "message_count": int}, ...]}
self.person_engaged_cache: Dict[str, List[Dict[str, any]]] = {}
self.person_engaged_cache: Dict[str, List[Dict[str, Any]]] = {}
# 持久化存储文件路径
self.cache_file_path = os.path.join("data", "relationship", f"relationship_cache_{self.chat_id}.pkl")
@@ -210,11 +210,7 @@ class RelationshipBuilder:
if person_id not in self.person_engaged_cache:
return 0
total_count = 0
for segment in self.person_engaged_cache[person_id]:
total_count += segment["message_count"]
return total_count
return sum(segment["message_count"] for segment in self.person_engaged_cache[person_id])
def _cleanup_old_segments(self) -> bool:
"""清理老旧的消息段"""
@@ -289,7 +285,7 @@ class RelationshipBuilder:
self.last_cleanup_time = current_time
# 保存缓存
if cleanup_stats["segments_removed"] > 0 or len(users_to_remove) > 0:
if cleanup_stats["segments_removed"] > 0 or users_to_remove:
self._save_cache()
logger.info(
f"{self.log_prefix} 清理完成 - 影响用户: {cleanup_stats['users_cleaned']}, 移除消息段: {cleanup_stats['segments_removed']}, 移除用户: {len(users_to_remove)}"
@@ -313,6 +309,7 @@ class RelationshipBuilder:
return False
def get_cache_status(self) -> str:
# sourcery skip: merge-list-append, merge-list-appends-into-extend
"""获取缓存状态信息,用于调试和监控"""
if not self.person_engaged_cache:
return f"{self.log_prefix} 关系缓存为空"
@@ -357,13 +354,12 @@ class RelationshipBuilder:
self._cleanup_old_segments()
current_time = time.time()
latest_messages = get_raw_msg_by_timestamp_with_chat(
if latest_messages := get_raw_msg_by_timestamp_with_chat(
self.chat_id,
self.last_processed_message_time,
current_time,
limit=50, # 获取自上次处理后的消息
)
if latest_messages:
):
# 处理所有新的非bot消息
for latest_msg in latest_messages:
user_id = latest_msg.get("user_id")
@@ -414,7 +410,7 @@ class RelationshipBuilder:
# 负责触发关系构建、整合消息段、更新用户印象
# ================================
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, any]]):
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]):
"""基于消息段更新用户印象"""
original_segment_count = len(segments)
logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象")

View File

@@ -1,4 +1,5 @@
from typing import Dict, Optional, List
from typing import Dict, Optional, List, Any
from src.common.logger import get_logger
from .relationship_builder import RelationshipBuilder
@@ -63,7 +64,7 @@ class RelationshipBuilderManager:
"""
return list(self.builders.keys())
def get_status(self) -> Dict[str, any]:
def get_status(self) -> Dict[str, Any]:
"""获取管理器状态
Returns:
@@ -94,9 +95,7 @@ class RelationshipBuilderManager:
bool: 是否成功清理
"""
builder = self.get_builder(chat_id)
if builder:
return builder.force_cleanup_user_segments(person_id)
return False
return builder.force_cleanup_user_segments(person_id) if builder else False
# 全局管理器实例

View File

@@ -1,16 +1,19 @@
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
import time
import traceback
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.person_info.person_info import get_person_info_manager
from typing import List, Dict
from json_repair import repair_json
from src.chat.message_receive.chat_stream import get_chat_manager
import json
import random
from typing import List, Dict, Any
from json_repair import repair_json
from src.common.logger import get_logger
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
logger = get_logger("relationship_fetcher")
@@ -62,11 +65,11 @@ class RelationshipFetcher:
self.chat_id = chat_id
# 信息获取缓存:记录正在获取的信息请求
self.info_fetching_cache: List[Dict[str, any]] = []
self.info_fetching_cache: List[Dict[str, Any]] = []
# 信息结果缓存存储已获取的信息结果带TTL
self.info_fetched_cache: Dict[str, Dict[str, any]] = {}
# 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknow": bool}}}
self.info_fetched_cache: Dict[str, Dict[str, Any]] = {}
# 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}}
# LLM模型配置
self.llm_model = LLMRequest(
@@ -184,7 +187,7 @@ class RelationshipFetcher:
nickname_str = ",".join(global_config.bot.alias_names)
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
person_info_manager = get_person_info_manager()
person_name = await person_info_manager.get_value(person_id, "person_name")
person_name: str = await person_info_manager.get_value(person_id, "person_name") # type: ignore
info_cache_block = self._build_info_cache_block()
@@ -208,8 +211,7 @@ class RelationshipFetcher:
logger.debug(f"{self.log_prefix} LLM判断当前不需要查询任何信息{content_json.get('none', '')}")
return None
info_type = content_json.get("info_type")
if info_type:
if info_type := content_json.get("info_type"):
# 记录信息获取请求
self.info_fetching_cache.append(
{
@@ -287,7 +289,7 @@ class RelationshipFetcher:
"ttl": 2,
"start_time": start_time,
"person_name": person_name,
"unknow": cached_info == "none",
"unknown": cached_info == "none",
}
logger.info(f"{self.log_prefix} 记得 {person_name}{info_type}: {cached_info}")
return
@@ -321,7 +323,7 @@ class RelationshipFetcher:
"ttl": 2,
"start_time": start_time,
"person_name": person_name,
"unknow": True,
"unknown": True,
}
logger.info(f"{self.log_prefix} 完全不认识 {person_name}")
await self._save_info_to_cache(person_id, info_type, "none")
@@ -353,15 +355,15 @@ class RelationshipFetcher:
if person_id not in self.info_fetched_cache:
self.info_fetched_cache[person_id] = {}
self.info_fetched_cache[person_id][info_type] = {
"info": "unknow" if is_unknown else info_content,
"info": "unknown" if is_unknown else info_content,
"ttl": 3,
"start_time": start_time,
"person_name": person_name,
"unknow": is_unknown,
"unknown": is_unknown,
}
# 保存到持久化缓存 (info_list)
await self._save_info_to_cache(person_id, info_type, info_content if not is_unknown else "none")
await self._save_info_to_cache(person_id, info_type, "none" if is_unknown else info_content)
if not is_unknown:
logger.info(f"{self.log_prefix} 思考得到,{person_name}{info_type}: {info_content}")
@@ -393,7 +395,7 @@ class RelationshipFetcher:
for info_type in self.info_fetched_cache[person_id]:
person_name = self.info_fetched_cache[person_id][info_type]["person_name"]
if not self.info_fetched_cache[person_id][info_type]["unknow"]:
if not self.info_fetched_cache[person_id][info_type]["unknown"]:
info_content = self.info_fetched_cache[person_id][info_type]["info"]
person_known_infos.append(f"[{info_type}]{info_content}")
else:
@@ -430,6 +432,7 @@ class RelationshipFetcher:
return persons_infos_str
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
# sourcery skip: use-next
"""将提取到的信息保存到 person_info 的 info_list 字段中
Args:

View File

@@ -1,5 +1,5 @@
from src.common.logger import get_logger
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from .person_info import PersonInfoManager, get_person_info_manager
import time
import random
from src.llm_models.utils_model import LLMRequest
@@ -12,7 +12,7 @@ from difflib import SequenceMatcher
import jieba
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Dict, Any
logger = get_logger("relation")
@@ -28,8 +28,7 @@ class RelationshipManager:
async def is_known_some_one(platform, user_id):
"""判断是否认识某人"""
person_info_manager = get_person_info_manager()
is_known = await person_info_manager.is_person_known(platform, user_id)
return is_known
return await person_info_manager.is_person_known(platform, user_id)
@staticmethod
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
@@ -110,7 +109,7 @@ class RelationshipManager:
return relation_prompt
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages=None):
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
"""更新用户印象
Args:
@@ -123,7 +122,7 @@ class RelationshipManager:
person_info_manager = get_person_info_manager()
person_name = await person_info_manager.get_value(person_id, "person_name")
nickname = await person_info_manager.get_value(person_id, "nickname")
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore
alias_str = ", ".join(global_config.bot.alias_names)
# personality_block =get_individuality().get_personality_prompt(x_person=2, level=2)
@@ -142,13 +141,13 @@ class RelationshipManager:
# 遍历消息,构建映射
for msg in user_messages:
await person_info_manager.get_or_create_person(
platform=msg.get("chat_info_platform"),
user_id=msg.get("user_id"),
nickname=msg.get("user_nickname"),
user_cardname=msg.get("user_cardname"),
platform=msg.get("chat_info_platform"), # type: ignore
user_id=msg.get("user_id"), # type: ignore
nickname=msg.get("user_nickname"), # type: ignore
user_cardname=msg.get("user_cardname"), # type: ignore
)
replace_user_id = msg.get("user_id")
replace_platform = msg.get("chat_info_platform")
replace_user_id: str = msg.get("user_id") # type: ignore
replace_platform: str = msg.get("chat_info_platform") # type: ignore
replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id)
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
@@ -354,8 +353,8 @@ class RelationshipManager:
person_name = await person_info_manager.get_value(person_id, "person_name")
nickname = await person_info_manager.get_value(person_id, "nickname")
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
attitude = await person_info_manager.get_value(person_id, "attitude") or 50
know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore
attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore
# 根据熟悉度,调整印象和简短印象的最大长度
if know_times > 300:
@@ -414,16 +413,14 @@ class RelationshipManager:
if len(remaining_points) < 10:
# 如果还没达到30条直接保留
remaining_points.append(point)
elif random.random() < keep_probability:
# 保留这个点,随机移除一个已保留的点
idx_to_remove = random.randrange(len(remaining_points))
points_to_move.append(remaining_points[idx_to_remove])
remaining_points[idx_to_remove] = point
else:
# 随机决定是否保留
if random.random() < keep_probability:
# 保留这个点,随机移除一个已保留的点
idx_to_remove = random.randrange(len(remaining_points))
points_to_move.append(remaining_points[idx_to_remove])
remaining_points[idx_to_remove] = point
else:
# 不保留这个点
points_to_move.append(point)
# 不保留这个点
points_to_move.append(point)
# 更新points和forgotten_points
current_points = remaining_points
@@ -520,7 +517,7 @@ class RelationshipManager:
new_attitude = int(relation_value_json.get("attitude", 50))
# 获取当前的关系值
old_attitude = await person_info_manager.get_value(person_id, "attitude") or 50
old_attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore
# 更新熟悉度
if new_attitude > 25:

View File

@@ -65,9 +65,9 @@ def get_replyer(
async def generate_reply(
chat_stream=None,
chat_id: str = None,
action_data: Dict[str, Any] = None,
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None,
reply_to: str = "",
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
@@ -78,25 +78,25 @@ async def generate_reply(
model_configs: Optional[List[Dict[str, Any]]] = None,
request_type: str = "",
enable_timeout: bool = False,
) -> Tuple[bool, List[Tuple[str, Any]]]:
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""生成回复
Args:
chat_stream: 聊天流对象(优先)
action_data: 动作数据
chat_id: 聊天ID备用
action_data: 动作数据
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
return_prompt: 是否返回提示词
Returns:
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, []
return False, [], None
logger.debug("[GeneratorAPI] 开始生成回复")
@@ -109,8 +109,9 @@ async def generate_reply(
enable_timeout=enable_timeout,
enable_tool=enable_tool,
)
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
reply_set = []
if content:
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
if success:
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
@@ -118,19 +119,19 @@ async def generate_reply(
logger.warning("[GeneratorAPI] 回复生成失败")
if return_prompt:
return success, reply_set or [], prompt
return success, reply_set, prompt
else:
return success, reply_set or []
return success, reply_set, None
except Exception as e:
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
return False, []
return False, [], None
async def rewrite_reply(
chat_stream=None,
reply_data: Dict[str, Any] = None,
chat_id: str = None,
chat_stream: Optional[ChatStream] = None,
reply_data: Optional[Dict[str, Any]] = None,
chat_id: Optional[str] = None,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
model_configs: Optional[List[Dict[str, Any]]] = None,
@@ -158,15 +159,16 @@ async def rewrite_reply(
# 调用回复器重写回复
success, content = await replyer.rewrite_reply_with_context(reply_data=reply_data or {})
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
reply_set = []
if content:
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
if success:
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项")
else:
logger.warning("[GeneratorAPI] 重写回复失败")
return success, reply_set or []
return success, reply_set
except Exception as e:
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")

View File

@@ -34,7 +34,7 @@ class ToolExecutor:
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
"""
def __init__(self, chat_id: str = None, enable_cache: bool = True, cache_ttl: int = 3):
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
"""初始化工具执行器
Args:
@@ -62,8 +62,8 @@ class ToolExecutor:
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'}TTL={cache_ttl}")
async def execute_from_chat_message(
self, target_message: str, chat_history: list[str], sender: str, return_details: bool = False
) -> List[Dict] | Tuple[List[Dict], List[str], str]:
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
) -> Tuple[List[Dict], List[str], str]:
"""从聊天消息执行工具
Args:
@@ -79,16 +79,14 @@ class ToolExecutor:
# 首先检查缓存
cache_key = self._generate_cache_key(target_message, chat_history, sender)
cached_result = self._get_from_cache(cache_key)
if cached_result:
if cached_result := self._get_from_cache(cache_key):
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
if return_details:
# 从缓存结果中提取工具名称
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
return cached_result, used_tools, "使用缓存结果"
else:
return cached_result
if not return_details:
return cached_result, [], "使用缓存结果"
# 从缓存结果中提取工具名称
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
return cached_result, used_tools, "使用缓存结果"
# 缓存未命中,执行工具调用
# 获取可用工具
@@ -134,7 +132,7 @@ class ToolExecutor:
if return_details:
return tool_results, used_tools, prompt
else:
return tool_results
return tool_results, [], ""
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
"""执行工具调用
@@ -207,7 +205,7 @@ class ToolExecutor:
return tool_results, used_tools
def _generate_cache_key(self, target_message: str, chat_history: list[str], sender: str) -> str:
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
"""生成缓存键
Args:
@@ -267,10 +265,7 @@ class ToolExecutor:
return
expired_keys = []
for cache_key, cache_item in self.tool_cache.items():
if cache_item["ttl"] <= 0:
expired_keys.append(cache_key)
expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
for key in expired_keys:
del self.tool_cache[key]
@@ -355,7 +350,7 @@ class ToolExecutor:
"ttl_distribution": ttl_distribution,
}
def set_cache_config(self, enable_cache: bool = None, cache_ttl: int = None):
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
"""动态修改缓存配置
Args:
@@ -366,7 +361,7 @@ class ToolExecutor:
self.enable_cache = enable_cache
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
if cache_ttl is not None and cache_ttl > 0:
if cache_ttl > 0:
self.cache_ttl = cache_ttl
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
@@ -380,7 +375,7 @@ init_tool_executor_prompt()
# 1. 基础使用 - 从聊天消息执行工具启用缓存默认TTL=3
executor = ToolExecutor(executor_id="my_executor")
results = await executor.execute_from_chat_message(
results, _, _ = await executor.execute_from_chat_message(
talking_message_str="今天天气怎么样?现在几点了?",
is_group_chat=False
)