fix:调整目录结构,优化hfc prompt,移除日程,移除动态和llm判断willing模式,

This commit is contained in:
SengokuCola
2025-05-13 18:37:55 +08:00
parent 6376da0682
commit fed71bccad
131 changed files with 422 additions and 1500 deletions

View File

@@ -0,0 +1,443 @@
from src.config.config import global_config
from typing import List, Dict, Any, Tuple # 确保类型提示被导入
import time # 导入 time 模块以获取当前时间
import random
import re
from src.common.message_repository import find_messages, count_messages
from src.chat.person_info.person_info import person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable
def get_raw_msg_by_timestamp(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'
"""
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_raw_msg_by_timestamp_with_chat(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'
"""
filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": timestamp_end}}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
# 直接将 limit_mode 传递给 find_messages
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_raw_msg_by_timestamp_with_chat_users(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
person_ids: list,
limit: int = 0,
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
"""获取某些特定用户在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'
"""
filter_query = {
"chat_id": chat_id,
"time": {"$gt": timestamp_start, "$lt": timestamp_end},
"user_id": {"$in": person_ids},
}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_raw_msg_by_timestamp_random(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
先在范围时间戳内随机选择一条消息取得消息的chat_id然后根据chat_id获取该聊天在指定时间戳范围内的消息
"""
# 获取所有消息只取chat_id字段
all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
if not all_msgs:
return []
# 随机选一条
msg = random.choice(all_msgs)
chat_id = msg["chat_id"]
# 用 chat_id 获取该聊天在指定时间戳范围内的消息
return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, limit_mode)
def get_raw_msg_by_timestamp_with_users(
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'
"""
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float = None) -> int:
"""
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。
"""
# 确定有效的结束时间戳
_timestamp_end = timestamp_end if timestamp_end is not None else time.time()
# 确保 timestamp_start < _timestamp_end
if timestamp_start >= _timestamp_end:
# logger.warning(f"timestamp_start ({timestamp_start}) must be less than _timestamp_end ({_timestamp_end}). Returning 0.")
return 0 # 起始时间大于等于结束时间,没有新消息
filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}}
return count_messages(message_filter=filter_query)
def num_new_messages_since_with_users(
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
) -> int:
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
if not person_ids: # 保持空列表检查
return 0
filter_query = {
"chat_id": chat_id,
"time": {"$gt": timestamp_start, "$lt": timestamp_end},
"user_id": {"$in": person_ids},
}
return count_messages(message_filter=filter_query)
async def _build_readable_messages_internal(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
"""
内部辅助函数,构建可读消息字符串和原始消息详情列表。
Args:
messages: 消息字典列表。
replace_bot_name: 是否将机器人的 user_id 替换为 ""
merge_messages: 是否合并来自同一用户的连续消息。
timestamp_mode: 时间戳的显示模式 ('relative', 'absolute', etc.)。传递给 translate_timestamp_to_human_readable。
truncate: 是否根据消息的新旧程度截断过长的消息内容。
Returns:
包含格式化消息的字符串和原始消息详情列表 (时间戳, 发送者名称, 内容) 的元组。
"""
if not messages:
return "", []
message_details_raw: List[Tuple[float, str, str]] = []
# 1 & 2: 获取发送者信息并提取消息组件
for msg in messages:
user_info = msg.get("user_info", {})
platform = user_info.get("platform")
user_id = user_info.get("user_id")
user_nickname = user_info.get("user_nickname")
user_cardname = user_info.get("user_cardname")
timestamp = msg.get("time")
content = msg.get("processed_plain_text", "") # 默认空字符串
# 检查必要信息是否存在
if not all([platform, user_id, timestamp is not None]):
continue
person_id = person_info_manager.get_person_id(platform, user_id)
# 根据 replace_bot_name 参数决定是否替换机器人名称
if replace_bot_name and user_id == global_config.BOT_QQ:
person_name = f"{global_config.BOT_NICKNAME}(你)"
else:
person_name = await person_info_manager.get_value(person_id, "person_name")
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
if not person_name:
if user_cardname:
person_name = f"昵称:{user_cardname}"
elif user_nickname:
person_name = f"{user_nickname}"
else:
person_name = "某人"
# 检查是否有 回复<aaa:bbb> 字段
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
match = re.search(reply_pattern, content)
if match:
aaa = match.group(1)
bbb = match.group(2)
reply_person_id = person_info_manager.get_person_id(platform, bbb)
reply_person_name = await person_info_manager.get_value(reply_person_id, "person_name")
if not reply_person_name:
reply_person_name = aaa
# 在内容前加上回复信息
content = re.sub(reply_pattern, f"回复 {reply_person_name}", content, count=1)
# 检查是否有 @<aaa:bbb> 字段 @<{member_info.get('nickname')}:{member_info.get('user_id')}>
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
at_matches = list(re.finditer(at_pattern, content))
if at_matches:
new_content = ""
last_end = 0
for m in at_matches:
new_content += content[last_end : m.start()]
aaa = m.group(1)
bbb = m.group(2)
at_person_id = person_info_manager.get_person_id(platform, bbb)
at_person_name = await person_info_manager.get_value(at_person_id, "person_name")
if not at_person_name:
at_person_name = aaa
new_content += f"@{at_person_name}"
last_end = m.end()
new_content += content[last_end:]
content = new_content
target_str = "这是QQ的一个功能用于提及某人但没那么明显"
if target_str in content:
if random.random() < 0.6:
content = content.replace(target_str, "")
if content != "":
message_details_raw.append((timestamp, person_name, content))
if not message_details_raw:
return "", []
message_details_raw.sort(key=lambda x: x[0]) # 按时间戳(第一个元素)升序排序,越早的消息排在前面
# 应用截断逻辑 (如果 truncate 为 True)
message_details: List[Tuple[float, str, str]] = []
n_messages = len(message_details_raw)
if truncate and n_messages > 0:
for i, (timestamp, name, content) in enumerate(message_details_raw):
percentile = i / n_messages # 计算消息在列表中的位置百分比 (0 <= percentile < 1)
original_len = len(content)
limit = -1 # 默认不截断
if percentile < 0.2: # 60% 之前的消息 (即最旧的 60%)
limit = 50
replace_content = "......(记不清了)"
elif percentile < 0.5: # 60% 之前的消息 (即最旧的 60%)
limit = 100
replace_content = "......(有点记不清了)"
elif percentile < 0.7: # 60% 到 80% 之前的消息 (即中间的 20%)
limit = 200
replace_content = "......(内容太长了)"
elif percentile < 1.0: # 80% 到 100% 之前的消息 (即较新的 20%)
limit = 300
replace_content = "......(太长了)"
truncated_content = content
if 0 < limit < original_len:
truncated_content = f"{content[:limit]}{replace_content}"
message_details.append((timestamp, name, truncated_content))
else:
# 如果不截断,直接使用原始列表
message_details = message_details_raw
# 3: 合并连续消息 (如果 merge_messages 为 True)
merged_messages = []
if merge_messages and message_details:
# 初始化第一个合并块
current_merge = {
"name": message_details[0][1],
"start_time": message_details[0][0],
"end_time": message_details[0][0],
"content": [message_details[0][2]],
}
for i in range(1, len(message_details)):
timestamp, name, content = message_details[i]
# 如果是同一个人发送的连续消息且时间间隔小于等于60秒
if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60):
current_merge["content"].append(content)
current_merge["end_time"] = timestamp # 更新最后消息时间
else:
# 保存上一个合并块
merged_messages.append(current_merge)
# 开始新的合并块
current_merge = {"name": name, "start_time": timestamp, "end_time": timestamp, "content": [content]}
# 添加最后一个合并块
merged_messages.append(current_merge)
elif message_details: # 如果不合并消息,则每个消息都是一个独立的块
for timestamp, name, content in message_details:
merged_messages.append(
{
"name": name,
"start_time": timestamp, # 起始和结束时间相同
"end_time": timestamp,
"content": [content], # 内容只有一个元素
}
)
# 4 & 5: 格式化为字符串
output_lines = []
for _i, merged in enumerate(merged_messages):
# 使用指定的 timestamp_mode 格式化时间
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
header = f"{readable_time}{merged['name']} 说:"
output_lines.append(header)
# 将内容合并,并添加缩进
for line in merged["content"]:
stripped_line = line.strip()
if stripped_line: # 过滤空行
# 移除末尾句号,添加分号 - 这个逻辑似乎有点奇怪,暂时保留
if stripped_line.endswith(""):
stripped_line = stripped_line[:-1]
# 如果内容被截断,结尾已经是 ...(内容太长),不再添加分号
if not stripped_line.endswith("(内容太长)"):
output_lines.append(f"{stripped_line};")
else:
output_lines.append(stripped_line) # 直接添加截断后的内容
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
# 移除可能的多余换行,然后合并
formatted_string = "".join(output_lines).strip()
# 返回格式化后的字符串和 *应用截断后* 的 message_details 列表
# 注意:如果外部调用者需要原始未截断的内容,可能需要调整返回策略
return formatted_string, message_details
async def build_readable_messages_with_list(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
"""
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
允许通过参数控制格式化行为。
"""
formatted_string, details_list = await _build_readable_messages_internal(
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
)
return formatted_string, details_list
async def build_readable_messages(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
) -> str:
"""
将消息列表转换为可读的文本格式。
如果提供了 read_mark则在相应位置插入已读标记。
允许通过参数控制格式化行为。
"""
if read_mark <= 0:
# 没有有效的 read_mark直接格式化所有消息
formatted_string, _ = await _build_readable_messages_internal(
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
)
return formatted_string
else:
# 按 read_mark 分割消息
messages_before_mark = [msg for msg in messages if msg.get("time", 0) <= read_mark]
messages_after_mark = [msg for msg in messages if msg.get("time", 0) > read_mark]
# 分别格式化
# 注意:这里决定对已读和未读部分都应用相同的 truncate 设置
# 如果需要不同的行为(例如只截断已读部分),需要调整这里的调用
formatted_before, _ = await _build_readable_messages_internal(
messages_before_mark, replace_bot_name, merge_messages, timestamp_mode, truncate
)
formatted_after, _ = await _build_readable_messages_internal(
messages_after_mark,
replace_bot_name,
merge_messages,
timestamp_mode,
)
readable_read_mark = translate_timestamp_to_human_readable(read_mark, mode=timestamp_mode)
read_mark_line = f"\n--- 以上消息是你已经思考过的内容已读 (标记时间: {readable_read_mark}) ---\n--- 请关注以下未读的新消息---\n"
# 组合结果,确保空部分不引入多余的标记或换行
if formatted_before and formatted_after:
return f"{formatted_before}{read_mark_line}{formatted_after}"
elif formatted_before:
return f"{formatted_before}{read_mark_line}"
elif formatted_after:
return f"{read_mark_line}{formatted_after}"
else:
# 理论上不应该发生,但作为保险
return read_mark_line.strip() # 如果前后都无消息,只返回标记行
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
"""
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。
Args:
messages: 消息字典列表。
Returns:
一个包含唯一 person_id 的列表。
"""
person_ids_set = set() # 使用集合来自动去重
for msg in messages:
user_info = msg.get("user_info", {})
platform = user_info.get("platform")
user_id = user_info.get("user_id")
# 检查必要信息是否存在 且 不是机器人自己
if not all([platform, user_id]) or user_id == global_config.BOT_QQ:
continue
person_id = person_info_manager.get_person_id(platform, user_id)
# 只有当获取到有效 person_id 时才添加
if person_id:
person_ids_set.add(person_id)
return list(person_ids_set) # 将集合转换为列表返回

View File

@@ -0,0 +1,234 @@
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv, MessageSending, Message
from src.common.database import db
import time
import traceback
from typing import List
class InfoCatcher:
def __init__(self):
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~
self.context_length = global_config.observation_context_size
self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~
self.chat_id = ""
self.trigger_response_text = ""
self.response_text = ""
self.trigger_response_time = 0
self.trigger_response_message = None
self.response_time = 0
self.response_messages = []
# 使用字典来存储 heartflow 模式的数据
self.heartflow_data = {
"heart_flow_prompt": "",
"sub_heartflow_before": "",
"sub_heartflow_now": "",
"sub_heartflow_after": "",
"sub_heartflow_model": "",
"prompt": "",
"response": "",
"model": "",
}
# 使用字典来存储 reasoning 模式的数据喵~
self.reasoning_data = {"thinking_log": "", "prompt": "", "response": "", "model": ""}
# 耗时喵~
self.timing_results = {
"interested_rate_time": 0,
"sub_heartflow_observe_time": 0,
"sub_heartflow_step_time": 0,
"make_response_time": 0,
}
def catch_decide_to_response(self, message: MessageRecv):
# 搜集决定回复时的信息
self.trigger_response_message = message
self.trigger_response_text = message.detailed_plain_text
self.trigger_response_time = time.time()
self.chat_id = message.chat_stream.stream_id
self.chat_history = self.get_message_from_db_before_msg(message)
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
self.timing_results["sub_heartflow_observe_time"] = obs_duration
# def catch_shf
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
self.timing_results["sub_heartflow_step_time"] = step_duration
if len(past_mind) > 1:
self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
self.heartflow_data["sub_heartflow_now"] = current_mind
else:
self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
self.heartflow_data["sub_heartflow_now"] = current_mind
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
# if self.response_mode == "heart_flow": # 条件判断不需要了喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
# elif self.response_mode == "reasoning": # 条件判断不需要了喵~
# self.reasoning_data["thinking_log"] = reasoning_content
# self.reasoning_data["prompt"] = prompt
# self.reasoning_data["response"] = response
# self.reasoning_data["model"] = model_name
# 直接记录信息喵~
self.reasoning_data["thinking_log"] = reasoning_content
self.reasoning_data["prompt"] = prompt
self.reasoning_data["response"] = response
self.reasoning_data["model"] = model_name
# 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
self.response_text = response
def catch_after_generate_response(self, response_duration: float):
self.timing_results["make_response_time"] = response_duration
def catch_after_response(
self, response_duration: float, response_message: List[str], first_bot_msg: MessageSending
):
self.timing_results["make_response_time"] = response_duration
self.response_time = time.time()
for msg in response_message:
self.response_messages.append(msg)
self.chat_history_in_thinking = self.get_message_from_db_between_msgs(
self.trigger_response_message, first_bot_msg
)
@staticmethod
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
try:
# 从数据库中获取消息的时间戳
time_start = message_start.message_info.time
time_end = message_end.message_info.time
chat_id = message_start.chat_stream.stream_id
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
messages_between = db.messages.find(
{"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}}
).sort("time", -1)
result = list(messages_between)
print(f"查询结果数量: {len(result)}")
if result:
print(f"第一条消息时间: {result[0]['time']}")
print(f"最后一条消息时间: {result[-1]['time']}")
return result
except Exception as e:
print(f"获取消息时出错: {str(e)}")
return []
def get_message_from_db_before_msg(self, message: MessageRecv):
# 从数据库中获取消息
message_id = message.message_info.message_id
chat_id = message.chat_stream.stream_id
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
messages_before = (
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
.sort("time", -1)
.limit(self.context_length * 3)
) # 获取更多历史信息
return list(messages_before)
def message_list_to_dict(self, message_list):
# 存储简化的聊天记录
result = []
for message in message_list:
if not isinstance(message, dict):
message = self.message_to_dict(message)
# print(message)
lite_message = {
"time": message["time"],
"user_nickname": message["user_info"]["user_nickname"],
"processed_plain_text": message["processed_plain_text"],
}
result.append(lite_message)
return result
@staticmethod
def message_to_dict(message):
if not message:
return None
if isinstance(message, dict):
return message
return {
# "message_id": message.message_info.message_id,
"time": message.message_info.time,
"user_id": message.message_info.user_info.user_id,
"user_nickname": message.message_info.user_info.user_nickname,
"processed_plain_text": message.processed_plain_text,
# "detailed_plain_text": message.detailed_plain_text
}
def done_catch(self):
"""将收集到的信息存储到数据库的 thinking_log 集合中喵~"""
try:
# 将消息对象转换为可序列化的字典喵~
thinking_log_data = {
"chat_id": self.chat_id,
"trigger_text": self.trigger_response_text,
"response_text": self.response_text,
"trigger_info": {
"time": self.trigger_response_time,
"message": self.message_to_dict(self.trigger_response_message),
},
"response_info": {
"time": self.response_time,
"message": self.response_messages,
},
"timing_results": self.timing_results,
"chat_history": self.message_list_to_dict(self.chat_history),
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
"heartflow_data": self.heartflow_data,
"reasoning_data": self.reasoning_data,
}
# 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~
# if self.response_mode == "heart_flow":
# thinking_log_data["mode_specific_data"] = self.heartflow_data
# elif self.response_mode == "reasoning":
# thinking_log_data["mode_specific_data"] = self.reasoning_data
# 将数据插入到 thinking_log 集合中喵~
db.thinking_log.insert_one(thinking_log_data)
return True
except Exception as e:
print(f"存储思考日志时出错: {str(e)} 喵~")
print(traceback.format_exc())
return False
class InfoCatcherManager:
def __init__(self):
self.info_catchers = {}
def get_info_catcher(self, thinking_id: str) -> InfoCatcher:
if thinking_id not in self.info_catchers:
self.info_catchers[thinking_id] = InfoCatcher()
return self.info_catchers[thinking_id]
info_catcher_manager = InfoCatcherManager()

View File

@@ -0,0 +1,226 @@
import json
import logging
from typing import Any, Dict, TypeVar, List, Union, Tuple
import ast
# 定义类型变量用于泛型类型提示
T = TypeVar("T")
# 获取logger
logger = logging.getLogger("json_utils")
def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
"""
安全地解析JSON字符串出错时返回默认值
现在尝试处理单引号和标准JSON
参数:
json_str: 要解析的JSON字符串
default_value: 解析失败时返回的默认值
返回:
解析后的Python对象或在解析失败时返回default_value
"""
if not json_str or not isinstance(json_str, str):
logger.warning(f"safe_json_loads 接收到非字符串输入: {type(json_str)}, 值: {json_str}")
return default_value
try:
# 尝试标准的 JSON 解析
return json.loads(json_str)
except json.JSONDecodeError:
# 如果标准解析失败,尝试将单引号替换为双引号再解析
# (注意:这种替换可能不安全,如果字符串内容本身包含引号)
# 更安全的方式是用 ast.literal_eval
try:
# logger.debug(f"标准JSON解析失败尝试用 ast.literal_eval 解析: {json_str[:100]}...")
result = ast.literal_eval(json_str)
# 确保结果是字典(因为我们通常期望参数是字典)
if isinstance(result, dict):
return result
else:
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
return default_value
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...")
return default_value
except Exception as e:
logger.error(f"使用 ast.literal_eval 解析时发生意外错误: {e}, 字符串: {json_str[:100]}...")
return default_value
except Exception as e:
logger.error(f"JSON解析过程中发生意外错误: {e}, 字符串: {json_str[:100]}...")
return default_value
def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]:
"""
从LLM工具调用对象中提取参数
参数:
tool_call: 工具调用对象字典
default_value: 解析失败时返回的默认值
返回:
解析后的参数字典或在解析失败时返回default_value
"""
default_result = default_value or {}
if not tool_call or not isinstance(tool_call, dict):
logger.error(f"无效的工具调用对象: {tool_call}")
return default_result
try:
# 提取function参数
function_data = tool_call.get("function", {})
if not function_data or not isinstance(function_data, dict):
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
return default_result
# 提取arguments
arguments_str = function_data.get("arguments", "{}")
if not arguments_str:
return default_result
# 解析JSON
return safe_json_loads(arguments_str, default_result)
except Exception as e:
logger.error(f"提取工具调用参数时出错: {e}")
return default_result
def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str:
"""
安全地将Python对象序列化为JSON字符串
参数:
obj: 要序列化的Python对象
default_value: 序列化失败时返回的默认值
ensure_ascii: 是否确保ASCII编码(默认False允许中文等非ASCII字符)
pretty: 是否美化输出JSON
返回:
序列化后的JSON字符串或在序列化失败时返回default_value
"""
try:
indent = 2 if pretty else None
return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent)
except TypeError as e:
logger.error(f"JSON序列化失败(类型错误): {e}")
return default_value
except Exception as e:
logger.error(f"JSON序列化过程中发生意外错误: {e}")
return default_value
def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]:
"""
标准化LLM响应格式将各种格式如元组转换为统一的列表格式
参数:
response: 原始LLM响应
log_prefix: 日志前缀
返回:
元组 (成功标志, 标准化后的响应列表, 错误消息)
"""
logger.debug(f"{log_prefix}原始人 LLM响应: {response}")
# 检查是否为None
if response is None:
return False, [], "LLM响应为None"
# 记录原始类型
logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}")
# 将元组转换为列表
if isinstance(response, tuple):
logger.debug(f"{log_prefix}将元组响应转换为列表")
response = list(response)
# 确保是列表类型
if not isinstance(response, list):
return False, [], f"无法处理的LLM响应类型: {type(response).__name__}"
# 处理工具调用部分(如果存在)
if len(response) == 3:
content, reasoning, tool_calls = response
# 将工具调用部分转换为列表(如果是元组)
if isinstance(tool_calls, tuple):
logger.debug(f"{log_prefix}将工具调用元组转换为列表")
tool_calls = list(tool_calls)
response[2] = tool_calls
return True, response, ""
def process_llm_tool_calls(
tool_calls: List[Dict[str, Any]], log_prefix: str = ""
) -> Tuple[bool, List[Dict[str, Any]], str]:
"""
处理并验证LLM响应中的工具调用列表
参数:
tool_calls: 从LLM响应中直接获取的工具调用列表
log_prefix: 日志前缀
返回:
元组 (成功标志, 验证后的工具调用列表, 错误消息)
"""
# 如果列表为空,表示没有工具调用,这不是错误
if not tool_calls:
return True, [], "工具调用列表为空"
# 验证每个工具调用的格式
valid_tool_calls = []
for i, tool_call in enumerate(tool_calls):
if not isinstance(tool_call, dict):
logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}, 内容: {tool_call}")
continue
# 检查基本结构
if tool_call.get("type") != "function":
logger.warning(
f"{log_prefix}工具调用[{i}]不是function类型: type={tool_call.get('type', '未定义')}, 内容: {tool_call}"
)
continue
if "function" not in tool_call or not isinstance(tool_call.get("function"), dict):
logger.warning(f"{log_prefix}工具调用[{i}]缺少'function'字段或其类型不正确: {tool_call}")
continue
func_details = tool_call["function"]
if "name" not in func_details or not isinstance(func_details.get("name"), str):
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'name'或类型不正确: {func_details}")
continue
# 验证参数 'arguments'
args_value = func_details.get("arguments")
# 1. 检查 arguments 是否存在且是字符串
if args_value is None or not isinstance(args_value, str):
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'arguments'字符串: {func_details}")
continue
# 2. 尝试安全地解析 arguments 字符串
parsed_args = safe_json_loads(args_value, None)
# 3. 检查解析结果是否为字典
if parsed_args is None or not isinstance(parsed_args, dict):
logger.warning(
f"{log_prefix}工具调用[{i}]的'arguments'无法解析为有效的JSON字典, "
f"原始字符串: {args_value[:100]}..., 解析结果类型: {type(parsed_args).__name__}"
)
continue
# 如果检查通过,将原始的 tool_call 加入有效列表
valid_tool_calls.append(tool_call)
if not valid_tool_calls and tool_calls: # 如果原始列表不为空,但验证后为空
return False, [], "所有工具调用格式均无效"
return True, valid_tool_calls, ""

View File

@@ -0,0 +1,88 @@
import sys
import loguru
from enum import Enum
class LogClassification(Enum):
BASE = "base"
MEMORY = "memory"
EMOJI = "emoji"
CHAT = "chat"
PBUILDER = "promptbuilder"
class LogModule:
logger = loguru.logger.opt()
def __init__(self):
pass
def setup_logger(self, log_type: LogClassification):
"""配置日志格式
Args:
log_type: 日志类型可选值BASE(基础日志)、MEMORY(记忆系统日志)、EMOJI(表情包系统日志)
"""
# 移除默认日志处理器
self.logger.remove()
# 基础日志格式
base_format = (
"<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | "
" d<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
)
chat_format = (
"<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
)
# 记忆系统日志格式
memory_format = (
"<green>{time:HH:mm}</green> | <level>{level: <8}</level> | "
"<light-magenta>海马体</light-magenta> | <level>{message}</level>"
)
# 表情包系统日志格式
emoji_format = (
"<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>表情包</yellow> | "
"<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
)
promptbuilder_format = (
"<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>Prompt</yellow> | "
"<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
)
# 根据日志类型选择日志格式和输出
if log_type == LogClassification.CHAT:
self.logger.add(
sys.stderr,
format=chat_format,
# level="INFO"
)
elif log_type == LogClassification.PBUILDER:
self.logger.add(
sys.stderr,
format=promptbuilder_format,
# level="INFO"
)
elif log_type == LogClassification.MEMORY:
# 同时输出到控制台和文件
self.logger.add(
sys.stderr,
format=memory_format,
# level="INFO"
)
self.logger.add("logs/memory.log", format=memory_format, level="INFO", rotation="1 day", retention="7 days")
elif log_type == LogClassification.EMOJI:
self.logger.add(
sys.stderr,
format=emoji_format,
# level="INFO"
)
self.logger.add("logs/emoji.log", format=emoji_format, level="INFO", rotation="1 day", retention="7 days")
else: # BASE
self.logger.add(sys.stderr, format=base_format, level="INFO")
return self.logger

View File

@@ -0,0 +1,237 @@
from typing import Dict, Any, Optional, List, Union
import re
from contextlib import asynccontextmanager
import asyncio
from src.common.logger import get_module_logger
# import traceback
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("prompt_build")
class PromptContext:
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._current_context: Optional[str] = None
self._context_lock = asyncio.Lock() # 添加异步锁
@asynccontextmanager
async def async_scope(self, context_id: str):
"""创建一个异步的临时提示模板作用域"""
async with self._context_lock:
if context_id not in self._context_prompts:
self._context_prompts[context_id] = {}
previous_context = self._current_context
self._current_context = context_id
try:
yield self
finally:
async with self._context_lock:
self._current_context = previous_context
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板"""
async with self._context_lock:
if self._current_context and name in self._context_prompts[self._current_context]:
return self._context_prompts[self._current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
target_context = context_id or self._current_context
if target_context:
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
class PromptManager:
def __init__(self):
self._prompts = {}
self._counter = 0
self._context = PromptContext()
self._lock = asyncio.Lock()
@asynccontextmanager
async def async_message_scope(self, message_id: str):
"""为消息处理创建异步临时作用域"""
async with self._context.async_scope(message_id):
yield self
async def get_prompt_async(self, name: str) -> "Prompt":
# 首先尝试从当前上下文获取
context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None:
return context_prompt
# 如果上下文中不存在,则使用全局提示模板
async with self._lock:
if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]
def generate_name(self, template: str) -> str:
"""为未命名的prompt生成名称"""
self._counter += 1
return f"prompt_{self._counter}"
def register(self, prompt: "Prompt") -> None:
"""注册一个prompt"""
if not prompt.name:
prompt.name = self.generate_name(prompt.template)
self._prompts[prompt.name] = prompt
def add_prompt(self, name: str, fstr: str) -> "Prompt":
prompt = Prompt(fstr, name=name)
self._prompts[prompt.name] = prompt
return prompt
async def format_prompt(self, name: str, **kwargs) -> str:
prompt = await self.get_prompt_async(name)
return prompt.format(**kwargs)
# 全局单例
global_prompt_manager = PromptManager()
class Prompt(str):
# 临时标记,作为类常量
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
@staticmethod
def _process_escaped_braces(template: str) -> str:
"""处理模板中的转义花括号,将 \{\} 替换为临时标记"""
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
@staticmethod
def _restore_escaped_braces(template: str) -> str:
"""将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def __new__(cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs):
# 如果传入的是元组,转换为列表
if isinstance(args, tuple):
args = list(args)
should_register = kwargs.pop("_should_register", True)
# 预处理模板中的转义花括号
processed_fstr = cls._process_escaped_braces(fstr)
# 解析模板
template_args = []
result = re.findall(r"\{(.*?)}", processed_fstr)
for expr in result:
if expr and expr not in template_args:
template_args.append(expr)
# 如果提供了初始参数,立即格式化
if kwargs or args:
formatted = cls._format_template(fstr, args=args, kwargs=kwargs)
obj = super().__new__(cls, formatted)
else:
obj = super().__new__(cls, "")
obj.template = fstr
obj.name = name
obj.args = template_args
obj._args = args or []
obj._kwargs = kwargs
# 修改自动注册逻辑
if should_register:
if global_prompt_manager._context._current_context:
# 如果存在当前上下文,则注册到上下文中
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
pass
else:
# 否则注册到全局管理器
global_prompt_manager.register(obj)
return obj
@classmethod
async def create_async(
cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs
):
"""异步创建Prompt实例"""
prompt = cls(fstr, name, args, **kwargs)
if global_prompt_manager._context._current_context:
await global_prompt_manager._context.register_async(prompt)
return prompt
@classmethod
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
# 预处理模板中的转义花括号
processed_template = cls._process_escaped_braces(template)
template_args = []
result = re.findall(r"\{(.*?)}", processed_template)
for expr in result:
if expr and expr not in template_args:
template_args.append(expr)
formatted_args = {}
formatted_kwargs = {}
# 处理位置参数
if args:
# print(len(template_args), len(args), template_args, args)
for i in range(len(args)):
if i < len(template_args):
arg = args[i]
if isinstance(arg, Prompt):
formatted_args[template_args[i]] = arg.format(**kwargs)
else:
formatted_args[template_args[i]] = arg
else:
logger.error(
f"构建提示词模板失败,解析到的参数列表{template_args},长度为{len(template_args)},输入的参数列表为{args},提示词模板为{template}"
)
raise ValueError("格式化模板失败")
# 处理关键字参数
if kwargs:
for key, value in kwargs.items():
if isinstance(value, Prompt):
remaining_kwargs = {k: v for k, v in kwargs.items() if k != key}
formatted_kwargs[key] = value.format(**remaining_kwargs)
else:
formatted_kwargs[key] = value
try:
# 先用位置参数格式化
if args:
processed_template = processed_template.format(**formatted_args)
# 再用关键字参数格式化
if kwargs:
processed_template = processed_template.format(**formatted_kwargs)
# 将临时标记还原为实际的花括号
result = cls._restore_escaped_braces(processed_template)
return result
except (IndexError, KeyError) as e:
raise ValueError(
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
) from e
def format(self, *args, **kwargs) -> "str":
"""支持位置参数和关键字参数的格式化,使用"""
ret = type(self)(
self.template,
self.name,
args=list(args) if args else self._args,
_should_register=False,
**kwargs if kwargs else self._kwargs,
)
# print(f"prompt build result: {ret} name: {ret.name} ")
return str(ret)
def __str__(self) -> str:
if self._kwargs or self._args:
return super().__str__()
return self.template
def __repr__(self) -> str:
return f"Prompt(template='{self.template}', name='{self.name}')"

760
src/chat/utils/statistic.py Normal file
View File

@@ -0,0 +1,760 @@
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
from src.common.logger import get_module_logger
from src.manager.async_task_manager import AsyncTask
from ...common.database import db
from src.manager.local_store_manager import local_storage
logger = get_module_logger("maibot_statistic")
# 统计数据的键
TOTAL_REQ_CNT = "total_requests"
TOTAL_COST = "total_cost"
REQ_CNT_BY_TYPE = "requests_by_type"
REQ_CNT_BY_USER = "requests_by_user"
REQ_CNT_BY_MODEL = "requests_by_model"
IN_TOK_BY_TYPE = "in_tokens_by_type"
IN_TOK_BY_USER = "in_tokens_by_user"
IN_TOK_BY_MODEL = "in_tokens_by_model"
OUT_TOK_BY_TYPE = "out_tokens_by_type"
OUT_TOK_BY_USER = "out_tokens_by_user"
OUT_TOK_BY_MODEL = "out_tokens_by_model"
TOTAL_TOK_BY_TYPE = "tokens_by_type"
TOTAL_TOK_BY_USER = "tokens_by_user"
TOTAL_TOK_BY_MODEL = "tokens_by_model"
COST_BY_TYPE = "costs_by_type"
COST_BY_USER = "costs_by_user"
COST_BY_MODEL = "costs_by_model"
ONLINE_TIME = "online_time"
TOTAL_MSG_CNT = "total_messages"
MSG_CNT_BY_CHAT = "messages_by_chat"
class OnlineTimeRecordTask(AsyncTask):
"""在线时间记录任务"""
def __init__(self):
super().__init__(task_name="Online Time Record Task", run_interval=60)
self.record_id: str | None = None
"""记录ID"""
self._init_database() # 初始化数据库
@staticmethod
def _init_database():
"""初始化数据库"""
if "online_time" not in db.list_collection_names():
# 初始化数据库(在线时长)
db.create_collection("online_time")
# 创建索引
if ("end_timestamp", 1) not in db.online_time.list_indexes():
db.online_time.create_index([("end_timestamp", 1)])
async def run(self):
try:
if self.record_id:
# 如果有记录,则更新结束时间
db.online_time.update_one(
{"_id": self.record_id},
{
"$set": {
"end_timestamp": datetime.now() + timedelta(minutes=1),
}
},
)
else:
# 如果没有记录,检查一分钟以内是否已有记录
current_time = datetime.now()
recent_record = db.online_time.find_one(
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}}
)
if not recent_record:
# 若没有记录,则插入新的在线时间记录
self.record_id = db.online_time.insert_one(
{
"start_timestamp": current_time,
"end_timestamp": current_time + timedelta(minutes=1),
}
).inserted_id
else:
# 如果有记录,则更新结束时间
self.record_id = recent_record["_id"]
db.online_time.update_one(
{"_id": self.record_id},
{
"$set": {
"end_timestamp": current_time + timedelta(minutes=1),
}
},
)
except Exception:
logger.exception("在线时间记录失败")
def _format_online_time(online_seconds: int) -> str:
"""
格式化在线时间
:param online_seconds: 在线时间(秒)
:return: 格式化后的在线时间字符串
"""
total_oneline_time = timedelta(seconds=online_seconds)
days = total_oneline_time.days
hours = total_oneline_time.seconds // 3600
minutes = (total_oneline_time.seconds // 60) % 60
seconds = total_oneline_time.seconds % 60
if days > 0:
# 如果在线时间超过1天则格式化为"X天X小时X分钟"
total_oneline_time_str = f"{total_oneline_time.days}{hours}小时{minutes}分钟{seconds}"
elif hours > 0:
# 如果在线时间超过1小时则格式化为"X小时X分钟X秒"
total_oneline_time_str = f"{hours}小时{minutes}分钟{seconds}"
else:
# 其他情况格式化为"X分钟X秒"
total_oneline_time_str = f"{minutes}分钟{seconds}"
return total_oneline_time_str
class StatisticOutputTask(AsyncTask):
"""统计输出任务"""
SEP_LINE = "-" * 84
def __init__(self, record_file_path: str = "maibot_statistics.html"):
# 延迟300秒启动运行间隔300秒
super().__init__(task_name="Statistics Data Output Task", wait_before_start=0, run_interval=300)
self.name_mapping: Dict[str, Tuple[str, float]] = {}
"""
联系人/群聊名称映射 {聊天ID: (联系人/群聊名称, 记录时间timestamp)}
注:设计记录时间的目的是方便更新名称,使联系人/群聊名称保持最新
"""
self.record_file_path: str = record_file_path
"""
记录文件路径
"""
now = datetime.now()
if "deploy_time" in local_storage:
# 如果存在部署时间,则使用该时间作为全量统计的起始时间
deploy_time = datetime.fromtimestamp(local_storage["deploy_time"])
else:
# 否则,使用最大时间范围,并记录部署时间为当前时间
deploy_time = datetime(2000, 1, 1)
local_storage["deploy_time"] = now.timestamp()
self.stat_period: List[Tuple[str, timedelta, str]] = [
("all_time", now - deploy_time, "自部署以来"), # 必须保留"all_time"
("last_7_days", timedelta(days=7), "最近7天"),
("last_24_hours", timedelta(days=1), "最近24小时"),
("last_hour", timedelta(hours=1), "最近1小时"),
]
"""
统计时间段 [(统计名称, 统计时间段, 统计描述), ...]
"""
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
"""
输出统计数据到控制台
:param stats: 统计数据
:param now: 基准当前时间
"""
# 输出最近一小时的统计数据
output = [
self.SEP_LINE,
f" 最近1小时的统计数据 (自{now.strftime('%Y-%m-%d %H:%M:%S')}开始,详细信息见文件:{self.record_file_path})",
self.SEP_LINE,
self._format_total_stat(stats["last_hour"]),
"",
self._format_model_classified_stat(stats["last_hour"]),
"",
self._format_chat_stat(stats["last_hour"]),
self.SEP_LINE,
"",
]
logger.info("\n" + "\n".join(output))
async def run(self):
try:
now = datetime.now()
# 收集统计数据
stats = self._collect_all_statistics(now)
# 输出统计数据到控制台
self._statistic_console_output(stats, now)
# 输出统计数据到html文件
self._generate_html_report(stats, now)
except Exception as e:
logger.exception(f"输出统计数据过程中发生异常,错误信息:{e}")
# -- 以下为统计数据收集方法 --
@staticmethod
def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
"""
收集指定时间段的LLM请求统计数据
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 总LLM请求数
TOTAL_REQ_CNT: 0,
# 请求次数统计
REQ_CNT_BY_TYPE: defaultdict(int),
REQ_CNT_BY_USER: defaultdict(int),
REQ_CNT_BY_MODEL: defaultdict(int),
# 输入Token数
IN_TOK_BY_TYPE: defaultdict(int),
IN_TOK_BY_USER: defaultdict(int),
IN_TOK_BY_MODEL: defaultdict(int),
# 输出Token数
OUT_TOK_BY_TYPE: defaultdict(int),
OUT_TOK_BY_USER: defaultdict(int),
OUT_TOK_BY_MODEL: defaultdict(int),
# 总Token数
TOTAL_TOK_BY_TYPE: defaultdict(int),
TOTAL_TOK_BY_USER: defaultdict(int),
TOTAL_TOK_BY_MODEL: defaultdict(int),
# 总开销
TOTAL_COST: 0.0,
# 请求开销统计
COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float),
}
for period_key, _ in collect_period
}
# 以最早的时间戳为起始时间获取记录
for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}):
record_timestamp = record.get("timestamp")
for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.get("request_type", "unknown") # 请求类型
user_id = str(record.get("user_id", "unknown")) # 用户ID
model_name = record.get("model_name", "unknown") # 模型名称
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数
completion_tokens = record.get("completion_tokens", 0) # 输出Token数
total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
stats[period_key][IN_TOK_BY_MODEL][model_name] += prompt_tokens
stats[period_key][OUT_TOK_BY_TYPE][request_type] += completion_tokens
stats[period_key][OUT_TOK_BY_USER][user_id] += completion_tokens
stats[period_key][OUT_TOK_BY_MODEL][model_name] += completion_tokens
stats[period_key][TOTAL_TOK_BY_TYPE][request_type] += total_tokens
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
cost = record.get("cost", 0.0)
stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost
stats[period_key][COST_BY_MODEL][model_name] += cost
break # 取消更早时间段的判断
return stats
@staticmethod
def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
"""
收集指定时间段的在线时间统计数据
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 在线时间统计
ONLINE_TIME: 0.0,
}
for period_key, _ in collect_period
}
# 统计在线时间
for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}):
end_timestamp: datetime = record.get("end_timestamp")
for idx, (_, period_start) in enumerate(collect_period):
if end_timestamp >= period_start:
# 由于end_timestamp会超前标记时间所以我们需要判断是否晚于当前时间如果是则使用当前时间作为结束时间
if end_timestamp > now:
end_timestamp = now
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _period_start in collect_period[idx:]:
start_timestamp: datetime = record.get("start_timestamp")
if start_timestamp < _period_start:
# 如果开始时间在查询边界之前,则使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds()
else:
# 否则,使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds()
break # 取消更早时间段的判断
return stats
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
"""
收集指定时间段的消息统计数据
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 消息统计
TOTAL_MSG_CNT: 0,
MSG_CNT_BY_CHAT: defaultdict(int),
}
for period_key, _ in collect_period
}
# 统计消息量
for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}):
chat_info = message.get("chat_info", None) # 聊天信息
user_info = message.get("user_info", None) # 用户信息(消息发送人)
message_time = message.get("time", 0) # 消息时间
group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息
if group_info is not None:
# 若有群聊信息
chat_id = f"g{group_info.get('group_id')}"
chat_name = group_info.get("group_name", f"{group_info.get('group_id')}")
elif user_info:
# 若没有群聊信息,则尝试获取用户信息
chat_id = f"u{user_info['user_id']}"
chat_name = user_info["user_nickname"]
else:
continue # 如果没有群组信息也没有用户信息,则跳过
if chat_id in self.name_mapping:
if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]:
# 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称
self.name_mapping[chat_id] = (chat_name, message_time)
else:
self.name_mapping[chat_id] = (chat_name, message_time)
for idx, (_, period_start) in enumerate(collect_period):
if message_time >= period_start.timestamp():
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break
return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
"""
收集各时间段的统计数据
:param now: 基准当前时间
"""
last_all_time_stat = None
if "last_full_statistics_timestamp" in local_storage and "last_full_statistics" in local_storage:
# 若存有上次完整统计的时间戳,则使用该时间戳作为"所有时间"的起始时间,进行增量统计
last_full_stat_ts: float = local_storage["last_full_statistics_timestamp"]
last_all_time_stat = local_storage["last_full_statistics"]
self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段
self.stat_period.append(("all_time", now - datetime.fromtimestamp(last_full_stat_ts), "自部署以来的"))
stat_start_timestamp = [(period[0], now - period[1]) for period in self.stat_period]
stat = {item[0]: {} for item in self.stat_period}
model_req_stat = self._collect_model_request_for_period(stat_start_timestamp)
online_time_stat = self._collect_online_time_for_period(stat_start_timestamp, now)
message_count_stat = self._collect_message_count_for_period(stat_start_timestamp)
# 统计数据合并
# 合并三类统计数据
for period_key, _ in stat_start_timestamp:
stat[period_key].update(model_req_stat[period_key])
stat[period_key].update(online_time_stat[period_key])
stat[period_key].update(message_count_stat[period_key])
if last_all_time_stat:
# 若存在上次完整统计数据,则将其与当前统计数据合并
for key, val in last_all_time_stat.items():
if isinstance(val, dict):
# 是字典类型,则进行合并
for sub_key, sub_val in val.items():
stat["all_time"][key][sub_key] += sub_val
else:
# 直接合并
stat["all_time"][key] += val
# 更新上次完整统计数据的时间戳
local_storage["last_full_statistics_timestamp"] = now.timestamp()
# 更新上次完整统计数据
local_storage["last_full_statistics"] = stat["all_time"]
return stat
# -- 以下为统计数据格式化方法 --
@staticmethod
def _format_total_stat(stats: Dict[str, Any]) -> str:
"""
格式化总统计数据
"""
output = [
f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}",
f"总消息数: {stats[TOTAL_MSG_CNT]}",
f"总请求数: {stats[TOTAL_REQ_CNT]}",
f"总花费: {stats[TOTAL_COST]:.4f}¥",
"",
]
return "\n".join(output)
@staticmethod
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
"""
格式化按模型分类的统计数据
"""
if stats[TOTAL_REQ_CNT] > 0:
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.4f}¥"
output = [
"按模型分类统计:",
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费",
]
for model_name, count in sorted(stats[REQ_CNT_BY_MODEL].items()):
name = model_name[:29] + "..." if len(model_name) > 32 else model_name
in_tokens = stats[IN_TOK_BY_MODEL][model_name]
out_tokens = stats[OUT_TOK_BY_MODEL][model_name]
tokens = stats[TOTAL_TOK_BY_MODEL][model_name]
cost = stats[COST_BY_MODEL][model_name]
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost))
output.append("")
return "\n".join(output)
else:
return ""
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
"""
格式化聊天统计数据
"""
if stats[TOTAL_MSG_CNT] > 0:
output = ["聊天消息统计:", " 联系人/群组名称 消息数量"]
for chat_id, count in sorted(stats[MSG_CNT_BY_CHAT].items()):
output.append(f"{self.name_mapping[chat_id][0][:32]:<32} {count:>10}")
output.append("")
return "\n".join(output)
else:
return ""
def _generate_html_report(self, stat: dict[str, Any], now: datetime):
"""
生成HTML格式的统计报告
:param stat: 统计数据
:param now: 基准当前时间
:return: HTML格式的统计报告
"""
tab_list = [
f'<button class="tab-link" onclick="showTab(event, \'{period[0]}\')">{period[2]}</button>'
for period in self.stat_period
]
def _format_stat_data(stat_data: dict[str, Any], div_id: str, start_time: datetime) -> str:
"""
格式化一个时间段的统计数据到html div块
:param stat_data: 统计数据
:param div_id: div的ID
:param start_time: 统计时间段开始时间
"""
# format总在线时间
# 按模型分类统计
model_rows = "\n".join([
f"<tr>"
f"<td>{model_name}</td>"
f"<td>{count}</td>"
f"<td>{stat_data[IN_TOK_BY_MODEL][model_name]}</td>"
f"<td>{stat_data[OUT_TOK_BY_MODEL][model_name]}</td>"
f"<td>{stat_data[TOTAL_TOK_BY_MODEL][model_name]}</td>"
f"<td>{stat_data[COST_BY_MODEL][model_name]:.4f} ¥</td>"
f"</tr>"
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
])
# 按请求类型分类统计
type_rows = "\n".join([
f"<tr>"
f"<td>{req_type}</td>"
f"<td>{count}</td>"
f"<td>{stat_data[IN_TOK_BY_TYPE][req_type]}</td>"
f"<td>{stat_data[OUT_TOK_BY_TYPE][req_type]}</td>"
f"<td>{stat_data[TOTAL_TOK_BY_TYPE][req_type]}</td>"
f"<td>{stat_data[COST_BY_TYPE][req_type]:.4f} ¥</td>"
f"</tr>"
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
])
# 按用户分类统计
user_rows = "\n".join([
f"<tr>"
f"<td>{user_id}</td>"
f"<td>{count}</td>"
f"<td>{stat_data[IN_TOK_BY_USER][user_id]}</td>"
f"<td>{stat_data[OUT_TOK_BY_USER][user_id]}</td>"
f"<td>{stat_data[TOTAL_TOK_BY_USER][user_id]}</td>"
f"<td>{stat_data[COST_BY_USER][user_id]:.4f} ¥</td>"
f"</tr>"
for user_id, count in sorted(stat_data[REQ_CNT_BY_USER].items())
])
# 聊天消息统计
chat_rows = "\n".join([
f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>"
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items())
])
# 生成HTML
return f"""
<div id=\"{div_id}\" class=\"tab-content\">
<p class=\"info-item\">
<strong>统计时段: </strong>
{start_time.strftime("%Y-%m-%d %H:%M:%S")} ~ {now.strftime("%Y-%m-%d %H:%M:%S")}
</p>
<p class=\"info-item\"><strong>总在线时间: </strong>{_format_online_time(stat_data[ONLINE_TIME])}</p>
<p class=\"info-item\"><strong>总消息数: </strong>{stat_data[TOTAL_MSG_CNT]}</p>
<p class=\"info-item\"><strong>总请求数: </strong>{stat_data[TOTAL_REQ_CNT]}</p>
<p class=\"info-item\"><strong>总花费: </strong>{stat_data[TOTAL_COST]:.4f} ¥</p>
<h2>按模型分类统计</h2>
<table>
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th></tr></thead>
<tbody>
{model_rows}
</tbody>
</table>
<h2>按请求类型分类统计</h2>
<table>
<thead>
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th></tr>
</thead>
<tbody>
{type_rows}
</tbody>
</table>
<h2>按用户分类统计</h2>
<table>
<thead>
<tr><th>用户名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th></tr>
</thead>
<tbody>
{user_rows}
</tbody>
</table>
<h2>聊天消息统计</h2>
<table>
<thead>
<tr><th>联系人/群组名称</th><th>消息数量</th></tr>
</thead>
<tbody>
{chat_rows}
</tbody>
</table>
</div>
"""
tab_content_list = [
_format_stat_data(stat[period[0]], period[0], now - period[1])
for period in self.stat_period
if period[0] != "all_time"
]
tab_content_list.append(
_format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"]))
)
joined_tab_list = "\n".join(tab_list)
joined_tab_content = "\n".join(tab_content_list)
html_template = (
"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MaiBot运行统计报告</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
margin: 0;
padding: 20px;
background-color: #f4f7f6;
color: #333;
line-height: 1.6;
}
.container {
max-width: 900px;
margin: 20px auto;
background-color: #fff;
padding: 25px;
border-radius: 8px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
h1, h2 {
color: #2c3e50;
border-bottom: 2px solid #3498db;
padding-bottom: 10px;
margin-top: 0;
}
h1 {
text-align: center;
font-size: 2em;
}
h2 {
font-size: 1.5em;
margin-top: 30px;
}
p {
margin-bottom: 10px;
}
.info-item {
background-color: #ecf0f1;
padding: 8px 12px;
border-radius: 4px;
margin-bottom: 8px;
font-size: 0.95em;
}
.info-item strong {
color: #2980b9;
}
table {
width: 100%;
border-collapse: collapse;
margin-top: 15px;
font-size: 0.9em;
}
th, td {
border: 1px solid #ddd;
padding: 10px;
text-align: left;
}
th {
background-color: #3498db;
color: white;
font-weight: bold;
}
tr:nth-child(even) {
background-color: #f9f9f9;
}
.footer {
text-align: center;
margin-top: 30px;
font-size: 0.8em;
color: #7f8c8d;
}
.tabs {
overflow: hidden;
background: #ecf0f1;
display: flex;
}
.tabs button {
background: inherit; border: none; outline: none;
padding: 14px 16px; cursor: pointer;
transition: 0.3s; font-size: 16px;
}
.tabs button:hover {
background-color: #d4dbdc;
}
.tabs button.active {
background-color: #b3bbbd;
}
.tab-content {
display: none;
padding: 20px;
background-color: #fff;
border: 1px solid #ccc;
}
.tab-content.active {
display: block;
}
</style>
</head>
<body>
"""
+ f"""
<div class="container">
<h1>MaiBot运行统计报告</h1>
<p class="info-item"><strong>统计截止时间:</strong> {now.strftime("%Y-%m-%d %H:%M:%S")}</p>
<div class="tabs">
{joined_tab_list}
</div>
{joined_tab_content}
</div>
"""
+ """
<script>
let i, tab_content, tab_links;
tab_content = document.getElementsByClassName("tab-content");
tab_links = document.getElementsByClassName("tab-link");
tab_content[0].classList.add("active");
tab_links[0].classList.add("active");
function showTab(evt, tabName) {{
for (i = 0; i < tab_content.length; i++) tab_content[i].classList.remove("active");
for (i = 0; i < tab_links.length; i++) tab_links[i].classList.remove("active");
document.getElementById(tabName).classList.add("active");
evt.currentTarget.classList.add("active");
}}
</script>
</body>
</html>
"""
)
with open(self.record_file_path, "w", encoding="utf-8") as f:
f.write(html_template)

View File

@@ -0,0 +1,155 @@
from time import perf_counter
from functools import wraps
from typing import Optional, Dict, Callable
import asyncio
from rich.traceback import install
install(extra_lines=3)
"""
# 更好的计时器
使用形式:
- 上下文
- 装饰器
- 直接实例化
使用场景:
- 使用Timer在需要测量代码执行时间时如性能测试、计时器工具Timer类是更可靠、高精度的选择。
- 使用time.time()的场景:当需要记录实际时间点(如日志、时间戳)时使用,但避免用它测量时间间隔。
使用方式:
【装饰器】
time_dict = {}
@Timer("计数", time_dict)
def func():
pass
print(time_dict)
【上下文_1】
def func():
with Timer() as t:
pass
print(t)
print(t.human_readable)
【上下文_2】
def func():
time_dict = {}
with Timer("计数", time_dict):
pass
print(time_dict)
【直接实例化】
a = Timer()
print(a) # 直接输出当前 perf_counter 值
参数:
- name计时器的名字默认为 None
- storage计时器结果存储字典默认为 None
- auto_unit自动选择单位毫秒或秒默认为 True自动根据时间切换毫秒或秒
- do_type_check是否进行类型检查默认为 False不进行类型检查
属性human_readable
自定义错误TimerTypeError
"""
class TimerTypeError(TypeError):
"""自定义类型错误"""
__slots__ = ()
def __init__(self, param, expected_type, actual_type):
super().__init__(f"参数 '{param}' 类型错误,期望 {expected_type},实际得到 {actual_type.__name__}")
class Timer:
"""
Timer 支持三种模式:
1. 装饰器模式:用于测量函数/协程运行时间
2. 上下文管理器模式:用于 with 语句块内部计时
3. 直接实例化:如果不调用 __enter__打印对象时将显示当前 perf_counter 的值
"""
__slots__ = ("name", "storage", "elapsed", "auto_unit", "start")
def __init__(
self,
name: Optional[str] = None,
storage: Optional[Dict[str, float]] = None,
auto_unit: bool = True,
do_type_check: bool = False,
):
if do_type_check:
self._validate_types(name, storage)
self.name = name
self.storage = storage
self.elapsed = None
self.auto_unit = auto_unit
self.start = None
@staticmethod
def _validate_types(name, storage):
"""类型检查"""
if name is not None and not isinstance(name, str):
raise TimerTypeError("name", "Optional[str]", type(name))
if storage is not None and not isinstance(storage, dict):
raise TimerTypeError("storage", "Optional[dict]", type(storage))
def __call__(self, func: Optional[Callable] = None) -> Callable:
"""装饰器模式"""
if func is None:
return lambda f: Timer(name=self.name or f.__name__, storage=self.storage, auto_unit=self.auto_unit)(f)
@wraps(func)
async def async_wrapper(*args, **kwargs):
with self:
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
with self:
return func(*args, **kwargs)
wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
wrapper.__timer__ = self # 保留计时器引用
return wrapper
def __enter__(self):
"""上下文管理器入口"""
self.start = perf_counter()
return self
def __exit__(self, *args):
self.elapsed = perf_counter() - self.start
self._record_time()
return False
def _record_time(self):
"""记录时间"""
if self.storage is not None and self.name:
self.storage[self.name] = self.elapsed
@property
def human_readable(self) -> str:
"""人类可读时间格式"""
if self.elapsed is None:
return "未计时"
if self.auto_unit:
return f"{self.elapsed * 1000:.2f}毫秒" if self.elapsed < 1 else f"{self.elapsed:.2f}"
return f"{self.elapsed:.4f}"
def __str__(self):
if self.start is not None:
if self.elapsed is None:
current_elapsed = perf_counter() - self.start
return f"<Timer {self.name or '匿名'} [计时中: {current_elapsed:.4f}秒]>"
return f"<Timer {self.name or '匿名'} [{self.human_readable}]>"
return f"{perf_counter()}"

View File

@@ -0,0 +1,477 @@
"""
错别字生成器 - 基于拼音和字频的中文错别字生成工具
"""
import json
import math
import os
import random
import time
from collections import defaultdict
from pathlib import Path
import jieba
from pypinyin import Style, pinyin
from src.common.logger import get_module_logger
logger = get_module_logger("typo_gen")
class ChineseTypoGenerator:
def __init__(self, error_rate=0.3, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3, max_freq_diff=200):
"""
初始化错别字生成器
参数:
error_rate: 单字替换概率
min_freq: 最小字频阈值
tone_error_rate: 声调错误概率
word_replace_rate: 整词替换概率
max_freq_diff: 最大允许的频率差异
"""
self.error_rate = error_rate
self.min_freq = min_freq
self.tone_error_rate = tone_error_rate
self.word_replace_rate = word_replace_rate
self.max_freq_diff = max_freq_diff
# 加载数据
# print("正在加载汉字数据库,请稍候...")
# logger.info("正在加载汉字数据库,请稍候...")
self.pinyin_dict = self._create_pinyin_dict()
self.char_frequency = self._load_or_create_char_frequency()
def _load_or_create_char_frequency(self):
"""
加载或创建汉字频率字典
"""
cache_file = Path("depends-data/char_frequency.json")
# 如果缓存文件存在,直接加载
if cache_file.exists():
with open(cache_file, "r", encoding="utf-8") as f:
return json.load(f)
# 使用内置的词频文件
char_freq = defaultdict(int)
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
# 读取jieba的词典文件
with open(dict_path, "r", encoding="utf-8") as f:
for line in f:
word, freq = line.strip().split()[:2]
# 对词中的每个字进行频率累加
for char in word:
if self._is_chinese_char(char):
char_freq[char] += int(freq)
# 归一化频率值
max_freq = max(char_freq.values())
normalized_freq = {char: freq / max_freq * 1000 for char, freq in char_freq.items()}
# 保存到缓存文件
with open(cache_file, "w", encoding="utf-8") as f:
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
return normalized_freq
@staticmethod
def _create_pinyin_dict():
"""
创建拼音到汉字的映射字典
"""
# 常用汉字范围
chars = [chr(i) for i in range(0x4E00, 0x9FFF)]
pinyin_dict = defaultdict(list)
# 为每个汉字建立拼音映射
for char in chars:
try:
py = pinyin(char, style=Style.TONE3)[0][0]
pinyin_dict[py].append(char)
except Exception:
continue
return pinyin_dict
@staticmethod
def _is_chinese_char(char):
"""
判断是否为汉字
"""
try:
return "\u4e00" <= char <= "\u9fff"
except Exception as e:
logger.debug(e)
return False
def _get_pinyin(self, sentence):
"""
将中文句子拆分成单个汉字并获取其拼音
"""
# 将句子拆分成单个字符
characters = list(sentence)
# 获取每个字符的拼音
result = []
for char in characters:
# 跳过空格和非汉字字符
if char.isspace() or not self._is_chinese_char(char):
continue
# 获取拼音(数字声调)
py = pinyin(char, style=Style.TONE3)[0][0]
result.append((char, py))
return result
@staticmethod
def _get_similar_tone_pinyin(py):
"""
获取相似声调的拼音
"""
# 检查拼音是否为空或无效
if not py or len(py) < 1:
return py
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1
return py + "1"
base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调
# 处理轻声通常用5表示或无效声调
if tone not in [1, 2, 3, 4]:
return base + str(random.choice([1, 2, 3, 4]))
# 正常处理声调
possible_tones = [1, 2, 3, 4]
possible_tones.remove(tone) # 移除原声调
new_tone = random.choice(possible_tones) # 随机选择一个新声调
return base + str(new_tone)
def _calculate_replacement_probability(self, orig_freq, target_freq):
"""
根据频率差计算替换概率
"""
if target_freq > orig_freq:
return 1.0 # 如果替换字频率更高,保持原有概率
freq_diff = orig_freq - target_freq
if freq_diff > self.max_freq_diff:
return 0.0 # 频率差太大,不替换
# 使用指数衰减函数计算概率
# 频率差为0时概率为1频率差为max_freq_diff时概率接近0
return math.exp(-3 * freq_diff / self.max_freq_diff)
def _get_similar_frequency_chars(self, char, py, num_candidates=5):
"""
获取与给定字频率相近的同音字,可能包含声调错误
"""
homophones = []
# 有一定概率使用错误声调
if random.random() < self.tone_error_rate:
wrong_tone_py = self._get_similar_tone_pinyin(py)
homophones.extend(self.pinyin_dict[wrong_tone_py])
# 添加正确声调的同音字
homophones.extend(self.pinyin_dict[py])
if not homophones:
return None
# 获取原字的频率
orig_freq = self.char_frequency.get(char, 0)
# 计算所有同音字与原字的频率差,并过滤掉低频字
freq_diff = [
(h, self.char_frequency.get(h, 0))
for h in homophones
if h != char and self.char_frequency.get(h, 0) >= self.min_freq
]
if not freq_diff:
return None
# 计算每个候选字的替换概率
candidates_with_prob = []
for h, freq in freq_diff:
prob = self._calculate_replacement_probability(orig_freq, freq)
if prob > 0: # 只保留有效概率的候选字
candidates_with_prob.append((h, prob))
if not candidates_with_prob:
return None
# 根据概率排序
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
# 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]]
@staticmethod
def _get_word_pinyin(word):
"""
获取词语的拼音列表
"""
return [py[0] for py in pinyin(word, style=Style.TONE3)]
@staticmethod
def _segment_sentence(sentence):
"""
使用jieba分词返回词语列表
"""
return list(jieba.cut(sentence))
def _get_word_homophones(self, word):
"""
获取整个词的同音词,只返回高频的有意义词语
"""
if len(word) == 1:
return []
# 获取词的拼音
word_pinyin = self._get_word_pinyin(word)
# 遍历所有可能的同音字组合
candidates = []
for py in word_pinyin:
chars = self.pinyin_dict.get(py, [])
if not chars:
return []
candidates.append(chars)
# 生成所有可能的组合
import itertools
all_combinations = itertools.product(*candidates)
# 获取jieba词典和词频信息
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
valid_words = {} # 改用字典存储词语及其频率
with open(dict_path, "r", encoding="utf-8") as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
word_text = parts[0]
word_freq = float(parts[1]) # 获取词频
valid_words[word_text] = word_freq
# 获取原词的词频作为参考
original_word_freq = valid_words.get(word, 0)
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
# 过滤和计算频率
homophones = []
for combo in all_combinations:
new_word = "".join(combo)
if new_word != word and new_word in valid_words:
new_word_freq = valid_words[new_word]
# 只保留词频达到阈值的词
if new_word_freq >= min_word_freq:
# 计算词的平均字频(考虑字频和词频)
char_avg_freq = sum(self.char_frequency.get(c, 0) for c in new_word) / len(new_word)
# 综合评分:结合词频和字频
combined_score = new_word_freq * 0.7 + char_avg_freq * 0.3
if combined_score >= self.min_freq:
homophones.append((new_word, combined_score))
# 按综合分数排序并限制返回数量
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
def create_typo_sentence(self, sentence):
"""
创建包含同音字错误的句子,支持词语级别和字级别的替换
参数:
sentence: 输入的中文句子
返回:
typo_sentence: 包含错别字的句子
correction_suggestion: 随机选择的一个纠正建议,返回正确的字/词
"""
result = []
typo_info = []
word_typos = [] # 记录词语错误对(错词,正确词)
char_typos = [] # 记录单字错误对(错字,正确字)
current_pos = 0
# 分词
words = self._segment_sentence(sentence)
for word in words:
# 如果是标点符号或空格,直接添加
if all(not self._is_chinese_char(c) for c in word):
result.append(word)
current_pos += len(word)
continue
# 获取词语的拼音
word_pinyin = self._get_word_pinyin(word)
# 尝试整词替换
if len(word) > 1 and random.random() < self.word_replace_rate:
word_homophones = self._get_word_homophones(word)
if word_homophones:
typo_word = random.choice(word_homophones)
# 计算词的平均频率
orig_freq = sum(self.char_frequency.get(c, 0) for c in word) / len(word)
typo_freq = sum(self.char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
# 添加到结果中
result.append(typo_word)
typo_info.append(
(
word,
typo_word,
" ".join(word_pinyin),
" ".join(self._get_word_pinyin(typo_word)),
orig_freq,
typo_freq,
)
)
word_typos.append((typo_word, word)) # 记录(错词,正确词)对
current_pos += len(typo_word)
continue
# 如果不进行整词替换,则进行单字替换
if len(word) == 1:
char = word
py = word_pinyin[0]
if random.random() < self.error_rate:
similar_chars = self._get_similar_frequency_chars(char, py)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = self.char_frequency.get(typo_char, 0)
orig_freq = self.char_frequency.get(char, 0)
replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
char_typos.append((typo_char, char)) # 记录(错字,正确字)对
current_pos += 1
continue
result.append(char)
current_pos += 1
else:
# 处理多字词的单字替换
word_result = []
for _, (char, py) in enumerate(zip(word, word_pinyin)):
# 词中的字替换概率降低
word_error_rate = self.error_rate * (0.7 ** (len(word) - 1))
if random.random() < word_error_rate:
similar_chars = self._get_similar_frequency_chars(char, py)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = self.char_frequency.get(typo_char, 0)
orig_freq = self.char_frequency.get(char, 0)
replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
word_result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
char_typos.append((typo_char, char)) # 记录(错字,正确字)对
continue
word_result.append(char)
result.append("".join(word_result))
current_pos += len(word)
# 优先从词语错误中选择,如果没有则从单字错误中选择
correction_suggestion = None
# 50%概率返回纠正建议
if random.random() < 0.5:
if word_typos:
wrong_word, correct_word = random.choice(word_typos)
correction_suggestion = correct_word
elif char_typos:
wrong_char, correct_char = random.choice(char_typos)
correction_suggestion = correct_char
return "".join(result), correction_suggestion
@staticmethod
def format_typo_info(typo_info):
"""
格式化错别字信息
参数:
typo_info: 错别字信息列表
返回:
格式化后的错别字信息字符串
"""
if not typo_info:
return "未生成错别字"
result = []
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
# 判断是否为词语替换
is_word = " " in orig_py
if is_word:
error_type = "整词替换"
else:
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
error_type = "声调错误" if tone_error else "同音字替换"
result.append(
f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> "
f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]"
)
return "\n".join(result)
def set_params(self, **kwargs):
"""
设置参数
可设置参数:
error_rate: 单字替换概率
min_freq: 最小字频阈值
tone_error_rate: 声调错误概率
word_replace_rate: 整词替换概率
max_freq_diff: 最大允许的频率差异
"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
print(f"参数 {key} 已设置为 {value}")
else:
print(f"警告: 参数 {key} 不存在")
def main():
# 创建错别字生成器实例
typo_generator = ChineseTypoGenerator(error_rate=0.03, min_freq=7, tone_error_rate=0.02, word_replace_rate=0.3)
# 获取用户输入
sentence = input("请输入中文句子:")
# 创建包含错别字的句子
start_time = time.time()
typo_sentence, correction_suggestion = typo_generator.create_typo_sentence(sentence)
# 打印结果
print("\n原句:", sentence)
print("错字版:", typo_sentence)
# 打印纠正建议
if correction_suggestion:
print("\n随机纠正建议:")
print(f"应该改为:{correction_suggestion}")
# 计算并打印总耗时
end_time = time.time()
total_time = end_time - start_time
print(f"\n总耗时:{total_time:.2f}")
if __name__ == "__main__":
main()

744
src/chat/utils/utils.py Normal file
View File

@@ -0,0 +1,744 @@
import random
import re
import time
from collections import Counter
import jieba
import numpy as np
from maim_message import UserInfo
from pymongo.errors import PyMongoError
from src.common.logger import get_module_logger
from src.manager.mood_manager import mood_manager
from ..message_receive.message import MessageRecv
from ..models.utils_model import LLMRequest
from .typo_generator import ChineseTypoGenerator
from ...common.database import db
from ...config.config import global_config
logger = get_module_logger("chat_utils")
def is_english_letter(char: str) -> bool:
"""检查字符是否为英文字母(忽略大小写)"""
return "a" <= char.lower() <= "z"
def db_message_to_str(message_dict: dict) -> str:
logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try:
name = "[(%s)%s]%s" % (
message_dict["user_id"],
message_dict.get("user_nickname", ""),
message_dict.get("user_cardname", ""),
)
except Exception:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "")
result = f"[{time_str}] {name}: {content}\n"
logger.debug(f"result: {result}")
return result
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
nicknames = global_config.BOT_ALIAS_NAMES
reply_probability = 0.0
is_at = False
is_mentioned = False
if (
message.message_info.additional_config is not None
and message.message_info.additional_config.get("is_mentioned") is not None
):
try:
reply_probability = float(message.message_info.additional_config.get("is_mentioned"))
is_mentioned = True
return is_mentioned, reply_probability
except Exception as e:
logger.warning(e)
logger.warning(
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
)
# 判断是否被@
if re.search(f"@[\s\S]*?id:{global_config.BOT_QQ}", message.processed_plain_text):
is_at = True
is_mentioned = True
if is_at and global_config.at_bot_inevitable_reply:
reply_probability = 1.0
logger.info("被@回复概率设置为100%")
else:
if not is_mentioned:
# 判断是否被回复
if re.match(
f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\)[\s\S]*?],说:", message.processed_plain_text
):
is_mentioned = True
else:
# 判断内容中是否被提及
message_content = re.sub(r"@[\s\S]*?(\d+)", "", message.processed_plain_text)
message_content = re.sub(r"\[回复 [\s\S]*?\(((\d+)|未知id)\)[\s\S]*?],说:", "", message_content)
for keyword in keywords:
if keyword in message_content:
is_mentioned = True
for nickname in nicknames:
if nickname in message_content:
is_mentioned = True
if is_mentioned and global_config.mentioned_bot_inevitable_reply:
reply_probability = 1.0
logger.info("被提及回复概率设置为100%")
return is_mentioned, reply_probability
async def get_embedding(text, request_type="embedding"):
"""获取文本的embedding向量"""
llm = LLMRequest(model=global_config.embedding, request_type=request_type)
# return llm.get_embedding_sync(text)
try:
embedding = await llm.get_embedding(text)
except Exception as e:
logger.error(f"获取embedding失败: {str(e)}")
embedding = None
return embedding
def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, combine=False):
recent_messages = list(
db.messages.find(
{"chat_id": chat_stream_id},
{
"time": 1, # 返回时间字段
"chat_id": 1,
"chat_info": 1,
"user_info": 1,
"message_id": 1, # 返回消息ID字段
"detailed_plain_text": 1, # 返回处理后的文本字段
},
)
.sort("time", -1)
.limit(limit)
)
if not recent_messages:
return []
message_detailed_plain_text = ""
message_detailed_plain_text_list = []
# 反转消息列表,使最新的消息在最后
recent_messages.reverse()
if combine:
for msg_db_data in recent_messages:
message_detailed_plain_text += str(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text
else:
for msg_db_data in recent_messages:
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text_list
def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> list:
# 获取当前群聊记录内发言的人
recent_messages = list(
db.messages.find(
{"chat_id": chat_stream_id},
{
"user_info": 1,
},
)
.sort("time", -1)
.limit(limit)
)
if not recent_messages:
return []
who_chat_in_group = []
for msg_db_data in recent_messages:
user_info = UserInfo.from_dict(msg_db_data["user_info"])
if (
(user_info.platform, user_info.user_id) != sender
and user_info.user_id != global_config.BOT_QQ
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目
who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
return who_chat_in_group
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
"""将文本分割成句子,并根据概率合并
1. 识别分割点(, 。 ; 空格),但如果分割点左右都是英文字母则不分割。
2. 将文本分割成 (内容, 分隔符) 的元组。
3. 根据原始文本长度计算合并概率,概率性地合并相邻段落。
注意:此函数假定颜文字已在上层被保护。
Args:
text: 要分割的文本字符串 (假定颜文字已被保护)
Returns:
List[str]: 分割和合并后的句子列表
"""
# 预处理:处理多余的换行符
# 1. 将连续的换行符替换为单个换行符
text = re.sub(r"\n\s*\n+", "\n", text)
# 2. 处理换行符和其他分隔符的组合
text = re.sub(r"\n\s*([,。;\s])", r"\1", text)
text = re.sub(r"([,。;\s])\s*\n", r"\1", text)
# 处理两个汉字中间的换行符
text = re.sub(r"([\u4e00-\u9fff])\n([\u4e00-\u9fff])", r"\1。\2", text)
len_text = len(text)
if len_text < 3:
if random.random() < 0.01:
return list(text) # 如果文本很短且触发随机条件,直接按字符分割
else:
return [text]
# 定义分隔符
separators = {"", ",", " ", "", ";"}
segments = []
current_segment = ""
# 1. 分割成 (内容, 分隔符) 元组
i = 0
while i < len(text):
char = text[i]
if char in separators:
# 检查分割条件:如果分隔符左右都是英文字母,则不分割
can_split = True
if 0 < i < len(text) - 1:
prev_char = text[i - 1]
next_char = text[i + 1]
# if is_english_letter(prev_char) and is_english_letter(next_char) and char == ' ': # 原计划只对空格应用此规则,现应用于所有分隔符
if is_english_letter(prev_char) and is_english_letter(next_char):
can_split = False
if can_split:
# 只有当当前段不为空时才添加
if current_segment:
segments.append((current_segment, char))
# 如果当前段为空,但分隔符是空格,则也添加一个空段(保留空格)
elif char == " ":
segments.append(("", char))
current_segment = ""
else:
# 不分割,将分隔符加入当前段
current_segment += char
else:
current_segment += char
i += 1
# 添加最后一个段(没有后续分隔符)
if current_segment:
segments.append((current_segment, ""))
# 过滤掉完全空的段(内容和分隔符都为空)
segments = [(content, sep) for content, sep in segments if content or sep]
# 如果分割后为空(例如,输入全是分隔符且不满足保留条件),恢复颜文字并返回
if not segments:
# recovered_text = recover_kaomoji([text], mapping) # 恢复原文本中的颜文字 - 已移至上层处理
# return [s for s in recovered_text if s] # 返回非空结果
return [text] if text else [] # 如果原始文本非空,则返回原始文本(可能只包含未被分割的字符或颜文字占位符)
# 2. 概率合并
if len_text < 12:
split_strength = 0.2
elif len_text < 32:
split_strength = 0.6
else:
split_strength = 0.7
# 合并概率与分割强度相反
merge_probability = 1.0 - split_strength
merged_segments = []
idx = 0
while idx < len(segments):
current_content, current_sep = segments[idx]
# 检查是否可以与下一段合并
# 条件:不是最后一段,且随机数小于合并概率,且当前段有内容(避免合并空段)
if idx + 1 < len(segments) and random.random() < merge_probability and current_content:
next_content, next_sep = segments[idx + 1]
# 合并: (内容1 + 分隔符1 + 内容2, 分隔符2)
# 只有当下一段也有内容时才合并文本,否则只传递分隔符
if next_content:
merged_content = current_content + current_sep + next_content
merged_segments.append((merged_content, next_sep))
else: # 下一段内容为空,只保留当前内容和下一段的分隔符
merged_segments.append((current_content, next_sep))
idx += 2 # 跳过下一段,因为它已被合并
else:
# 不合并,直接添加当前段
merged_segments.append((current_content, current_sep))
idx += 1
# 提取最终的句子内容
final_sentences = [content for content, sep in merged_segments if content] # 只保留有内容的段
# 清理可能引入的空字符串和仅包含空白的字符串
final_sentences = [
s for s in final_sentences if s.strip()
] # 过滤掉空字符串以及仅包含空白(如换行符、空格)的字符串
logger.debug(f"分割并合并后的句子: {final_sentences}")
return final_sentences
def random_remove_punctuation(text: str) -> str:
"""随机处理标点符号,模拟人类打字习惯
Args:
text: 要处理的文本
Returns:
str: 处理后的文本
"""
result = ""
text_len = len(text)
for i, char in enumerate(text):
if char == "" and i == text_len - 1: # 结尾的句号
if random.random() > 0.1: # 90%概率删除结尾句号
continue
elif char == "":
rand = random.random()
if rand < 0.25: # 5%概率删除逗号
continue
elif rand < 0.25: # 20%概率把逗号变成空格
result += " "
continue
result += char
return result
def process_llm_response(text: str) -> list[str]:
# 先保护颜文字
if global_config.enable_kaomoji_protection:
protected_text, kaomoji_mapping = protect_kaomoji(text)
logger.trace(f"保护颜文字后的文本: {protected_text}")
else:
protected_text = text
kaomoji_mapping = {}
# 提取被 () 或 [] 包裹且包含中文的内容
pattern = re.compile(r"[(\[](?=.*[一-鿿]).*?[)\]]")
# _extracted_contents = pattern.findall(text)
_extracted_contents = pattern.findall(protected_text) # 在保护后的文本上查找
# 去除 () 和 [] 及其包裹的内容
cleaned_text = pattern.sub("", protected_text)
if cleaned_text == "":
return ["呃呃"]
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
# 对清理后的文本进行进一步处理
max_length = global_config.response_max_length * 2
max_sentence_num = global_config.response_max_sentence_num
# 如果基本上是中文,则进行长度过滤
if get_western_ratio(cleaned_text) < 0.1:
if len(cleaned_text) > max_length:
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
return ["懒得说"]
typo_generator = ChineseTypoGenerator(
error_rate=global_config.chinese_typo_error_rate,
min_freq=global_config.chinese_typo_min_freq,
tone_error_rate=global_config.chinese_typo_tone_error_rate,
word_replace_rate=global_config.chinese_typo_word_replace_rate,
)
if global_config.enable_response_splitter:
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
else:
split_sentences = [cleaned_text]
sentences = []
for sentence in split_sentences:
if global_config.chinese_typo_enable:
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
sentences.append(typoed_text)
if typo_corrections:
sentences.append(typo_corrections)
else:
sentences.append(sentence)
if len(sentences) > max_sentence_num:
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f"{global_config.BOT_NICKNAME}不知道哦"]
# if extracted_contents:
# for content in extracted_contents:
# sentences.append(content)
# 在所有句子处理完毕后,对包含占位符的列表进行恢复
if global_config.enable_kaomoji_protection:
sentences = recover_kaomoji(sentences, kaomoji_mapping)
return sentences
def calculate_typing_time(
input_string: str,
thinking_start_time: float,
chinese_time: float = 0.2,
english_time: float = 0.1,
is_emoji: bool = False,
) -> float:
"""
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
input_string (str): 输入的字符串
chinese_time (float): 中文字符的输入时间默认为0.2秒
english_time (float): 英文字符的输入时间默认为0.1秒
is_emoji (bool): 是否为emoji默认为False
特殊情况:
- 如果只有一个中文字符将使用3倍的中文输入时间
- 在所有输入结束后额外加上回车时间0.3秒
- 如果is_emoji为True将使用固定1秒的输入时间
"""
# 将0-1的唤醒度映射到-1到1
mood_arousal = mood_manager.current_mood.arousal
# 映射到0.5到2倍的速度系数
typing_speed_multiplier = 1.5**mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半
chinese_time *= 1 / typing_speed_multiplier
english_time *= 1 / typing_speed_multiplier
# 计算中文字符数
chinese_chars = sum(1 for char in input_string if "\u4e00" <= char <= "\u9fff")
# 如果只有一个中文字符使用3倍时间
if chinese_chars == 1 and len(input_string.strip()) == 1:
return chinese_time * 3 + 0.3 # 加上回车时间
# 正常计算所有字符的输入时间
total_time = 0.0
for char in input_string:
if "\u4e00" <= char <= "\u9fff": # 判断是否为中文字符
total_time += chinese_time
else: # 其他字符(如英文)
total_time += english_time
if is_emoji:
total_time = 1
if time.time() - thinking_start_time > 10:
total_time = 1
# print(f"thinking_start_time:{thinking_start_time}")
# print(f"nowtime:{time.time()}")
# print(f"nowtime - thinking_start_time:{time.time() - thinking_start_time}")
# print(f"{total_time}")
return total_time # 加上回车时间
def cosine_similarity(v1, v2):
"""计算余弦相似度"""
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
if norm1 == 0 or norm2 == 0:
return 0
return dot_product / (norm1 * norm2)
def text_to_vector(text):
"""将文本转换为词频向量"""
# 分词
words = jieba.lcut(text)
# 统计词频
word_freq = Counter(words)
return word_freq
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
"""使用简单的余弦相似度计算文本相似度"""
# 将输入文本转换为词频向量
text_vector = text_to_vector(text)
# 计算每个主题的相似度
similarities = []
for topic in topics:
topic_vector = text_to_vector(topic)
# 获取所有唯一词
all_words = set(text_vector.keys()) | set(topic_vector.keys())
# 构建向量
v1 = [text_vector.get(word, 0) for word in all_words]
v2 = [topic_vector.get(word, 0) for word in all_words]
# 计算相似度
similarity = cosine_similarity(v1, v2)
similarities.append((topic, similarity))
# 按相似度降序排序并返回前k个
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度"""
if len(message) > max_length:
return message[:max_length] + "..."
return message
def protect_kaomoji(sentence):
""" "
识别并保护句子中的颜文字(含括号与无括号),将其替换为占位符,
并返回替换后的句子和占位符到颜文字的映射表。
Args:
sentence (str): 输入的原始句子
Returns:
tuple: (处理后的句子, {占位符: 颜文字})
"""
kaomoji_pattern = re.compile(
r"("
r"[(\[(【]" # 左括号
r"[^()\[\]()【】]*?" # 非括号字符(惰性匹配)
r"[^一-龥a-zA-Z0-9\s]" # 非中文、非英文、非数字、非空格字符(必须包含至少一个)
r"[^()\[\]()【】]*?" # 非括号字符(惰性匹配)
r"[)\])】" # 右括号
r"]"
r")"
r"|"
r"([▼▽・ᴥω・﹏^><≧≦ ̄`´∀ヮДд︿﹀へ。゚╥╯╰︶︹•⁄]{2,15})"
)
kaomoji_matches = kaomoji_pattern.findall(sentence)
placeholder_to_kaomoji = {}
for idx, match in enumerate(kaomoji_matches):
kaomoji = match[0] if match[0] else match[1]
placeholder = f"__KAOMOJI_{idx}__"
sentence = sentence.replace(kaomoji, placeholder, 1)
placeholder_to_kaomoji[placeholder] = kaomoji
return sentence, placeholder_to_kaomoji
def recover_kaomoji(sentences, placeholder_to_kaomoji):
"""
根据映射表恢复句子中的颜文字。
Args:
sentences (list): 含有占位符的句子列表
placeholder_to_kaomoji (dict): 占位符到颜文字的映射表
Returns:
list: 恢复颜文字后的句子列表
"""
recovered_sentences = []
for sentence in sentences:
for placeholder, kaomoji in placeholder_to_kaomoji.items():
sentence = sentence.replace(placeholder, kaomoji)
recovered_sentences.append(sentence)
return recovered_sentences
def get_western_ratio(paragraph):
"""计算段落中字母数字字符的西文比例
原理:检查段落中字母数字字符的西文比例
通过is_english_letter函数判断每个字符是否为西文
只检查字母数字字符,忽略标点符号和空格等非字母数字字符
Args:
paragraph: 要检查的文本段落
Returns:
float: 西文字符比例(0.0-1.0)如果没有字母数字字符则返回0.0
"""
alnum_chars = [char for char in paragraph if char.isalnum()]
if not alnum_chars:
return 0.0
western_count = sum(1 for char in alnum_chars if is_english_letter(char))
return western_count / len(alnum_chars)
def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]:
"""计算两个时间点之间的消息数量和文本总长度
Args:
start_time (float): 起始时间戳 (不包含)
end_time (float): 结束时间戳 (包含)
stream_id (str): 聊天流ID
Returns:
tuple[int, int]: (消息数量, 文本总长度)
"""
count = 0
total_length = 0
# 参数校验 (可选但推荐)
if start_time >= end_time:
# logger.debug(f"开始时间 {start_time} 大于或等于结束时间 {end_time},返回 0, 0")
return 0, 0
if not stream_id:
logger.error("stream_id 不能为空")
return 0, 0
# 直接查询时间范围内的消息
# time > start_time AND time <= end_time
query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
try:
# 执行查询
messages_cursor = db.messages.find(query)
# 遍历结果计算数量和长度
for msg in messages_cursor:
count += 1
total_length += len(msg.get("processed_plain_text", ""))
# logger.debug(f"查询范围 ({start_time}, {end_time}] 内找到 {count} 条消息,总长度 {total_length}")
return count, total_length
except PyMongoError as e:
logger.error(f"查询 stream_id={stream_id} 在 ({start_time}, {end_time}] 范围内的消息时出错: {e}")
return 0, 0
except Exception as e: # 保留一个通用异常捕获以防万一
logger.error(f"计算消息数量时发生意外错误: {e}")
return 0, 0
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
"""将时间戳转换为人类可读的时间格式
Args:
timestamp: 时间戳
mode: 转换模式,"normal"为标准格式,"relative"为相对时间格式
Returns:
str: 格式化后的时间字符串
"""
if mode == "normal":
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
elif mode == "relative":
now = time.time()
diff = now - timestamp
if diff < 20:
return "刚刚:\n"
elif diff < 60:
return f"{int(diff)}秒前:\n"
elif diff < 3600:
return f"{int(diff / 60)}分钟前:\n"
elif diff < 86400:
return f"{int(diff / 3600)}小时前:\n"
elif diff < 86400 * 2:
return f"{int(diff / 86400)}天前:\n"
else:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":\n"
else: # mode = "lite" or unknown
# 只返回时分秒格式,喵~
return time.strftime("%H:%M:%S", time.localtime(timestamp))
def parse_text_timestamps(text: str, mode: str = "normal") -> str:
"""解析文本中的时间戳并转换为可读时间格式
Args:
text: 包含时间戳的文本,时间戳应以[]包裹
mode: 转换模式传递给translate_timestamp_to_human_readable"normal""relative"
Returns:
str: 替换后的文本
转换规则:
- normal模式: 将文本中所有时间戳转换为可读格式
- lite模式:
- 第一个和最后一个时间戳必须转换
- 以5秒为间隔划分时间段每段最多转换一个时间戳
- 不转换的时间戳替换为空字符串
"""
# 匹配[数字]或[数字.数字]格式的时间戳
pattern = r"\[(\d+(?:\.\d+)?)\]"
# 找出所有匹配的时间戳
matches = list(re.finditer(pattern, text))
if not matches:
return text
# normal模式: 直接转换所有时间戳
if mode == "normal":
result_text = text
for match in matches:
timestamp = float(match.group(1))
readable_time = translate_timestamp_to_human_readable(timestamp, "normal")
# 由于替换会改变文本长度,需要使用正则替换而非直接替换
pattern_instance = re.escape(match.group(0))
result_text = re.sub(pattern_instance, readable_time, result_text, count=1)
return result_text
else:
# lite模式: 按5秒间隔划分并选择性转换
result_text = text
# 提取所有时间戳及其位置
timestamps = [(float(m.group(1)), m) for m in matches]
timestamps.sort(key=lambda x: x[0]) # 按时间戳升序排序
if not timestamps:
return text
# 获取第一个和最后一个时间戳
first_timestamp, first_match = timestamps[0]
last_timestamp, last_match = timestamps[-1]
# 将时间范围划分成5秒间隔的时间段
time_segments = {}
# 对所有时间戳按15秒间隔分组
for ts, match in timestamps:
segment_key = int(ts // 15) # 将时间戳除以15取整作为时间段的键
if segment_key not in time_segments:
time_segments[segment_key] = []
time_segments[segment_key].append((ts, match))
# 记录需要转换的时间戳
to_convert = []
# 从每个时间段中选择一个时间戳进行转换
for _, segment_timestamps in time_segments.items():
# 选择这个时间段中的第一个时间戳
to_convert.append(segment_timestamps[0])
# 确保第一个和最后一个时间戳在转换列表中
first_in_list = False
last_in_list = False
for ts, _ in to_convert:
if ts == first_timestamp:
first_in_list = True
if ts == last_timestamp:
last_in_list = True
if not first_in_list:
to_convert.append((first_timestamp, first_match))
if not last_in_list:
to_convert.append((last_timestamp, last_match))
# 创建需要转换的时间戳集合,用于快速查找
to_convert_set = {match.group(0) for _, match in to_convert}
# 首先替换所有不需要转换的时间戳为空字符串
for _, match in timestamps:
if match.group(0) not in to_convert_set:
pattern_instance = re.escape(match.group(0))
result_text = re.sub(pattern_instance, "", result_text, count=1)
# 按照时间戳原始顺序排序,避免替换时位置错误
to_convert.sort(key=lambda x: x[1].start())
# 执行替换
# 由于替换会改变文本长度,从后向前替换
to_convert.reverse()
for ts, match in to_convert:
readable_time = translate_timestamp_to_human_readable(ts, "relative")
pattern_instance = re.escape(match.group(0))
result_text = re.sub(pattern_instance, readable_time, result_text, count=1)
return result_text

View File

@@ -0,0 +1,379 @@
import base64
import os
import time
import hashlib
from typing import Optional
from PIL import Image
import io
import numpy as np
from ...common.database import db
from ...config.config import global_config
from ..models.utils_model import LLMRequest
from src.common.logger_manager import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("chat_image")
class ImageManager:
_instance = None
IMAGE_DIR = "data" # 图像存储根目录
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir()
self._initialized = True
self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
def _ensure_image_dir(self):
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
@staticmethod
def _ensure_image_collection():
"""确保images集合存在并创建索引"""
if "images" not in db.list_collection_names():
db.create_collection("images")
# 删除旧索引
db.images.drop_indexes()
# 创建新的复合索引
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)])
@staticmethod
def _ensure_description_collection():
"""确保image_descriptions集合存在并创建索引"""
if "image_descriptions" not in db.list_collection_names():
db.create_collection("image_descriptions")
# 删除旧索引
db.image_descriptions.drop_indexes()
# 创建新的复合索引
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
@staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
Args:
image_hash: 图片哈希值
description_type: 描述类型 ('emoji''image')
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
return result["description"] if result else None
@staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库
Args:
image_hash: 图片哈希值
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
try:
db.image_descriptions.update_one(
{"hash": image_hash, "type": description_type},
{
"$set": {
"description": description,
"timestamp": int(time.time()),
"hash": image_hash, # 确保hash字段存在
"type": description_type, # 确保type字段存在
}
},
upsert=True,
)
except Exception as e:
logger.error(f"保存描述到数据库失败: {str(e)}")
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
# logger.debug(f"缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]"
# 调用AI获取描述
if image_format == "gif" or image_format == "GIF":
image_base64 = self.transform_gif(image_base64)
prompt = "这是一个动态图表情包每一张图代表了动态图的某一帧黑色背景代表透明使用1-2个词描述一下表情包表达的情感和内容简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg")
else:
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]"
# 根据配置决定是否保存图片
if global_config.save_emoji:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")):
os.makedirs(os.path.join(self.IMAGE_DIR, "emoji"))
file_path = os.path.join(self.IMAGE_DIR, "emoji", filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
"hash": image_hash,
"path": file_path,
"type": "emoji",
"description": description,
"timestamp": timestamp,
}
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
logger.trace(f"保存表情包: {file_path}")
except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}")
# 保存描述到数据库
self._save_description_to_db(image_hash, description, "emoji")
return f"[表情包:{description}]"
except Exception as e:
logger.error(f"获取表情包描述失败: {str(e)}")
return "[表情包]"
async def get_image_description(self, image_base64: str) -> str:
"""获取普通图片描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
logger.debug(f"图片描述缓存中 {cached_description}")
return f"[图片:{cached_description}]"
# 调用AI获取描述
prompt = (
"请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多100个字。"
)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
return f"[图片:{cached_description}]"
logger.debug(f"描述是{description}")
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片]"
# 根据配置决定是否保存图片
if global_config.save_pic:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")):
os.makedirs(os.path.join(self.IMAGE_DIR, "image"))
file_path = os.path.join(self.IMAGE_DIR, "image", filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
"hash": image_hash,
"path": file_path,
"type": "image",
"description": description,
"timestamp": timestamp,
}
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
logger.trace(f"保存图片: {file_path}")
except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}")
# 保存描述到数据库
self._save_description_to_db(image_hash, description, "image")
return f"[图片:{description}]"
except Exception as e:
logger.error(f"获取图片描述失败: {str(e)}")
return "[图片]"
@staticmethod
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
"""将GIF转换为水平拼接的静态图像, 跳过相似的帧
Args:
gif_base64: GIF的base64编码字符串
similarity_threshold: 判定帧相似的阈值 (MSE)越小表示要求差异越大才算不同帧默认1000.0
max_frames: 最大抽取的帧数默认15
Returns:
Optional[str]: 拼接后的JPG图像的base64编码字符串, 或者在失败时返回None
"""
try:
# 解码base64
gif_data = base64.b64decode(gif_base64)
gif = Image.open(io.BytesIO(gif_data))
# 收集所有帧
all_frames = []
try:
while True:
gif.seek(len(all_frames))
# 确保是RGB格式方便比较
frame = gif.convert("RGB")
all_frames.append(frame.copy())
except EOFError:
pass # 读完啦
if not all_frames:
logger.warning("GIF中没有找到任何帧")
return None # 空的GIF直接返回None
# --- 新的帧选择逻辑 ---
selected_frames = []
last_selected_frame_np = None
for i, current_frame in enumerate(all_frames):
current_frame_np = np.array(current_frame)
# 第一帧总是要选的
if i == 0:
selected_frames.append(current_frame)
last_selected_frame_np = current_frame_np
continue
# 计算和上一张选中帧的差异(均方误差 MSE
if last_selected_frame_np is not None:
mse = np.mean((current_frame_np - last_selected_frame_np) ** 2)
# logger.trace(f"帧 {i} 与上一选中帧的 MSE: {mse}") # 可以取消注释来看差异值
# 如果差异够大,就选它!
if mse > similarity_threshold:
selected_frames.append(current_frame)
last_selected_frame_np = current_frame_np
# 检查是不是选够了
if len(selected_frames) >= max_frames:
# logger.debug(f"已选够 {max_frames} 帧,停止选择。")
break
# 如果差异不大就跳过这一帧啦
# --- 帧选择逻辑结束 ---
# 如果选择后连一帧都没有比如GIF只有一帧且后续处理失败或者原始GIF就没帧也返回None
if not selected_frames:
logger.warning("处理后没有选中任何帧")
return None
# logger.debug(f"总帧数: {len(all_frames)}, 选中帧数: {len(selected_frames)}")
# 获取选中的第一帧的尺寸(假设所有帧尺寸一致)
frame_width, frame_height = selected_frames[0].size
# 计算目标尺寸,保持宽高比
target_height = 200 # 固定高度
# 防止除以零
if frame_height == 0:
logger.error("帧高度为0无法计算缩放尺寸")
return None
target_width = int((target_height / frame_height) * frame_width)
# 宽度也不能是0
if target_width == 0:
logger.warning(f"计算出的目标宽度为0 (原始尺寸 {frame_width}x{frame_height})调整为1")
target_width = 1
# 调整所有选中帧的大小
resized_frames = [
frame.resize((target_width, target_height), Image.Resampling.LANCZOS) for frame in selected_frames
]
# 创建拼接图像
total_width = target_width * len(resized_frames)
# 防止总宽度为0
if total_width == 0 and len(resized_frames) > 0:
logger.warning("计算出的总宽度为0但有选中帧可能目标宽度太小")
# 至少给点宽度吧
total_width = len(resized_frames)
elif total_width == 0:
logger.error("计算出的总宽度为0且无选中帧")
return None
combined_image = Image.new("RGB", (total_width, target_height))
# 水平拼接图像
for idx, frame in enumerate(resized_frames):
combined_image.paste(frame, (idx * target_width, 0))
# 转换为base64
buffer = io.BytesIO()
combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG
result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return result_base64
except MemoryError:
logger.error("GIF转换失败: 内存不足可能是GIF太大或帧数太多")
return None # 内存不够啦
except Exception as e:
logger.error(f"GIF转换失败: {str(e)}", exc_info=True) # 记录详细错误信息
return None # 其他错误也返回None
# 创建全局单例
image_manager = ImageManager()
def image_path_to_base64(image_path: str) -> str:
"""将图片路径转换为base64编码
Args:
image_path: 图片文件路径
Returns:
str: base64编码的图片数据
Raises:
FileNotFoundError: 当图片文件不存在时
IOError: 当读取图片文件失败时
"""
if not os.path.exists(image_path):
raise FileNotFoundError(f"图片文件不存在: {image_path}")
with open(image_path, "rb") as f:
image_data = f.read()
if not image_data:
raise IOError(f"读取图片文件失败: {image_path}")
return base64.b64encode(image_data).decode("utf-8")