fix:调整目录结构,优化hfc prompt,移除日程,移除动态和llm判断willing模式,
This commit is contained in:
443
src/chat/utils/chat_message_builder.py
Normal file
443
src/chat/utils/chat_message_builder.py
Normal file
@@ -0,0 +1,443 @@
|
||||
from src.config.config import global_config
|
||||
from typing import List, Dict, Any, Tuple # 确保类型提示被导入
|
||||
import time # 导入 time 模块以获取当前时间
|
||||
import random
|
||||
import re
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.chat.person_info.person_info import person_info_manager
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
"""
|
||||
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}}
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
"""
|
||||
filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": timestamp_end}}
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
# 直接将 limit_mode 传递给 find_messages
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_chat_users(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
person_ids: list,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取某些特定用户在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
"""
|
||||
filter_query = {
|
||||
"chat_id": chat_id,
|
||||
"time": {"$gt": timestamp_start, "$lt": timestamp_end},
|
||||
"user_id": {"$in": person_ids},
|
||||
}
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_random(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息
|
||||
"""
|
||||
# 获取所有消息,只取chat_id字段
|
||||
all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
|
||||
if not all_msgs:
|
||||
return []
|
||||
# 随机选一条
|
||||
msg = random.choice(all_msgs)
|
||||
chat_id = msg["chat_id"]
|
||||
# 用 chat_id 获取该聊天在指定时间戳范围内的消息
|
||||
return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, limit_mode)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_users(
|
||||
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
"""
|
||||
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}}
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
filter_query = {"time": {"$lt": timestamp}}
|
||||
sort_order = [("time", 1)]
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
|
||||
sort_order = [("time", 1)]
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}}
|
||||
sort_order = [("time", 1)]
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float = None) -> int:
|
||||
"""
|
||||
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
||||
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
||||
"""
|
||||
# 确定有效的结束时间戳
|
||||
_timestamp_end = timestamp_end if timestamp_end is not None else time.time()
|
||||
|
||||
# 确保 timestamp_start < _timestamp_end
|
||||
if timestamp_start >= _timestamp_end:
|
||||
# logger.warning(f"timestamp_start ({timestamp_start}) must be less than _timestamp_end ({_timestamp_end}). Returning 0.")
|
||||
return 0 # 起始时间大于等于结束时间,没有新消息
|
||||
|
||||
filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}}
|
||||
return count_messages(message_filter=filter_query)
|
||||
|
||||
|
||||
def num_new_messages_since_with_users(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
|
||||
) -> int:
|
||||
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
|
||||
if not person_ids: # 保持空列表检查
|
||||
return 0
|
||||
filter_query = {
|
||||
"chat_id": chat_id,
|
||||
"time": {"$gt": timestamp_start, "$lt": timestamp_end},
|
||||
"user_id": {"$in": person_ids},
|
||||
}
|
||||
return count_messages(message_filter=filter_query)
|
||||
|
||||
|
||||
async def _build_readable_messages_internal(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
"""
|
||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||
|
||||
Args:
|
||||
messages: 消息字典列表。
|
||||
replace_bot_name: 是否将机器人的 user_id 替换为 "我"。
|
||||
merge_messages: 是否合并来自同一用户的连续消息。
|
||||
timestamp_mode: 时间戳的显示模式 ('relative', 'absolute', etc.)。传递给 translate_timestamp_to_human_readable。
|
||||
truncate: 是否根据消息的新旧程度截断过长的消息内容。
|
||||
|
||||
Returns:
|
||||
包含格式化消息的字符串和原始消息详情列表 (时间戳, 发送者名称, 内容) 的元组。
|
||||
"""
|
||||
if not messages:
|
||||
return "", []
|
||||
|
||||
message_details_raw: List[Tuple[float, str, str]] = []
|
||||
|
||||
# 1 & 2: 获取发送者信息并提取消息组件
|
||||
for msg in messages:
|
||||
user_info = msg.get("user_info", {})
|
||||
platform = user_info.get("platform")
|
||||
user_id = user_info.get("user_id")
|
||||
|
||||
user_nickname = user_info.get("user_nickname")
|
||||
user_cardname = user_info.get("user_cardname")
|
||||
|
||||
timestamp = msg.get("time")
|
||||
content = msg.get("processed_plain_text", "") # 默认空字符串
|
||||
|
||||
# 检查必要信息是否存在
|
||||
if not all([platform, user_id, timestamp is not None]):
|
||||
continue
|
||||
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
||||
if replace_bot_name and user_id == global_config.BOT_QQ:
|
||||
person_name = f"{global_config.BOT_NICKNAME}(你)"
|
||||
else:
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
||||
if not person_name:
|
||||
if user_cardname:
|
||||
person_name = f"昵称:{user_cardname}"
|
||||
elif user_nickname:
|
||||
person_name = f"{user_nickname}"
|
||||
else:
|
||||
person_name = "某人"
|
||||
|
||||
# 检查是否有 回复<aaa:bbb> 字段
|
||||
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
||||
match = re.search(reply_pattern, content)
|
||||
if match:
|
||||
aaa = match.group(1)
|
||||
bbb = match.group(2)
|
||||
reply_person_id = person_info_manager.get_person_id(platform, bbb)
|
||||
reply_person_name = await person_info_manager.get_value(reply_person_id, "person_name")
|
||||
if not reply_person_name:
|
||||
reply_person_name = aaa
|
||||
# 在内容前加上回复信息
|
||||
content = re.sub(reply_pattern, f"回复 {reply_person_name}", content, count=1)
|
||||
|
||||
# 检查是否有 @<aaa:bbb> 字段 @<{member_info.get('nickname')}:{member_info.get('user_id')}>
|
||||
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
|
||||
at_matches = list(re.finditer(at_pattern, content))
|
||||
if at_matches:
|
||||
new_content = ""
|
||||
last_end = 0
|
||||
for m in at_matches:
|
||||
new_content += content[last_end : m.start()]
|
||||
aaa = m.group(1)
|
||||
bbb = m.group(2)
|
||||
at_person_id = person_info_manager.get_person_id(platform, bbb)
|
||||
at_person_name = await person_info_manager.get_value(at_person_id, "person_name")
|
||||
if not at_person_name:
|
||||
at_person_name = aaa
|
||||
new_content += f"@{at_person_name}"
|
||||
last_end = m.end()
|
||||
new_content += content[last_end:]
|
||||
content = new_content
|
||||
|
||||
target_str = "这是QQ的一个功能,用于提及某人,但没那么明显"
|
||||
if target_str in content:
|
||||
if random.random() < 0.6:
|
||||
content = content.replace(target_str, "")
|
||||
|
||||
if content != "":
|
||||
message_details_raw.append((timestamp, person_name, content))
|
||||
|
||||
if not message_details_raw:
|
||||
return "", []
|
||||
|
||||
message_details_raw.sort(key=lambda x: x[0]) # 按时间戳(第一个元素)升序排序,越早的消息排在前面
|
||||
|
||||
# 应用截断逻辑 (如果 truncate 为 True)
|
||||
message_details: List[Tuple[float, str, str]] = []
|
||||
n_messages = len(message_details_raw)
|
||||
if truncate and n_messages > 0:
|
||||
for i, (timestamp, name, content) in enumerate(message_details_raw):
|
||||
percentile = i / n_messages # 计算消息在列表中的位置百分比 (0 <= percentile < 1)
|
||||
original_len = len(content)
|
||||
limit = -1 # 默认不截断
|
||||
|
||||
if percentile < 0.2: # 60% 之前的消息 (即最旧的 60%)
|
||||
limit = 50
|
||||
replace_content = "......(记不清了)"
|
||||
elif percentile < 0.5: # 60% 之前的消息 (即最旧的 60%)
|
||||
limit = 100
|
||||
replace_content = "......(有点记不清了)"
|
||||
elif percentile < 0.7: # 60% 到 80% 之前的消息 (即中间的 20%)
|
||||
limit = 200
|
||||
replace_content = "......(内容太长了)"
|
||||
elif percentile < 1.0: # 80% 到 100% 之前的消息 (即较新的 20%)
|
||||
limit = 300
|
||||
replace_content = "......(太长了)"
|
||||
|
||||
truncated_content = content
|
||||
if 0 < limit < original_len:
|
||||
truncated_content = f"{content[:limit]}{replace_content}"
|
||||
|
||||
message_details.append((timestamp, name, truncated_content))
|
||||
else:
|
||||
# 如果不截断,直接使用原始列表
|
||||
message_details = message_details_raw
|
||||
|
||||
# 3: 合并连续消息 (如果 merge_messages 为 True)
|
||||
merged_messages = []
|
||||
if merge_messages and message_details:
|
||||
# 初始化第一个合并块
|
||||
current_merge = {
|
||||
"name": message_details[0][1],
|
||||
"start_time": message_details[0][0],
|
||||
"end_time": message_details[0][0],
|
||||
"content": [message_details[0][2]],
|
||||
}
|
||||
|
||||
for i in range(1, len(message_details)):
|
||||
timestamp, name, content = message_details[i]
|
||||
# 如果是同一个人发送的连续消息且时间间隔小于等于60秒
|
||||
if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60):
|
||||
current_merge["content"].append(content)
|
||||
current_merge["end_time"] = timestamp # 更新最后消息时间
|
||||
else:
|
||||
# 保存上一个合并块
|
||||
merged_messages.append(current_merge)
|
||||
# 开始新的合并块
|
||||
current_merge = {"name": name, "start_time": timestamp, "end_time": timestamp, "content": [content]}
|
||||
# 添加最后一个合并块
|
||||
merged_messages.append(current_merge)
|
||||
elif message_details: # 如果不合并消息,则每个消息都是一个独立的块
|
||||
for timestamp, name, content in message_details:
|
||||
merged_messages.append(
|
||||
{
|
||||
"name": name,
|
||||
"start_time": timestamp, # 起始和结束时间相同
|
||||
"end_time": timestamp,
|
||||
"content": [content], # 内容只有一个元素
|
||||
}
|
||||
)
|
||||
|
||||
# 4 & 5: 格式化为字符串
|
||||
output_lines = []
|
||||
for _i, merged in enumerate(merged_messages):
|
||||
# 使用指定的 timestamp_mode 格式化时间
|
||||
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
|
||||
|
||||
header = f"{readable_time}{merged['name']} 说:"
|
||||
output_lines.append(header)
|
||||
# 将内容合并,并添加缩进
|
||||
for line in merged["content"]:
|
||||
stripped_line = line.strip()
|
||||
if stripped_line: # 过滤空行
|
||||
# 移除末尾句号,添加分号 - 这个逻辑似乎有点奇怪,暂时保留
|
||||
if stripped_line.endswith("。"):
|
||||
stripped_line = stripped_line[:-1]
|
||||
# 如果内容被截断,结尾已经是 ...(内容太长),不再添加分号
|
||||
if not stripped_line.endswith("(内容太长)"):
|
||||
output_lines.append(f"{stripped_line};")
|
||||
else:
|
||||
output_lines.append(stripped_line) # 直接添加截断后的内容
|
||||
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
|
||||
|
||||
# 移除可能的多余换行,然后合并
|
||||
formatted_string = "".join(output_lines).strip()
|
||||
|
||||
# 返回格式化后的字符串和 *应用截断后* 的 message_details 列表
|
||||
# 注意:如果外部调用者需要原始未截断的内容,可能需要调整返回策略
|
||||
return formatted_string, message_details
|
||||
|
||||
|
||||
async def build_readable_messages_with_list(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
"""
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
"""
|
||||
formatted_string, details_list = await _build_readable_messages_internal(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
return formatted_string, details_list
|
||||
|
||||
|
||||
async def build_readable_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
将消息列表转换为可读的文本格式。
|
||||
如果提供了 read_mark,则在相应位置插入已读标记。
|
||||
允许通过参数控制格式化行为。
|
||||
"""
|
||||
if read_mark <= 0:
|
||||
# 没有有效的 read_mark,直接格式化所有消息
|
||||
formatted_string, _ = await _build_readable_messages_internal(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
return formatted_string
|
||||
else:
|
||||
# 按 read_mark 分割消息
|
||||
messages_before_mark = [msg for msg in messages if msg.get("time", 0) <= read_mark]
|
||||
messages_after_mark = [msg for msg in messages if msg.get("time", 0) > read_mark]
|
||||
|
||||
# 分别格式化
|
||||
# 注意:这里决定对已读和未读部分都应用相同的 truncate 设置
|
||||
# 如果需要不同的行为(例如只截断已读部分),需要调整这里的调用
|
||||
formatted_before, _ = await _build_readable_messages_internal(
|
||||
messages_before_mark, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
formatted_after, _ = await _build_readable_messages_internal(
|
||||
messages_after_mark,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
timestamp_mode,
|
||||
)
|
||||
|
||||
readable_read_mark = translate_timestamp_to_human_readable(read_mark, mode=timestamp_mode)
|
||||
read_mark_line = f"\n--- 以上消息是你已经思考过的内容已读 (标记时间: {readable_read_mark}) ---\n--- 请关注以下未读的新消息---\n"
|
||||
|
||||
# 组合结果,确保空部分不引入多余的标记或换行
|
||||
if formatted_before and formatted_after:
|
||||
return f"{formatted_before}{read_mark_line}{formatted_after}"
|
||||
elif formatted_before:
|
||||
return f"{formatted_before}{read_mark_line}"
|
||||
elif formatted_after:
|
||||
return f"{read_mark_line}{formatted_after}"
|
||||
else:
|
||||
# 理论上不应该发生,但作为保险
|
||||
return read_mark_line.strip() # 如果前后都无消息,只返回标记行
|
||||
|
||||
|
||||
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
"""
|
||||
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。
|
||||
|
||||
Args:
|
||||
messages: 消息字典列表。
|
||||
|
||||
Returns:
|
||||
一个包含唯一 person_id 的列表。
|
||||
"""
|
||||
person_ids_set = set() # 使用集合来自动去重
|
||||
|
||||
for msg in messages:
|
||||
user_info = msg.get("user_info", {})
|
||||
platform = user_info.get("platform")
|
||||
user_id = user_info.get("user_id")
|
||||
|
||||
# 检查必要信息是否存在 且 不是机器人自己
|
||||
if not all([platform, user_id]) or user_id == global_config.BOT_QQ:
|
||||
continue
|
||||
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
|
||||
# 只有当获取到有效 person_id 时才添加
|
||||
if person_id:
|
||||
person_ids_set.add(person_id)
|
||||
|
||||
return list(person_ids_set) # 将集合转换为列表返回
|
||||
234
src/chat/utils/info_catcher.py
Normal file
234
src/chat/utils/info_catcher.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending, Message
|
||||
from src.common.database import db
|
||||
import time
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
|
||||
class InfoCatcher:
|
||||
def __init__(self):
|
||||
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~
|
||||
self.context_length = global_config.observation_context_size
|
||||
self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~
|
||||
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~
|
||||
|
||||
self.chat_id = ""
|
||||
self.trigger_response_text = ""
|
||||
self.response_text = ""
|
||||
|
||||
self.trigger_response_time = 0
|
||||
self.trigger_response_message = None
|
||||
|
||||
self.response_time = 0
|
||||
self.response_messages = []
|
||||
|
||||
# 使用字典来存储 heartflow 模式的数据
|
||||
self.heartflow_data = {
|
||||
"heart_flow_prompt": "",
|
||||
"sub_heartflow_before": "",
|
||||
"sub_heartflow_now": "",
|
||||
"sub_heartflow_after": "",
|
||||
"sub_heartflow_model": "",
|
||||
"prompt": "",
|
||||
"response": "",
|
||||
"model": "",
|
||||
}
|
||||
|
||||
# 使用字典来存储 reasoning 模式的数据喵~
|
||||
self.reasoning_data = {"thinking_log": "", "prompt": "", "response": "", "model": ""}
|
||||
|
||||
# 耗时喵~
|
||||
self.timing_results = {
|
||||
"interested_rate_time": 0,
|
||||
"sub_heartflow_observe_time": 0,
|
||||
"sub_heartflow_step_time": 0,
|
||||
"make_response_time": 0,
|
||||
}
|
||||
|
||||
def catch_decide_to_response(self, message: MessageRecv):
|
||||
# 搜集决定回复时的信息
|
||||
self.trigger_response_message = message
|
||||
self.trigger_response_text = message.detailed_plain_text
|
||||
|
||||
self.trigger_response_time = time.time()
|
||||
|
||||
self.chat_id = message.chat_stream.stream_id
|
||||
|
||||
self.chat_history = self.get_message_from_db_before_msg(message)
|
||||
|
||||
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
|
||||
self.timing_results["sub_heartflow_observe_time"] = obs_duration
|
||||
|
||||
# def catch_shf
|
||||
|
||||
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
|
||||
self.timing_results["sub_heartflow_step_time"] = step_duration
|
||||
if len(past_mind) > 1:
|
||||
self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
|
||||
self.heartflow_data["sub_heartflow_now"] = current_mind
|
||||
else:
|
||||
self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
|
||||
self.heartflow_data["sub_heartflow_now"] = current_mind
|
||||
|
||||
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
|
||||
# if self.response_mode == "heart_flow": # 条件判断不需要了喵~
|
||||
# self.heartflow_data["prompt"] = prompt
|
||||
# self.heartflow_data["response"] = response
|
||||
# self.heartflow_data["model"] = model_name
|
||||
# elif self.response_mode == "reasoning": # 条件判断不需要了喵~
|
||||
# self.reasoning_data["thinking_log"] = reasoning_content
|
||||
# self.reasoning_data["prompt"] = prompt
|
||||
# self.reasoning_data["response"] = response
|
||||
# self.reasoning_data["model"] = model_name
|
||||
|
||||
# 直接记录信息喵~
|
||||
self.reasoning_data["thinking_log"] = reasoning_content
|
||||
self.reasoning_data["prompt"] = prompt
|
||||
self.reasoning_data["response"] = response
|
||||
self.reasoning_data["model"] = model_name
|
||||
# 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~
|
||||
# self.heartflow_data["prompt"] = prompt
|
||||
# self.heartflow_data["response"] = response
|
||||
# self.heartflow_data["model"] = model_name
|
||||
|
||||
self.response_text = response
|
||||
|
||||
def catch_after_generate_response(self, response_duration: float):
|
||||
self.timing_results["make_response_time"] = response_duration
|
||||
|
||||
def catch_after_response(
|
||||
self, response_duration: float, response_message: List[str], first_bot_msg: MessageSending
|
||||
):
|
||||
self.timing_results["make_response_time"] = response_duration
|
||||
self.response_time = time.time()
|
||||
for msg in response_message:
|
||||
self.response_messages.append(msg)
|
||||
|
||||
self.chat_history_in_thinking = self.get_message_from_db_between_msgs(
|
||||
self.trigger_response_message, first_bot_msg
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
|
||||
try:
|
||||
# 从数据库中获取消息的时间戳
|
||||
time_start = message_start.message_info.time
|
||||
time_end = message_end.message_info.time
|
||||
chat_id = message_start.chat_stream.stream_id
|
||||
|
||||
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
|
||||
|
||||
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
|
||||
messages_between = db.messages.find(
|
||||
{"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}}
|
||||
).sort("time", -1)
|
||||
|
||||
result = list(messages_between)
|
||||
print(f"查询结果数量: {len(result)}")
|
||||
if result:
|
||||
print(f"第一条消息时间: {result[0]['time']}")
|
||||
print(f"最后一条消息时间: {result[-1]['time']}")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"获取消息时出错: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_message_from_db_before_msg(self, message: MessageRecv):
|
||||
# 从数据库中获取消息
|
||||
message_id = message.message_info.message_id
|
||||
chat_id = message.chat_stream.stream_id
|
||||
|
||||
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
|
||||
messages_before = (
|
||||
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
|
||||
.sort("time", -1)
|
||||
.limit(self.context_length * 3)
|
||||
) # 获取更多历史信息
|
||||
|
||||
return list(messages_before)
|
||||
|
||||
def message_list_to_dict(self, message_list):
|
||||
# 存储简化的聊天记录
|
||||
result = []
|
||||
for message in message_list:
|
||||
if not isinstance(message, dict):
|
||||
message = self.message_to_dict(message)
|
||||
# print(message)
|
||||
|
||||
lite_message = {
|
||||
"time": message["time"],
|
||||
"user_nickname": message["user_info"]["user_nickname"],
|
||||
"processed_plain_text": message["processed_plain_text"],
|
||||
}
|
||||
result.append(lite_message)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def message_to_dict(message):
|
||||
if not message:
|
||||
return None
|
||||
if isinstance(message, dict):
|
||||
return message
|
||||
return {
|
||||
# "message_id": message.message_info.message_id,
|
||||
"time": message.message_info.time,
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"user_nickname": message.message_info.user_info.user_nickname,
|
||||
"processed_plain_text": message.processed_plain_text,
|
||||
# "detailed_plain_text": message.detailed_plain_text
|
||||
}
|
||||
|
||||
def done_catch(self):
|
||||
"""将收集到的信息存储到数据库的 thinking_log 集合中喵~"""
|
||||
try:
|
||||
# 将消息对象转换为可序列化的字典喵~
|
||||
|
||||
thinking_log_data = {
|
||||
"chat_id": self.chat_id,
|
||||
"trigger_text": self.trigger_response_text,
|
||||
"response_text": self.response_text,
|
||||
"trigger_info": {
|
||||
"time": self.trigger_response_time,
|
||||
"message": self.message_to_dict(self.trigger_response_message),
|
||||
},
|
||||
"response_info": {
|
||||
"time": self.response_time,
|
||||
"message": self.response_messages,
|
||||
},
|
||||
"timing_results": self.timing_results,
|
||||
"chat_history": self.message_list_to_dict(self.chat_history),
|
||||
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
|
||||
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
|
||||
"heartflow_data": self.heartflow_data,
|
||||
"reasoning_data": self.reasoning_data,
|
||||
}
|
||||
|
||||
# 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~
|
||||
# if self.response_mode == "heart_flow":
|
||||
# thinking_log_data["mode_specific_data"] = self.heartflow_data
|
||||
# elif self.response_mode == "reasoning":
|
||||
# thinking_log_data["mode_specific_data"] = self.reasoning_data
|
||||
|
||||
# 将数据插入到 thinking_log 集合中喵~
|
||||
db.thinking_log.insert_one(thinking_log_data)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"存储思考日志时出错: {str(e)} 喵~")
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
class InfoCatcherManager:
|
||||
def __init__(self):
|
||||
self.info_catchers = {}
|
||||
|
||||
def get_info_catcher(self, thinking_id: str) -> InfoCatcher:
|
||||
if thinking_id not in self.info_catchers:
|
||||
self.info_catchers[thinking_id] = InfoCatcher()
|
||||
return self.info_catchers[thinking_id]
|
||||
|
||||
|
||||
info_catcher_manager = InfoCatcherManager()
|
||||
226
src/chat/utils/json_utils.py
Normal file
226
src/chat/utils/json_utils.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, TypeVar, List, Union, Tuple
|
||||
import ast
|
||||
|
||||
# 定义类型变量用于泛型类型提示
|
||||
T = TypeVar("T")
|
||||
|
||||
# 获取logger
|
||||
logger = logging.getLogger("json_utils")
|
||||
|
||||
|
||||
def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
|
||||
"""
|
||||
安全地解析JSON字符串,出错时返回默认值
|
||||
现在尝试处理单引号和标准JSON
|
||||
|
||||
参数:
|
||||
json_str: 要解析的JSON字符串
|
||||
default_value: 解析失败时返回的默认值
|
||||
|
||||
返回:
|
||||
解析后的Python对象,或在解析失败时返回default_value
|
||||
"""
|
||||
if not json_str or not isinstance(json_str, str):
|
||||
logger.warning(f"safe_json_loads 接收到非字符串输入: {type(json_str)}, 值: {json_str}")
|
||||
return default_value
|
||||
|
||||
try:
|
||||
# 尝试标准的 JSON 解析
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果标准解析失败,尝试将单引号替换为双引号再解析
|
||||
# (注意:这种替换可能不安全,如果字符串内容本身包含引号)
|
||||
# 更安全的方式是用 ast.literal_eval
|
||||
try:
|
||||
# logger.debug(f"标准JSON解析失败,尝试用 ast.literal_eval 解析: {json_str[:100]}...")
|
||||
result = ast.literal_eval(json_str)
|
||||
# 确保结果是字典(因为我们通常期望参数是字典)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
else:
|
||||
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
|
||||
return default_value
|
||||
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
|
||||
logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...")
|
||||
return default_value
|
||||
except Exception as e:
|
||||
logger.error(f"使用 ast.literal_eval 解析时发生意外错误: {e}, 字符串: {json_str[:100]}...")
|
||||
return default_value
|
||||
except Exception as e:
|
||||
logger.error(f"JSON解析过程中发生意外错误: {e}, 字符串: {json_str[:100]}...")
|
||||
return default_value
|
||||
|
||||
|
||||
def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
从LLM工具调用对象中提取参数
|
||||
|
||||
参数:
|
||||
tool_call: 工具调用对象字典
|
||||
default_value: 解析失败时返回的默认值
|
||||
|
||||
返回:
|
||||
解析后的参数字典,或在解析失败时返回default_value
|
||||
"""
|
||||
default_result = default_value or {}
|
||||
|
||||
if not tool_call or not isinstance(tool_call, dict):
|
||||
logger.error(f"无效的工具调用对象: {tool_call}")
|
||||
return default_result
|
||||
|
||||
try:
|
||||
# 提取function参数
|
||||
function_data = tool_call.get("function", {})
|
||||
if not function_data or not isinstance(function_data, dict):
|
||||
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
|
||||
return default_result
|
||||
|
||||
# 提取arguments
|
||||
arguments_str = function_data.get("arguments", "{}")
|
||||
if not arguments_str:
|
||||
return default_result
|
||||
|
||||
# 解析JSON
|
||||
return safe_json_loads(arguments_str, default_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取工具调用参数时出错: {e}")
|
||||
return default_result
|
||||
|
||||
|
||||
def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str:
|
||||
"""
|
||||
安全地将Python对象序列化为JSON字符串
|
||||
|
||||
参数:
|
||||
obj: 要序列化的Python对象
|
||||
default_value: 序列化失败时返回的默认值
|
||||
ensure_ascii: 是否确保ASCII编码(默认False,允许中文等非ASCII字符)
|
||||
pretty: 是否美化输出JSON
|
||||
|
||||
返回:
|
||||
序列化后的JSON字符串,或在序列化失败时返回default_value
|
||||
"""
|
||||
try:
|
||||
indent = 2 if pretty else None
|
||||
return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent)
|
||||
except TypeError as e:
|
||||
logger.error(f"JSON序列化失败(类型错误): {e}")
|
||||
return default_value
|
||||
except Exception as e:
|
||||
logger.error(f"JSON序列化过程中发生意外错误: {e}")
|
||||
return default_value
|
||||
|
||||
|
||||
def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]:
|
||||
"""
|
||||
标准化LLM响应格式,将各种格式(如元组)转换为统一的列表格式
|
||||
|
||||
参数:
|
||||
response: 原始LLM响应
|
||||
log_prefix: 日志前缀
|
||||
|
||||
返回:
|
||||
元组 (成功标志, 标准化后的响应列表, 错误消息)
|
||||
"""
|
||||
|
||||
logger.debug(f"{log_prefix}原始人 LLM响应: {response}")
|
||||
|
||||
# 检查是否为None
|
||||
if response is None:
|
||||
return False, [], "LLM响应为None"
|
||||
|
||||
# 记录原始类型
|
||||
logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}")
|
||||
|
||||
# 将元组转换为列表
|
||||
if isinstance(response, tuple):
|
||||
logger.debug(f"{log_prefix}将元组响应转换为列表")
|
||||
response = list(response)
|
||||
|
||||
# 确保是列表类型
|
||||
if not isinstance(response, list):
|
||||
return False, [], f"无法处理的LLM响应类型: {type(response).__name__}"
|
||||
|
||||
# 处理工具调用部分(如果存在)
|
||||
if len(response) == 3:
|
||||
content, reasoning, tool_calls = response
|
||||
|
||||
# 将工具调用部分转换为列表(如果是元组)
|
||||
if isinstance(tool_calls, tuple):
|
||||
logger.debug(f"{log_prefix}将工具调用元组转换为列表")
|
||||
tool_calls = list(tool_calls)
|
||||
response[2] = tool_calls
|
||||
|
||||
return True, response, ""
|
||||
|
||||
|
||||
def process_llm_tool_calls(
|
||||
tool_calls: List[Dict[str, Any]], log_prefix: str = ""
|
||||
) -> Tuple[bool, List[Dict[str, Any]], str]:
|
||||
"""
|
||||
处理并验证LLM响应中的工具调用列表
|
||||
|
||||
参数:
|
||||
tool_calls: 从LLM响应中直接获取的工具调用列表
|
||||
log_prefix: 日志前缀
|
||||
|
||||
返回:
|
||||
元组 (成功标志, 验证后的工具调用列表, 错误消息)
|
||||
"""
|
||||
|
||||
# 如果列表为空,表示没有工具调用,这不是错误
|
||||
if not tool_calls:
|
||||
return True, [], "工具调用列表为空"
|
||||
|
||||
# 验证每个工具调用的格式
|
||||
valid_tool_calls = []
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
if not isinstance(tool_call, dict):
|
||||
logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}, 内容: {tool_call}")
|
||||
continue
|
||||
|
||||
# 检查基本结构
|
||||
if tool_call.get("type") != "function":
|
||||
logger.warning(
|
||||
f"{log_prefix}工具调用[{i}]不是function类型: type={tool_call.get('type', '未定义')}, 内容: {tool_call}"
|
||||
)
|
||||
continue
|
||||
|
||||
if "function" not in tool_call or not isinstance(tool_call.get("function"), dict):
|
||||
logger.warning(f"{log_prefix}工具调用[{i}]缺少'function'字段或其类型不正确: {tool_call}")
|
||||
continue
|
||||
|
||||
func_details = tool_call["function"]
|
||||
if "name" not in func_details or not isinstance(func_details.get("name"), str):
|
||||
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'name'或类型不正确: {func_details}")
|
||||
continue
|
||||
|
||||
# 验证参数 'arguments'
|
||||
args_value = func_details.get("arguments")
|
||||
|
||||
# 1. 检查 arguments 是否存在且是字符串
|
||||
if args_value is None or not isinstance(args_value, str):
|
||||
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'arguments'字符串: {func_details}")
|
||||
continue
|
||||
|
||||
# 2. 尝试安全地解析 arguments 字符串
|
||||
parsed_args = safe_json_loads(args_value, None)
|
||||
|
||||
# 3. 检查解析结果是否为字典
|
||||
if parsed_args is None or not isinstance(parsed_args, dict):
|
||||
logger.warning(
|
||||
f"{log_prefix}工具调用[{i}]的'arguments'无法解析为有效的JSON字典, "
|
||||
f"原始字符串: {args_value[:100]}..., 解析结果类型: {type(parsed_args).__name__}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 如果检查通过,将原始的 tool_call 加入有效列表
|
||||
valid_tool_calls.append(tool_call)
|
||||
|
||||
if not valid_tool_calls and tool_calls: # 如果原始列表不为空,但验证后为空
|
||||
return False, [], "所有工具调用格式均无效"
|
||||
|
||||
return True, valid_tool_calls, ""
|
||||
88
src/chat/utils/logger_config.py
Normal file
88
src/chat/utils/logger_config.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import sys
|
||||
import loguru
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LogClassification(Enum):
|
||||
BASE = "base"
|
||||
MEMORY = "memory"
|
||||
EMOJI = "emoji"
|
||||
CHAT = "chat"
|
||||
PBUILDER = "promptbuilder"
|
||||
|
||||
|
||||
class LogModule:
|
||||
logger = loguru.logger.opt()
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def setup_logger(self, log_type: LogClassification):
|
||||
"""配置日志格式
|
||||
|
||||
Args:
|
||||
log_type: 日志类型,可选值:BASE(基础日志)、MEMORY(记忆系统日志)、EMOJI(表情包系统日志)
|
||||
"""
|
||||
# 移除默认日志处理器
|
||||
self.logger.remove()
|
||||
|
||||
# 基础日志格式
|
||||
base_format = (
|
||||
"<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | "
|
||||
" d<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
chat_format = (
|
||||
"<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
# 记忆系统日志格式
|
||||
memory_format = (
|
||||
"<green>{time:HH:mm}</green> | <level>{level: <8}</level> | "
|
||||
"<light-magenta>海马体</light-magenta> | <level>{message}</level>"
|
||||
)
|
||||
|
||||
# 表情包系统日志格式
|
||||
emoji_format = (
|
||||
"<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>表情包</yellow> | "
|
||||
"<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
promptbuilder_format = (
|
||||
"<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>Prompt</yellow> | "
|
||||
"<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
# 根据日志类型选择日志格式和输出
|
||||
if log_type == LogClassification.CHAT:
|
||||
self.logger.add(
|
||||
sys.stderr,
|
||||
format=chat_format,
|
||||
# level="INFO"
|
||||
)
|
||||
elif log_type == LogClassification.PBUILDER:
|
||||
self.logger.add(
|
||||
sys.stderr,
|
||||
format=promptbuilder_format,
|
||||
# level="INFO"
|
||||
)
|
||||
elif log_type == LogClassification.MEMORY:
|
||||
# 同时输出到控制台和文件
|
||||
self.logger.add(
|
||||
sys.stderr,
|
||||
format=memory_format,
|
||||
# level="INFO"
|
||||
)
|
||||
self.logger.add("logs/memory.log", format=memory_format, level="INFO", rotation="1 day", retention="7 days")
|
||||
elif log_type == LogClassification.EMOJI:
|
||||
self.logger.add(
|
||||
sys.stderr,
|
||||
format=emoji_format,
|
||||
# level="INFO"
|
||||
)
|
||||
self.logger.add("logs/emoji.log", format=emoji_format, level="INFO", rotation="1 day", retention="7 days")
|
||||
else: # BASE
|
||||
self.logger.add(sys.stderr, format=base_format, level="INFO")
|
||||
|
||||
return self.logger
|
||||
237
src/chat/utils/prompt_builder.py
Normal file
237
src/chat/utils/prompt_builder.py
Normal file
@@ -0,0 +1,237 @@
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
import re
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
# import traceback
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_module_logger("prompt_build")
|
||||
|
||||
|
||||
class PromptContext:
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
self._current_context: Optional[str] = None
|
||||
self._context_lock = asyncio.Lock() # 添加异步锁
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: str):
|
||||
"""创建一个异步的临时提示模板作用域"""
|
||||
async with self._context_lock:
|
||||
if context_id not in self._context_prompts:
|
||||
self._context_prompts[context_id] = {}
|
||||
|
||||
previous_context = self._current_context
|
||||
self._current_context = context_id
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
async with self._context_lock:
|
||||
self._current_context = previous_context
|
||||
|
||||
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
|
||||
"""异步获取当前作用域中的提示模板"""
|
||||
async with self._context_lock:
|
||||
if self._current_context and name in self._context_prompts[self._current_context]:
|
||||
return self._context_prompts[self._current_context][name]
|
||||
return None
|
||||
|
||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
target_context = context_id or self._current_context
|
||||
if target_context:
|
||||
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
|
||||
|
||||
|
||||
class PromptManager:
|
||||
def __init__(self):
|
||||
self._prompts = {}
|
||||
self._counter = 0
|
||||
self._context = PromptContext()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_message_scope(self, message_id: str):
|
||||
"""为消息处理创建异步临时作用域"""
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
|
||||
async def get_prompt_async(self, name: str) -> "Prompt":
|
||||
# 首先尝试从当前上下文获取
|
||||
context_prompt = await self._context.get_prompt_async(name)
|
||||
if context_prompt is not None:
|
||||
return context_prompt
|
||||
# 如果上下文中不存在,则使用全局提示模板
|
||||
async with self._lock:
|
||||
if name not in self._prompts:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
def generate_name(self, template: str) -> str:
|
||||
"""为未命名的prompt生成名称"""
|
||||
self._counter += 1
|
||||
return f"prompt_{self._counter}"
|
||||
|
||||
def register(self, prompt: "Prompt") -> None:
|
||||
"""注册一个prompt"""
|
||||
if not prompt.name:
|
||||
prompt.name = self.generate_name(prompt.template)
|
||||
self._prompts[prompt.name] = prompt
|
||||
|
||||
def add_prompt(self, name: str, fstr: str) -> "Prompt":
|
||||
prompt = Prompt(fstr, name=name)
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||
prompt = await self.get_prompt_async(name)
|
||||
return prompt.format(**kwargs)
|
||||
|
||||
|
||||
# 全局单例
|
||||
global_prompt_manager = PromptManager()
|
||||
|
||||
|
||||
class Prompt(str):
|
||||
# 临时标记,作为类常量
|
||||
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
|
||||
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
|
||||
|
||||
@staticmethod
|
||||
def _process_escaped_braces(template: str) -> str:
|
||||
"""处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记"""
|
||||
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
|
||||
|
||||
@staticmethod
|
||||
def _restore_escaped_braces(template: str) -> str:
|
||||
"""将临时标记还原为实际的花括号字符"""
|
||||
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
|
||||
|
||||
def __new__(cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs):
|
||||
# 如果传入的是元组,转换为列表
|
||||
if isinstance(args, tuple):
|
||||
args = list(args)
|
||||
should_register = kwargs.pop("_should_register", True)
|
||||
|
||||
# 预处理模板中的转义花括号
|
||||
processed_fstr = cls._process_escaped_braces(fstr)
|
||||
|
||||
# 解析模板
|
||||
template_args = []
|
||||
result = re.findall(r"\{(.*?)}", processed_fstr)
|
||||
for expr in result:
|
||||
if expr and expr not in template_args:
|
||||
template_args.append(expr)
|
||||
|
||||
# 如果提供了初始参数,立即格式化
|
||||
if kwargs or args:
|
||||
formatted = cls._format_template(fstr, args=args, kwargs=kwargs)
|
||||
obj = super().__new__(cls, formatted)
|
||||
else:
|
||||
obj = super().__new__(cls, "")
|
||||
|
||||
obj.template = fstr
|
||||
obj.name = name
|
||||
obj.args = template_args
|
||||
obj._args = args or []
|
||||
obj._kwargs = kwargs
|
||||
|
||||
# 修改自动注册逻辑
|
||||
if should_register:
|
||||
if global_prompt_manager._context._current_context:
|
||||
# 如果存在当前上下文,则注册到上下文中
|
||||
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
|
||||
pass
|
||||
else:
|
||||
# 否则注册到全局管理器
|
||||
global_prompt_manager.register(obj)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs
|
||||
):
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = cls(fstr, name, args, **kwargs)
|
||||
if global_prompt_manager._context._current_context:
|
||||
await global_prompt_manager._context.register_async(prompt)
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
|
||||
# 预处理模板中的转义花括号
|
||||
processed_template = cls._process_escaped_braces(template)
|
||||
|
||||
template_args = []
|
||||
result = re.findall(r"\{(.*?)}", processed_template)
|
||||
for expr in result:
|
||||
if expr and expr not in template_args:
|
||||
template_args.append(expr)
|
||||
formatted_args = {}
|
||||
formatted_kwargs = {}
|
||||
|
||||
# 处理位置参数
|
||||
if args:
|
||||
# print(len(template_args), len(args), template_args, args)
|
||||
for i in range(len(args)):
|
||||
if i < len(template_args):
|
||||
arg = args[i]
|
||||
if isinstance(arg, Prompt):
|
||||
formatted_args[template_args[i]] = arg.format(**kwargs)
|
||||
else:
|
||||
formatted_args[template_args[i]] = arg
|
||||
else:
|
||||
logger.error(
|
||||
f"构建提示词模板失败,解析到的参数列表{template_args},长度为{len(template_args)},输入的参数列表为{args},提示词模板为{template}"
|
||||
)
|
||||
raise ValueError("格式化模板失败")
|
||||
|
||||
# 处理关键字参数
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, Prompt):
|
||||
remaining_kwargs = {k: v for k, v in kwargs.items() if k != key}
|
||||
formatted_kwargs[key] = value.format(**remaining_kwargs)
|
||||
else:
|
||||
formatted_kwargs[key] = value
|
||||
|
||||
try:
|
||||
# 先用位置参数格式化
|
||||
if args:
|
||||
processed_template = processed_template.format(**formatted_args)
|
||||
# 再用关键字参数格式化
|
||||
if kwargs:
|
||||
processed_template = processed_template.format(**formatted_kwargs)
|
||||
|
||||
# 将临时标记还原为实际的花括号
|
||||
result = cls._restore_escaped_braces(processed_template)
|
||||
return result
|
||||
except (IndexError, KeyError) as e:
|
||||
raise ValueError(
|
||||
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
|
||||
) from e
|
||||
|
||||
def format(self, *args, **kwargs) -> "str":
|
||||
"""支持位置参数和关键字参数的格式化,使用"""
|
||||
ret = type(self)(
|
||||
self.template,
|
||||
self.name,
|
||||
args=list(args) if args else self._args,
|
||||
_should_register=False,
|
||||
**kwargs if kwargs else self._kwargs,
|
||||
)
|
||||
# print(f"prompt build result: {ret} name: {ret.name} ")
|
||||
return str(ret)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self._kwargs or self._args:
|
||||
return super().__str__()
|
||||
return self.template
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Prompt(template='{self.template}', name='{self.name}')"
|
||||
760
src/chat/utils/statistic.py
Normal file
760
src/chat/utils/statistic.py
Normal file
@@ -0,0 +1,760 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Tuple, List
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
|
||||
from ...common.database import db
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_module_logger("maibot_statistic")
|
||||
|
||||
# 统计数据的键
|
||||
TOTAL_REQ_CNT = "total_requests"
|
||||
TOTAL_COST = "total_cost"
|
||||
REQ_CNT_BY_TYPE = "requests_by_type"
|
||||
REQ_CNT_BY_USER = "requests_by_user"
|
||||
REQ_CNT_BY_MODEL = "requests_by_model"
|
||||
IN_TOK_BY_TYPE = "in_tokens_by_type"
|
||||
IN_TOK_BY_USER = "in_tokens_by_user"
|
||||
IN_TOK_BY_MODEL = "in_tokens_by_model"
|
||||
OUT_TOK_BY_TYPE = "out_tokens_by_type"
|
||||
OUT_TOK_BY_USER = "out_tokens_by_user"
|
||||
OUT_TOK_BY_MODEL = "out_tokens_by_model"
|
||||
TOTAL_TOK_BY_TYPE = "tokens_by_type"
|
||||
TOTAL_TOK_BY_USER = "tokens_by_user"
|
||||
TOTAL_TOK_BY_MODEL = "tokens_by_model"
|
||||
COST_BY_TYPE = "costs_by_type"
|
||||
COST_BY_USER = "costs_by_user"
|
||||
COST_BY_MODEL = "costs_by_model"
|
||||
ONLINE_TIME = "online_time"
|
||||
TOTAL_MSG_CNT = "total_messages"
|
||||
MSG_CNT_BY_CHAT = "messages_by_chat"
|
||||
|
||||
|
||||
class OnlineTimeRecordTask(AsyncTask):
|
||||
"""在线时间记录任务"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(task_name="Online Time Record Task", run_interval=60)
|
||||
|
||||
self.record_id: str | None = None
|
||||
"""记录ID"""
|
||||
|
||||
self._init_database() # 初始化数据库
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库"""
|
||||
if "online_time" not in db.list_collection_names():
|
||||
# 初始化数据库(在线时长)
|
||||
db.create_collection("online_time")
|
||||
# 创建索引
|
||||
if ("end_timestamp", 1) not in db.online_time.list_indexes():
|
||||
db.online_time.create_index([("end_timestamp", 1)])
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
if self.record_id:
|
||||
# 如果有记录,则更新结束时间
|
||||
db.online_time.update_one(
|
||||
{"_id": self.record_id},
|
||||
{
|
||||
"$set": {
|
||||
"end_timestamp": datetime.now() + timedelta(minutes=1),
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
# 如果没有记录,检查一分钟以内是否已有记录
|
||||
current_time = datetime.now()
|
||||
recent_record = db.online_time.find_one(
|
||||
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}}
|
||||
)
|
||||
|
||||
if not recent_record:
|
||||
# 若没有记录,则插入新的在线时间记录
|
||||
self.record_id = db.online_time.insert_one(
|
||||
{
|
||||
"start_timestamp": current_time,
|
||||
"end_timestamp": current_time + timedelta(minutes=1),
|
||||
}
|
||||
).inserted_id
|
||||
else:
|
||||
# 如果有记录,则更新结束时间
|
||||
self.record_id = recent_record["_id"]
|
||||
db.online_time.update_one(
|
||||
{"_id": self.record_id},
|
||||
{
|
||||
"$set": {
|
||||
"end_timestamp": current_time + timedelta(minutes=1),
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("在线时间记录失败")
|
||||
|
||||
|
||||
def _format_online_time(online_seconds: int) -> str:
|
||||
"""
|
||||
格式化在线时间
|
||||
:param online_seconds: 在线时间(秒)
|
||||
:return: 格式化后的在线时间字符串
|
||||
"""
|
||||
total_oneline_time = timedelta(seconds=online_seconds)
|
||||
|
||||
days = total_oneline_time.days
|
||||
hours = total_oneline_time.seconds // 3600
|
||||
minutes = (total_oneline_time.seconds // 60) % 60
|
||||
seconds = total_oneline_time.seconds % 60
|
||||
if days > 0:
|
||||
# 如果在线时间超过1天,则格式化为"X天X小时X分钟"
|
||||
total_oneline_time_str = f"{total_oneline_time.days}天{hours}小时{minutes}分钟{seconds}秒"
|
||||
elif hours > 0:
|
||||
# 如果在线时间超过1小时,则格式化为"X小时X分钟X秒"
|
||||
total_oneline_time_str = f"{hours}小时{minutes}分钟{seconds}秒"
|
||||
else:
|
||||
# 其他情况格式化为"X分钟X秒"
|
||||
total_oneline_time_str = f"{minutes}分钟{seconds}秒"
|
||||
|
||||
return total_oneline_time_str
|
||||
|
||||
|
||||
class StatisticOutputTask(AsyncTask):
|
||||
"""统计输出任务"""
|
||||
|
||||
SEP_LINE = "-" * 84
|
||||
|
||||
def __init__(self, record_file_path: str = "maibot_statistics.html"):
|
||||
# 延迟300秒启动,运行间隔300秒
|
||||
super().__init__(task_name="Statistics Data Output Task", wait_before_start=0, run_interval=300)
|
||||
|
||||
self.name_mapping: Dict[str, Tuple[str, float]] = {}
|
||||
"""
|
||||
联系人/群聊名称映射 {聊天ID: (联系人/群聊名称, 记录时间(timestamp))}
|
||||
注:设计记录时间的目的是方便更新名称,使联系人/群聊名称保持最新
|
||||
"""
|
||||
|
||||
self.record_file_path: str = record_file_path
|
||||
"""
|
||||
记录文件路径
|
||||
"""
|
||||
|
||||
now = datetime.now()
|
||||
if "deploy_time" in local_storage:
|
||||
# 如果存在部署时间,则使用该时间作为全量统计的起始时间
|
||||
deploy_time = datetime.fromtimestamp(local_storage["deploy_time"])
|
||||
else:
|
||||
# 否则,使用最大时间范围,并记录部署时间为当前时间
|
||||
deploy_time = datetime(2000, 1, 1)
|
||||
local_storage["deploy_time"] = now.timestamp()
|
||||
|
||||
self.stat_period: List[Tuple[str, timedelta, str]] = [
|
||||
("all_time", now - deploy_time, "自部署以来"), # 必须保留"all_time"
|
||||
("last_7_days", timedelta(days=7), "最近7天"),
|
||||
("last_24_hours", timedelta(days=1), "最近24小时"),
|
||||
("last_hour", timedelta(hours=1), "最近1小时"),
|
||||
]
|
||||
"""
|
||||
统计时间段 [(统计名称, 统计时间段, 统计描述), ...]
|
||||
"""
|
||||
|
||||
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
|
||||
"""
|
||||
输出统计数据到控制台
|
||||
:param stats: 统计数据
|
||||
:param now: 基准当前时间
|
||||
"""
|
||||
# 输出最近一小时的统计数据
|
||||
|
||||
output = [
|
||||
self.SEP_LINE,
|
||||
f" 最近1小时的统计数据 (自{now.strftime('%Y-%m-%d %H:%M:%S')}开始,详细信息见文件:{self.record_file_path})",
|
||||
self.SEP_LINE,
|
||||
self._format_total_stat(stats["last_hour"]),
|
||||
"",
|
||||
self._format_model_classified_stat(stats["last_hour"]),
|
||||
"",
|
||||
self._format_chat_stat(stats["last_hour"]),
|
||||
self.SEP_LINE,
|
||||
"",
|
||||
]
|
||||
|
||||
logger.info("\n" + "\n".join(output))
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
now = datetime.now()
|
||||
# 收集统计数据
|
||||
stats = self._collect_all_statistics(now)
|
||||
|
||||
# 输出统计数据到控制台
|
||||
self._statistic_console_output(stats, now)
|
||||
# 输出统计数据到html文件
|
||||
self._generate_html_report(stats, now)
|
||||
except Exception as e:
|
||||
logger.exception(f"输出统计数据过程中发生异常,错误信息:{e}")
|
||||
|
||||
# -- 以下为统计数据收集方法 --
|
||||
|
||||
@staticmethod
|
||||
def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的LLM请求统计数据
|
||||
|
||||
:param collect_period: 统计时间段
|
||||
"""
|
||||
if len(collect_period) <= 0:
|
||||
return {}
|
||||
else:
|
||||
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
|
||||
collect_period.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
stats = {
|
||||
period_key: {
|
||||
# 总LLM请求数
|
||||
TOTAL_REQ_CNT: 0,
|
||||
# 请求次数统计
|
||||
REQ_CNT_BY_TYPE: defaultdict(int),
|
||||
REQ_CNT_BY_USER: defaultdict(int),
|
||||
REQ_CNT_BY_MODEL: defaultdict(int),
|
||||
# 输入Token数
|
||||
IN_TOK_BY_TYPE: defaultdict(int),
|
||||
IN_TOK_BY_USER: defaultdict(int),
|
||||
IN_TOK_BY_MODEL: defaultdict(int),
|
||||
# 输出Token数
|
||||
OUT_TOK_BY_TYPE: defaultdict(int),
|
||||
OUT_TOK_BY_USER: defaultdict(int),
|
||||
OUT_TOK_BY_MODEL: defaultdict(int),
|
||||
# 总Token数
|
||||
TOTAL_TOK_BY_TYPE: defaultdict(int),
|
||||
TOTAL_TOK_BY_USER: defaultdict(int),
|
||||
TOTAL_TOK_BY_MODEL: defaultdict(int),
|
||||
# 总开销
|
||||
TOTAL_COST: 0.0,
|
||||
# 请求开销统计
|
||||
COST_BY_TYPE: defaultdict(float),
|
||||
COST_BY_USER: defaultdict(float),
|
||||
COST_BY_MODEL: defaultdict(float),
|
||||
}
|
||||
for period_key, _ in collect_period
|
||||
}
|
||||
|
||||
# 以最早的时间戳为起始时间获取记录
|
||||
for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}):
|
||||
record_timestamp = record.get("timestamp")
|
||||
for idx, (_, period_start) in enumerate(collect_period):
|
||||
if record_timestamp >= period_start:
|
||||
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
|
||||
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||
|
||||
request_type = record.get("request_type", "unknown") # 请求类型
|
||||
user_id = str(record.get("user_id", "unknown")) # 用户ID
|
||||
model_name = record.get("model_name", "unknown") # 模型名称
|
||||
|
||||
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
|
||||
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
|
||||
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
||||
|
||||
prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数
|
||||
completion_tokens = record.get("completion_tokens", 0) # 输出Token数
|
||||
total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数
|
||||
|
||||
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
|
||||
stats[period_key][IN_TOK_BY_MODEL][model_name] += prompt_tokens
|
||||
|
||||
stats[period_key][OUT_TOK_BY_TYPE][request_type] += completion_tokens
|
||||
stats[period_key][OUT_TOK_BY_USER][user_id] += completion_tokens
|
||||
stats[period_key][OUT_TOK_BY_MODEL][model_name] += completion_tokens
|
||||
|
||||
stats[period_key][TOTAL_TOK_BY_TYPE][request_type] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
||||
|
||||
cost = record.get("cost", 0.0)
|
||||
stats[period_key][TOTAL_COST] += cost
|
||||
stats[period_key][COST_BY_TYPE][request_type] += cost
|
||||
stats[period_key][COST_BY_USER][user_id] += cost
|
||||
stats[period_key][COST_BY_MODEL][model_name] += cost
|
||||
break # 取消更早时间段的判断
|
||||
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的在线时间统计数据
|
||||
|
||||
:param collect_period: 统计时间段
|
||||
"""
|
||||
if len(collect_period) <= 0:
|
||||
return {}
|
||||
else:
|
||||
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
|
||||
collect_period.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
stats = {
|
||||
period_key: {
|
||||
# 在线时间统计
|
||||
ONLINE_TIME: 0.0,
|
||||
}
|
||||
for period_key, _ in collect_period
|
||||
}
|
||||
|
||||
# 统计在线时间
|
||||
for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}):
|
||||
end_timestamp: datetime = record.get("end_timestamp")
|
||||
for idx, (_, period_start) in enumerate(collect_period):
|
||||
if end_timestamp >= period_start:
|
||||
# 由于end_timestamp会超前标记时间,所以我们需要判断是否晚于当前时间,如果是,则使用当前时间作为结束时间
|
||||
if end_timestamp > now:
|
||||
end_timestamp = now
|
||||
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
|
||||
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
|
||||
for period_key, _period_start in collect_period[idx:]:
|
||||
start_timestamp: datetime = record.get("start_timestamp")
|
||||
if start_timestamp < _period_start:
|
||||
# 如果开始时间在查询边界之前,则使用开始时间
|
||||
stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds()
|
||||
else:
|
||||
# 否则,使用开始时间
|
||||
stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds()
|
||||
break # 取消更早时间段的判断
|
||||
|
||||
return stats
|
||||
|
||||
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的消息统计数据
|
||||
|
||||
:param collect_period: 统计时间段
|
||||
"""
|
||||
if len(collect_period) <= 0:
|
||||
return {}
|
||||
else:
|
||||
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
|
||||
collect_period.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
stats = {
|
||||
period_key: {
|
||||
# 消息统计
|
||||
TOTAL_MSG_CNT: 0,
|
||||
MSG_CNT_BY_CHAT: defaultdict(int),
|
||||
}
|
||||
for period_key, _ in collect_period
|
||||
}
|
||||
|
||||
# 统计消息量
|
||||
for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}):
|
||||
chat_info = message.get("chat_info", None) # 聊天信息
|
||||
user_info = message.get("user_info", None) # 用户信息(消息发送人)
|
||||
message_time = message.get("time", 0) # 消息时间
|
||||
|
||||
group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息
|
||||
if group_info is not None:
|
||||
# 若有群聊信息
|
||||
chat_id = f"g{group_info.get('group_id')}"
|
||||
chat_name = group_info.get("group_name", f"群{group_info.get('group_id')}")
|
||||
elif user_info:
|
||||
# 若没有群聊信息,则尝试获取用户信息
|
||||
chat_id = f"u{user_info['user_id']}"
|
||||
chat_name = user_info["user_nickname"]
|
||||
else:
|
||||
continue # 如果没有群组信息也没有用户信息,则跳过
|
||||
|
||||
if chat_id in self.name_mapping:
|
||||
if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]:
|
||||
# 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称
|
||||
self.name_mapping[chat_id] = (chat_name, message_time)
|
||||
else:
|
||||
self.name_mapping[chat_id] = (chat_name, message_time)
|
||||
|
||||
for idx, (_, period_start) in enumerate(collect_period):
|
||||
if message_time >= period_start.timestamp():
|
||||
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
|
||||
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_MSG_CNT] += 1
|
||||
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
||||
break
|
||||
|
||||
return stats
|
||||
|
||||
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
收集各时间段的统计数据
|
||||
:param now: 基准当前时间
|
||||
"""
|
||||
|
||||
last_all_time_stat = None
|
||||
|
||||
if "last_full_statistics_timestamp" in local_storage and "last_full_statistics" in local_storage:
|
||||
# 若存有上次完整统计的时间戳,则使用该时间戳作为"所有时间"的起始时间,进行增量统计
|
||||
last_full_stat_ts: float = local_storage["last_full_statistics_timestamp"]
|
||||
last_all_time_stat = local_storage["last_full_statistics"]
|
||||
self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段
|
||||
self.stat_period.append(("all_time", now - datetime.fromtimestamp(last_full_stat_ts), "自部署以来的"))
|
||||
|
||||
stat_start_timestamp = [(period[0], now - period[1]) for period in self.stat_period]
|
||||
|
||||
stat = {item[0]: {} for item in self.stat_period}
|
||||
|
||||
model_req_stat = self._collect_model_request_for_period(stat_start_timestamp)
|
||||
online_time_stat = self._collect_online_time_for_period(stat_start_timestamp, now)
|
||||
message_count_stat = self._collect_message_count_for_period(stat_start_timestamp)
|
||||
|
||||
# 统计数据合并
|
||||
# 合并三类统计数据
|
||||
for period_key, _ in stat_start_timestamp:
|
||||
stat[period_key].update(model_req_stat[period_key])
|
||||
stat[period_key].update(online_time_stat[period_key])
|
||||
stat[period_key].update(message_count_stat[period_key])
|
||||
|
||||
if last_all_time_stat:
|
||||
# 若存在上次完整统计数据,则将其与当前统计数据合并
|
||||
for key, val in last_all_time_stat.items():
|
||||
if isinstance(val, dict):
|
||||
# 是字典类型,则进行合并
|
||||
for sub_key, sub_val in val.items():
|
||||
stat["all_time"][key][sub_key] += sub_val
|
||||
else:
|
||||
# 直接合并
|
||||
stat["all_time"][key] += val
|
||||
|
||||
# 更新上次完整统计数据的时间戳
|
||||
local_storage["last_full_statistics_timestamp"] = now.timestamp()
|
||||
# 更新上次完整统计数据
|
||||
local_storage["last_full_statistics"] = stat["all_time"]
|
||||
|
||||
return stat
|
||||
|
||||
# -- 以下为统计数据格式化方法 --
|
||||
|
||||
@staticmethod
|
||||
def _format_total_stat(stats: Dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化总统计数据
|
||||
"""
|
||||
|
||||
output = [
|
||||
f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}",
|
||||
f"总消息数: {stats[TOTAL_MSG_CNT]}",
|
||||
f"总请求数: {stats[TOTAL_REQ_CNT]}",
|
||||
f"总花费: {stats[TOTAL_COST]:.4f}¥",
|
||||
"",
|
||||
]
|
||||
|
||||
return "\n".join(output)
|
||||
|
||||
@staticmethod
|
||||
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化按模型分类的统计数据
|
||||
"""
|
||||
if stats[TOTAL_REQ_CNT] > 0:
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.4f}¥"
|
||||
|
||||
output = [
|
||||
"按模型分类统计:",
|
||||
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费",
|
||||
]
|
||||
for model_name, count in sorted(stats[REQ_CNT_BY_MODEL].items()):
|
||||
name = model_name[:29] + "..." if len(model_name) > 32 else model_name
|
||||
in_tokens = stats[IN_TOK_BY_MODEL][model_name]
|
||||
out_tokens = stats[OUT_TOK_BY_MODEL][model_name]
|
||||
tokens = stats[TOTAL_TOK_BY_MODEL][model_name]
|
||||
cost = stats[COST_BY_MODEL][model_name]
|
||||
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost))
|
||||
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
else:
|
||||
return ""
|
||||
|
||||
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化聊天统计数据
|
||||
"""
|
||||
if stats[TOTAL_MSG_CNT] > 0:
|
||||
output = ["聊天消息统计:", " 联系人/群组名称 消息数量"]
|
||||
for chat_id, count in sorted(stats[MSG_CNT_BY_CHAT].items()):
|
||||
output.append(f"{self.name_mapping[chat_id][0][:32]:<32} {count:>10}")
|
||||
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
else:
|
||||
return ""
|
||||
|
||||
def _generate_html_report(self, stat: dict[str, Any], now: datetime):
|
||||
"""
|
||||
生成HTML格式的统计报告
|
||||
:param stat: 统计数据
|
||||
:param now: 基准当前时间
|
||||
:return: HTML格式的统计报告
|
||||
"""
|
||||
|
||||
tab_list = [
|
||||
f'<button class="tab-link" onclick="showTab(event, \'{period[0]}\')">{period[2]}</button>'
|
||||
for period in self.stat_period
|
||||
]
|
||||
|
||||
def _format_stat_data(stat_data: dict[str, Any], div_id: str, start_time: datetime) -> str:
|
||||
"""
|
||||
格式化一个时间段的统计数据到html div块
|
||||
:param stat_data: 统计数据
|
||||
:param div_id: div的ID
|
||||
:param start_time: 统计时间段开始时间
|
||||
"""
|
||||
# format总在线时间
|
||||
|
||||
# 按模型分类统计
|
||||
model_rows = "\n".join([
|
||||
f"<tr>"
|
||||
f"<td>{model_name}</td>"
|
||||
f"<td>{count}</td>"
|
||||
f"<td>{stat_data[IN_TOK_BY_MODEL][model_name]}</td>"
|
||||
f"<td>{stat_data[OUT_TOK_BY_MODEL][model_name]}</td>"
|
||||
f"<td>{stat_data[TOTAL_TOK_BY_MODEL][model_name]}</td>"
|
||||
f"<td>{stat_data[COST_BY_MODEL][model_name]:.4f} ¥</td>"
|
||||
f"</tr>"
|
||||
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
|
||||
])
|
||||
# 按请求类型分类统计
|
||||
type_rows = "\n".join([
|
||||
f"<tr>"
|
||||
f"<td>{req_type}</td>"
|
||||
f"<td>{count}</td>"
|
||||
f"<td>{stat_data[IN_TOK_BY_TYPE][req_type]}</td>"
|
||||
f"<td>{stat_data[OUT_TOK_BY_TYPE][req_type]}</td>"
|
||||
f"<td>{stat_data[TOTAL_TOK_BY_TYPE][req_type]}</td>"
|
||||
f"<td>{stat_data[COST_BY_TYPE][req_type]:.4f} ¥</td>"
|
||||
f"</tr>"
|
||||
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
|
||||
])
|
||||
# 按用户分类统计
|
||||
user_rows = "\n".join([
|
||||
f"<tr>"
|
||||
f"<td>{user_id}</td>"
|
||||
f"<td>{count}</td>"
|
||||
f"<td>{stat_data[IN_TOK_BY_USER][user_id]}</td>"
|
||||
f"<td>{stat_data[OUT_TOK_BY_USER][user_id]}</td>"
|
||||
f"<td>{stat_data[TOTAL_TOK_BY_USER][user_id]}</td>"
|
||||
f"<td>{stat_data[COST_BY_USER][user_id]:.4f} ¥</td>"
|
||||
f"</tr>"
|
||||
for user_id, count in sorted(stat_data[REQ_CNT_BY_USER].items())
|
||||
])
|
||||
# 聊天消息统计
|
||||
chat_rows = "\n".join([
|
||||
f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>"
|
||||
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items())
|
||||
])
|
||||
# 生成HTML
|
||||
return f"""
|
||||
<div id=\"{div_id}\" class=\"tab-content\">
|
||||
<p class=\"info-item\">
|
||||
<strong>统计时段: </strong>
|
||||
{start_time.strftime("%Y-%m-%d %H:%M:%S")} ~ {now.strftime("%Y-%m-%d %H:%M:%S")}
|
||||
</p>
|
||||
<p class=\"info-item\"><strong>总在线时间: </strong>{_format_online_time(stat_data[ONLINE_TIME])}</p>
|
||||
<p class=\"info-item\"><strong>总消息数: </strong>{stat_data[TOTAL_MSG_CNT]}</p>
|
||||
<p class=\"info-item\"><strong>总请求数: </strong>{stat_data[TOTAL_REQ_CNT]}</p>
|
||||
<p class=\"info-item\"><strong>总花费: </strong>{stat_data[TOTAL_COST]:.4f} ¥</p>
|
||||
|
||||
<h2>按模型分类统计</h2>
|
||||
<table>
|
||||
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th></tr></thead>
|
||||
<tbody>
|
||||
{model_rows}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<h2>按请求类型分类统计</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{type_rows}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<h2>按用户分类统计</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>用户名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{user_rows}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<h2>聊天消息统计</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>联系人/群组名称</th><th>消息数量</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{chat_rows}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
"""
|
||||
|
||||
tab_content_list = [
|
||||
_format_stat_data(stat[period[0]], period[0], now - period[1])
|
||||
for period in self.stat_period
|
||||
if period[0] != "all_time"
|
||||
]
|
||||
|
||||
tab_content_list.append(
|
||||
_format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"]))
|
||||
)
|
||||
|
||||
joined_tab_list = "\n".join(tab_list)
|
||||
joined_tab_content = "\n".join(tab_content_list)
|
||||
|
||||
html_template = (
|
||||
"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>MaiBot运行统计报告</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background-color: #f4f7f6;
|
||||
color: #333;
|
||||
line-height: 1.6;
|
||||
}
|
||||
.container {
|
||||
max-width: 900px;
|
||||
margin: 20px auto;
|
||||
background-color: #fff;
|
||||
padding: 25px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||
}
|
||||
h1, h2 {
|
||||
color: #2c3e50;
|
||||
border-bottom: 2px solid #3498db;
|
||||
padding-bottom: 10px;
|
||||
margin-top: 0;
|
||||
}
|
||||
h1 {
|
||||
text-align: center;
|
||||
font-size: 2em;
|
||||
}
|
||||
h2 {
|
||||
font-size: 1.5em;
|
||||
margin-top: 30px;
|
||||
}
|
||||
p {
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.info-item {
|
||||
background-color: #ecf0f1;
|
||||
padding: 8px 12px;
|
||||
border-radius: 4px;
|
||||
margin-bottom: 8px;
|
||||
font-size: 0.95em;
|
||||
}
|
||||
.info-item strong {
|
||||
color: #2980b9;
|
||||
}
|
||||
table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin-top: 15px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
th, td {
|
||||
border: 1px solid #ddd;
|
||||
padding: 10px;
|
||||
text-align: left;
|
||||
}
|
||||
th {
|
||||
background-color: #3498db;
|
||||
color: white;
|
||||
font-weight: bold;
|
||||
}
|
||||
tr:nth-child(even) {
|
||||
background-color: #f9f9f9;
|
||||
}
|
||||
.footer {
|
||||
text-align: center;
|
||||
margin-top: 30px;
|
||||
font-size: 0.8em;
|
||||
color: #7f8c8d;
|
||||
}
|
||||
.tabs {
|
||||
overflow: hidden;
|
||||
background: #ecf0f1;
|
||||
display: flex;
|
||||
}
|
||||
.tabs button {
|
||||
background: inherit; border: none; outline: none;
|
||||
padding: 14px 16px; cursor: pointer;
|
||||
transition: 0.3s; font-size: 16px;
|
||||
}
|
||||
.tabs button:hover {
|
||||
background-color: #d4dbdc;
|
||||
}
|
||||
.tabs button.active {
|
||||
background-color: #b3bbbd;
|
||||
}
|
||||
.tab-content {
|
||||
display: none;
|
||||
padding: 20px;
|
||||
background-color: #fff;
|
||||
border: 1px solid #ccc;
|
||||
}
|
||||
.tab-content.active {
|
||||
display: block;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
"""
|
||||
+ f"""
|
||||
<div class="container">
|
||||
<h1>MaiBot运行统计报告</h1>
|
||||
<p class="info-item"><strong>统计截止时间:</strong> {now.strftime("%Y-%m-%d %H:%M:%S")}</p>
|
||||
|
||||
<div class="tabs">
|
||||
{joined_tab_list}
|
||||
</div>
|
||||
|
||||
{joined_tab_content}
|
||||
</div>
|
||||
"""
|
||||
+ """
|
||||
<script>
|
||||
let i, tab_content, tab_links;
|
||||
tab_content = document.getElementsByClassName("tab-content");
|
||||
tab_links = document.getElementsByClassName("tab-link");
|
||||
|
||||
tab_content[0].classList.add("active");
|
||||
tab_links[0].classList.add("active");
|
||||
|
||||
function showTab(evt, tabName) {{
|
||||
for (i = 0; i < tab_content.length; i++) tab_content[i].classList.remove("active");
|
||||
for (i = 0; i < tab_links.length; i++) tab_links[i].classList.remove("active");
|
||||
document.getElementById(tabName).classList.add("active");
|
||||
evt.currentTarget.classList.add("active");
|
||||
}}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
|
||||
with open(self.record_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(html_template)
|
||||
155
src/chat/utils/timer_calculator.py
Normal file
155
src/chat/utils/timer_calculator.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from time import perf_counter
|
||||
from functools import wraps
|
||||
from typing import Optional, Dict, Callable
|
||||
import asyncio
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
"""
|
||||
# 更好的计时器
|
||||
|
||||
使用形式:
|
||||
- 上下文
|
||||
- 装饰器
|
||||
- 直接实例化
|
||||
|
||||
使用场景:
|
||||
- 使用Timer:在需要测量代码执行时间时(如性能测试、计时器工具),Timer类是更可靠、高精度的选择。
|
||||
- 使用time.time()的场景:当需要记录实际时间点(如日志、时间戳)时使用,但避免用它测量时间间隔。
|
||||
|
||||
使用方式:
|
||||
|
||||
【装饰器】
|
||||
time_dict = {}
|
||||
@Timer("计数", time_dict)
|
||||
def func():
|
||||
pass
|
||||
print(time_dict)
|
||||
|
||||
【上下文_1】
|
||||
def func():
|
||||
with Timer() as t:
|
||||
pass
|
||||
print(t)
|
||||
print(t.human_readable)
|
||||
|
||||
【上下文_2】
|
||||
def func():
|
||||
time_dict = {}
|
||||
with Timer("计数", time_dict):
|
||||
pass
|
||||
print(time_dict)
|
||||
|
||||
【直接实例化】
|
||||
a = Timer()
|
||||
print(a) # 直接输出当前 perf_counter 值
|
||||
|
||||
参数:
|
||||
- name:计时器的名字,默认为 None
|
||||
- storage:计时器结果存储字典,默认为 None
|
||||
- auto_unit:自动选择单位(毫秒或秒),默认为 True(自动根据时间切换毫秒或秒)
|
||||
- do_type_check:是否进行类型检查,默认为 False(不进行类型检查)
|
||||
|
||||
属性:human_readable
|
||||
|
||||
自定义错误:TimerTypeError
|
||||
"""
|
||||
|
||||
|
||||
class TimerTypeError(TypeError):
|
||||
"""自定义类型错误"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, param, expected_type, actual_type):
|
||||
super().__init__(f"参数 '{param}' 类型错误,期望 {expected_type},实际得到 {actual_type.__name__}")
|
||||
|
||||
|
||||
class Timer:
|
||||
"""
|
||||
Timer 支持三种模式:
|
||||
1. 装饰器模式:用于测量函数/协程运行时间
|
||||
2. 上下文管理器模式:用于 with 语句块内部计时
|
||||
3. 直接实例化:如果不调用 __enter__,打印对象时将显示当前 perf_counter 的值
|
||||
"""
|
||||
|
||||
__slots__ = ("name", "storage", "elapsed", "auto_unit", "start")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
storage: Optional[Dict[str, float]] = None,
|
||||
auto_unit: bool = True,
|
||||
do_type_check: bool = False,
|
||||
):
|
||||
if do_type_check:
|
||||
self._validate_types(name, storage)
|
||||
|
||||
self.name = name
|
||||
self.storage = storage
|
||||
self.elapsed = None
|
||||
|
||||
self.auto_unit = auto_unit
|
||||
self.start = None
|
||||
|
||||
@staticmethod
|
||||
def _validate_types(name, storage):
|
||||
"""类型检查"""
|
||||
if name is not None and not isinstance(name, str):
|
||||
raise TimerTypeError("name", "Optional[str]", type(name))
|
||||
|
||||
if storage is not None and not isinstance(storage, dict):
|
||||
raise TimerTypeError("storage", "Optional[dict]", type(storage))
|
||||
|
||||
def __call__(self, func: Optional[Callable] = None) -> Callable:
|
||||
"""装饰器模式"""
|
||||
if func is None:
|
||||
return lambda f: Timer(name=self.name or f.__name__, storage=self.storage, auto_unit=self.auto_unit)(f)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
with self:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
wrapper.__timer__ = self # 保留计时器引用
|
||||
return wrapper
|
||||
|
||||
def __enter__(self):
|
||||
"""上下文管理器入口"""
|
||||
self.start = perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.elapsed = perf_counter() - self.start
|
||||
self._record_time()
|
||||
return False
|
||||
|
||||
def _record_time(self):
|
||||
"""记录时间"""
|
||||
if self.storage is not None and self.name:
|
||||
self.storage[self.name] = self.elapsed
|
||||
|
||||
@property
|
||||
def human_readable(self) -> str:
|
||||
"""人类可读时间格式"""
|
||||
if self.elapsed is None:
|
||||
return "未计时"
|
||||
|
||||
if self.auto_unit:
|
||||
return f"{self.elapsed * 1000:.2f}毫秒" if self.elapsed < 1 else f"{self.elapsed:.2f}秒"
|
||||
return f"{self.elapsed:.4f}秒"
|
||||
|
||||
def __str__(self):
|
||||
if self.start is not None:
|
||||
if self.elapsed is None:
|
||||
current_elapsed = perf_counter() - self.start
|
||||
return f"<Timer {self.name or '匿名'} [计时中: {current_elapsed:.4f}秒]>"
|
||||
return f"<Timer {self.name or '匿名'} [{self.human_readable}]>"
|
||||
return f"{perf_counter()}"
|
||||
477
src/chat/utils/typo_generator.py
Normal file
477
src/chat/utils/typo_generator.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
错别字生成器 - 基于拼音和字频的中文错别字生成工具
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import jieba
|
||||
from pypinyin import Style, pinyin
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("typo_gen")
|
||||
|
||||
|
||||
class ChineseTypoGenerator:
|
||||
def __init__(self, error_rate=0.3, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3, max_freq_diff=200):
|
||||
"""
|
||||
初始化错别字生成器
|
||||
|
||||
参数:
|
||||
error_rate: 单字替换概率
|
||||
min_freq: 最小字频阈值
|
||||
tone_error_rate: 声调错误概率
|
||||
word_replace_rate: 整词替换概率
|
||||
max_freq_diff: 最大允许的频率差异
|
||||
"""
|
||||
self.error_rate = error_rate
|
||||
self.min_freq = min_freq
|
||||
self.tone_error_rate = tone_error_rate
|
||||
self.word_replace_rate = word_replace_rate
|
||||
self.max_freq_diff = max_freq_diff
|
||||
|
||||
# 加载数据
|
||||
# print("正在加载汉字数据库,请稍候...")
|
||||
# logger.info("正在加载汉字数据库,请稍候...")
|
||||
|
||||
self.pinyin_dict = self._create_pinyin_dict()
|
||||
self.char_frequency = self._load_or_create_char_frequency()
|
||||
|
||||
def _load_or_create_char_frequency(self):
|
||||
"""
|
||||
加载或创建汉字频率字典
|
||||
"""
|
||||
cache_file = Path("depends-data/char_frequency.json")
|
||||
|
||||
# 如果缓存文件存在,直接加载
|
||||
if cache_file.exists():
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
# 使用内置的词频文件
|
||||
char_freq = defaultdict(int)
|
||||
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
|
||||
|
||||
# 读取jieba的词典文件
|
||||
with open(dict_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
word, freq = line.strip().split()[:2]
|
||||
# 对词中的每个字进行频率累加
|
||||
for char in word:
|
||||
if self._is_chinese_char(char):
|
||||
char_freq[char] += int(freq)
|
||||
|
||||
# 归一化频率值
|
||||
max_freq = max(char_freq.values())
|
||||
normalized_freq = {char: freq / max_freq * 1000 for char, freq in char_freq.items()}
|
||||
|
||||
# 保存到缓存文件
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
|
||||
|
||||
return normalized_freq
|
||||
|
||||
@staticmethod
|
||||
def _create_pinyin_dict():
|
||||
"""
|
||||
创建拼音到汉字的映射字典
|
||||
"""
|
||||
# 常用汉字范围
|
||||
chars = [chr(i) for i in range(0x4E00, 0x9FFF)]
|
||||
pinyin_dict = defaultdict(list)
|
||||
|
||||
# 为每个汉字建立拼音映射
|
||||
for char in chars:
|
||||
try:
|
||||
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||
pinyin_dict[py].append(char)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return pinyin_dict
|
||||
|
||||
@staticmethod
|
||||
def _is_chinese_char(char):
|
||||
"""
|
||||
判断是否为汉字
|
||||
"""
|
||||
try:
|
||||
return "\u4e00" <= char <= "\u9fff"
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
return False
|
||||
|
||||
def _get_pinyin(self, sentence):
|
||||
"""
|
||||
将中文句子拆分成单个汉字并获取其拼音
|
||||
"""
|
||||
# 将句子拆分成单个字符
|
||||
characters = list(sentence)
|
||||
|
||||
# 获取每个字符的拼音
|
||||
result = []
|
||||
for char in characters:
|
||||
# 跳过空格和非汉字字符
|
||||
if char.isspace() or not self._is_chinese_char(char):
|
||||
continue
|
||||
# 获取拼音(数字声调)
|
||||
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||
result.append((char, py))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_similar_tone_pinyin(py):
|
||||
"""
|
||||
获取相似声调的拼音
|
||||
"""
|
||||
# 检查拼音是否为空或无效
|
||||
if not py or len(py) < 1:
|
||||
return py
|
||||
|
||||
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
|
||||
if not py[-1].isdigit():
|
||||
# 为非数字结尾的拼音添加数字声调1
|
||||
return py + "1"
|
||||
|
||||
base = py[:-1] # 去掉声调
|
||||
tone = int(py[-1]) # 获取声调
|
||||
|
||||
# 处理轻声(通常用5表示)或无效声调
|
||||
if tone not in [1, 2, 3, 4]:
|
||||
return base + str(random.choice([1, 2, 3, 4]))
|
||||
|
||||
# 正常处理声调
|
||||
possible_tones = [1, 2, 3, 4]
|
||||
possible_tones.remove(tone) # 移除原声调
|
||||
new_tone = random.choice(possible_tones) # 随机选择一个新声调
|
||||
return base + str(new_tone)
|
||||
|
||||
def _calculate_replacement_probability(self, orig_freq, target_freq):
|
||||
"""
|
||||
根据频率差计算替换概率
|
||||
"""
|
||||
if target_freq > orig_freq:
|
||||
return 1.0 # 如果替换字频率更高,保持原有概率
|
||||
|
||||
freq_diff = orig_freq - target_freq
|
||||
if freq_diff > self.max_freq_diff:
|
||||
return 0.0 # 频率差太大,不替换
|
||||
|
||||
# 使用指数衰减函数计算概率
|
||||
# 频率差为0时概率为1,频率差为max_freq_diff时概率接近0
|
||||
return math.exp(-3 * freq_diff / self.max_freq_diff)
|
||||
|
||||
def _get_similar_frequency_chars(self, char, py, num_candidates=5):
|
||||
"""
|
||||
获取与给定字频率相近的同音字,可能包含声调错误
|
||||
"""
|
||||
homophones = []
|
||||
|
||||
# 有一定概率使用错误声调
|
||||
if random.random() < self.tone_error_rate:
|
||||
wrong_tone_py = self._get_similar_tone_pinyin(py)
|
||||
homophones.extend(self.pinyin_dict[wrong_tone_py])
|
||||
|
||||
# 添加正确声调的同音字
|
||||
homophones.extend(self.pinyin_dict[py])
|
||||
|
||||
if not homophones:
|
||||
return None
|
||||
|
||||
# 获取原字的频率
|
||||
orig_freq = self.char_frequency.get(char, 0)
|
||||
|
||||
# 计算所有同音字与原字的频率差,并过滤掉低频字
|
||||
freq_diff = [
|
||||
(h, self.char_frequency.get(h, 0))
|
||||
for h in homophones
|
||||
if h != char and self.char_frequency.get(h, 0) >= self.min_freq
|
||||
]
|
||||
|
||||
if not freq_diff:
|
||||
return None
|
||||
|
||||
# 计算每个候选字的替换概率
|
||||
candidates_with_prob = []
|
||||
for h, freq in freq_diff:
|
||||
prob = self._calculate_replacement_probability(orig_freq, freq)
|
||||
if prob > 0: # 只保留有效概率的候选字
|
||||
candidates_with_prob.append((h, prob))
|
||||
|
||||
if not candidates_with_prob:
|
||||
return None
|
||||
|
||||
# 根据概率排序
|
||||
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 返回概率最高的几个字
|
||||
return [char for char, _ in candidates_with_prob[:num_candidates]]
|
||||
|
||||
@staticmethod
|
||||
def _get_word_pinyin(word):
|
||||
"""
|
||||
获取词语的拼音列表
|
||||
"""
|
||||
return [py[0] for py in pinyin(word, style=Style.TONE3)]
|
||||
|
||||
@staticmethod
|
||||
def _segment_sentence(sentence):
|
||||
"""
|
||||
使用jieba分词,返回词语列表
|
||||
"""
|
||||
return list(jieba.cut(sentence))
|
||||
|
||||
def _get_word_homophones(self, word):
|
||||
"""
|
||||
获取整个词的同音词,只返回高频的有意义词语
|
||||
"""
|
||||
if len(word) == 1:
|
||||
return []
|
||||
|
||||
# 获取词的拼音
|
||||
word_pinyin = self._get_word_pinyin(word)
|
||||
|
||||
# 遍历所有可能的同音字组合
|
||||
candidates = []
|
||||
for py in word_pinyin:
|
||||
chars = self.pinyin_dict.get(py, [])
|
||||
if not chars:
|
||||
return []
|
||||
candidates.append(chars)
|
||||
|
||||
# 生成所有可能的组合
|
||||
import itertools
|
||||
|
||||
all_combinations = itertools.product(*candidates)
|
||||
|
||||
# 获取jieba词典和词频信息
|
||||
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
|
||||
valid_words = {} # 改用字典存储词语及其频率
|
||||
with open(dict_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 2:
|
||||
word_text = parts[0]
|
||||
word_freq = float(parts[1]) # 获取词频
|
||||
valid_words[word_text] = word_freq
|
||||
|
||||
# 获取原词的词频作为参考
|
||||
original_word_freq = valid_words.get(word, 0)
|
||||
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
|
||||
|
||||
# 过滤和计算频率
|
||||
homophones = []
|
||||
for combo in all_combinations:
|
||||
new_word = "".join(combo)
|
||||
if new_word != word and new_word in valid_words:
|
||||
new_word_freq = valid_words[new_word]
|
||||
# 只保留词频达到阈值的词
|
||||
if new_word_freq >= min_word_freq:
|
||||
# 计算词的平均字频(考虑字频和词频)
|
||||
char_avg_freq = sum(self.char_frequency.get(c, 0) for c in new_word) / len(new_word)
|
||||
# 综合评分:结合词频和字频
|
||||
combined_score = new_word_freq * 0.7 + char_avg_freq * 0.3
|
||||
if combined_score >= self.min_freq:
|
||||
homophones.append((new_word, combined_score))
|
||||
|
||||
# 按综合分数排序并限制返回数量
|
||||
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
|
||||
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
|
||||
|
||||
def create_typo_sentence(self, sentence):
|
||||
"""
|
||||
创建包含同音字错误的句子,支持词语级别和字级别的替换
|
||||
|
||||
参数:
|
||||
sentence: 输入的中文句子
|
||||
|
||||
返回:
|
||||
typo_sentence: 包含错别字的句子
|
||||
correction_suggestion: 随机选择的一个纠正建议,返回正确的字/词
|
||||
"""
|
||||
result = []
|
||||
typo_info = []
|
||||
word_typos = [] # 记录词语错误对(错词,正确词)
|
||||
char_typos = [] # 记录单字错误对(错字,正确字)
|
||||
current_pos = 0
|
||||
|
||||
# 分词
|
||||
words = self._segment_sentence(sentence)
|
||||
|
||||
for word in words:
|
||||
# 如果是标点符号或空格,直接添加
|
||||
if all(not self._is_chinese_char(c) for c in word):
|
||||
result.append(word)
|
||||
current_pos += len(word)
|
||||
continue
|
||||
|
||||
# 获取词语的拼音
|
||||
word_pinyin = self._get_word_pinyin(word)
|
||||
|
||||
# 尝试整词替换
|
||||
if len(word) > 1 and random.random() < self.word_replace_rate:
|
||||
word_homophones = self._get_word_homophones(word)
|
||||
if word_homophones:
|
||||
typo_word = random.choice(word_homophones)
|
||||
# 计算词的平均频率
|
||||
orig_freq = sum(self.char_frequency.get(c, 0) for c in word) / len(word)
|
||||
typo_freq = sum(self.char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
|
||||
|
||||
# 添加到结果中
|
||||
result.append(typo_word)
|
||||
typo_info.append(
|
||||
(
|
||||
word,
|
||||
typo_word,
|
||||
" ".join(word_pinyin),
|
||||
" ".join(self._get_word_pinyin(typo_word)),
|
||||
orig_freq,
|
||||
typo_freq,
|
||||
)
|
||||
)
|
||||
word_typos.append((typo_word, word)) # 记录(错词,正确词)对
|
||||
current_pos += len(typo_word)
|
||||
continue
|
||||
|
||||
# 如果不进行整词替换,则进行单字替换
|
||||
if len(word) == 1:
|
||||
char = word
|
||||
py = word_pinyin[0]
|
||||
if random.random() < self.error_rate:
|
||||
similar_chars = self._get_similar_frequency_chars(char, py)
|
||||
if similar_chars:
|
||||
typo_char = random.choice(similar_chars)
|
||||
typo_freq = self.char_frequency.get(typo_char, 0)
|
||||
orig_freq = self.char_frequency.get(char, 0)
|
||||
replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq)
|
||||
if random.random() < replace_prob:
|
||||
result.append(typo_char)
|
||||
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
||||
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
||||
char_typos.append((typo_char, char)) # 记录(错字,正确字)对
|
||||
current_pos += 1
|
||||
continue
|
||||
result.append(char)
|
||||
current_pos += 1
|
||||
else:
|
||||
# 处理多字词的单字替换
|
||||
word_result = []
|
||||
for _, (char, py) in enumerate(zip(word, word_pinyin)):
|
||||
# 词中的字替换概率降低
|
||||
word_error_rate = self.error_rate * (0.7 ** (len(word) - 1))
|
||||
|
||||
if random.random() < word_error_rate:
|
||||
similar_chars = self._get_similar_frequency_chars(char, py)
|
||||
if similar_chars:
|
||||
typo_char = random.choice(similar_chars)
|
||||
typo_freq = self.char_frequency.get(typo_char, 0)
|
||||
orig_freq = self.char_frequency.get(char, 0)
|
||||
replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq)
|
||||
if random.random() < replace_prob:
|
||||
word_result.append(typo_char)
|
||||
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
||||
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
||||
char_typos.append((typo_char, char)) # 记录(错字,正确字)对
|
||||
continue
|
||||
word_result.append(char)
|
||||
result.append("".join(word_result))
|
||||
current_pos += len(word)
|
||||
|
||||
# 优先从词语错误中选择,如果没有则从单字错误中选择
|
||||
correction_suggestion = None
|
||||
# 50%概率返回纠正建议
|
||||
if random.random() < 0.5:
|
||||
if word_typos:
|
||||
wrong_word, correct_word = random.choice(word_typos)
|
||||
correction_suggestion = correct_word
|
||||
elif char_typos:
|
||||
wrong_char, correct_char = random.choice(char_typos)
|
||||
correction_suggestion = correct_char
|
||||
|
||||
return "".join(result), correction_suggestion
|
||||
|
||||
@staticmethod
|
||||
def format_typo_info(typo_info):
|
||||
"""
|
||||
格式化错别字信息
|
||||
|
||||
参数:
|
||||
typo_info: 错别字信息列表
|
||||
|
||||
返回:
|
||||
格式化后的错别字信息字符串
|
||||
"""
|
||||
if not typo_info:
|
||||
return "未生成错别字"
|
||||
|
||||
result = []
|
||||
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
|
||||
# 判断是否为词语替换
|
||||
is_word = " " in orig_py
|
||||
if is_word:
|
||||
error_type = "整词替换"
|
||||
else:
|
||||
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
|
||||
error_type = "声调错误" if tone_error else "同音字替换"
|
||||
|
||||
result.append(
|
||||
f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> "
|
||||
f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]"
|
||||
)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
def set_params(self, **kwargs):
|
||||
"""
|
||||
设置参数
|
||||
|
||||
可设置参数:
|
||||
error_rate: 单字替换概率
|
||||
min_freq: 最小字频阈值
|
||||
tone_error_rate: 声调错误概率
|
||||
word_replace_rate: 整词替换概率
|
||||
max_freq_diff: 最大允许的频率差异
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
print(f"参数 {key} 已设置为 {value}")
|
||||
else:
|
||||
print(f"警告: 参数 {key} 不存在")
|
||||
|
||||
|
||||
def main():
|
||||
# 创建错别字生成器实例
|
||||
typo_generator = ChineseTypoGenerator(error_rate=0.03, min_freq=7, tone_error_rate=0.02, word_replace_rate=0.3)
|
||||
|
||||
# 获取用户输入
|
||||
sentence = input("请输入中文句子:")
|
||||
|
||||
# 创建包含错别字的句子
|
||||
start_time = time.time()
|
||||
typo_sentence, correction_suggestion = typo_generator.create_typo_sentence(sentence)
|
||||
|
||||
# 打印结果
|
||||
print("\n原句:", sentence)
|
||||
print("错字版:", typo_sentence)
|
||||
|
||||
# 打印纠正建议
|
||||
if correction_suggestion:
|
||||
print("\n随机纠正建议:")
|
||||
print(f"应该改为:{correction_suggestion}")
|
||||
|
||||
# 计算并打印总耗时
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
print(f"\n总耗时:{total_time:.2f}秒")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
744
src/chat/utils/utils.py
Normal file
744
src/chat/utils/utils.py
Normal file
@@ -0,0 +1,744 @@
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from collections import Counter
|
||||
|
||||
import jieba
|
||||
import numpy as np
|
||||
from maim_message import UserInfo
|
||||
from pymongo.errors import PyMongoError
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
from src.manager.mood_manager import mood_manager
|
||||
from ..message_receive.message import MessageRecv
|
||||
from ..models.utils_model import LLMRequest
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
from ...common.database import db
|
||||
from ...config.config import global_config
|
||||
|
||||
logger = get_module_logger("chat_utils")
|
||||
|
||||
|
||||
def is_english_letter(char: str) -> bool:
|
||||
"""检查字符是否为英文字母(忽略大小写)"""
|
||||
return "a" <= char.lower() <= "z"
|
||||
|
||||
|
||||
def db_message_to_str(message_dict: dict) -> str:
|
||||
logger.debug(f"message_dict: {message_dict}")
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
||||
try:
|
||||
name = "[(%s)%s]%s" % (
|
||||
message_dict["user_id"],
|
||||
message_dict.get("user_nickname", ""),
|
||||
message_dict.get("user_cardname", ""),
|
||||
)
|
||||
except Exception:
|
||||
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
||||
content = message_dict.get("processed_plain_text", "")
|
||||
result = f"[{time_str}] {name}: {content}\n"
|
||||
logger.debug(f"result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
"""检查消息是否提到了机器人"""
|
||||
keywords = [global_config.BOT_NICKNAME]
|
||||
nicknames = global_config.BOT_ALIAS_NAMES
|
||||
reply_probability = 0.0
|
||||
is_at = False
|
||||
is_mentioned = False
|
||||
|
||||
if (
|
||||
message.message_info.additional_config is not None
|
||||
and message.message_info.additional_config.get("is_mentioned") is not None
|
||||
):
|
||||
try:
|
||||
reply_probability = float(message.message_info.additional_config.get("is_mentioned"))
|
||||
is_mentioned = True
|
||||
return is_mentioned, reply_probability
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
logger.warning(
|
||||
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
|
||||
)
|
||||
|
||||
# 判断是否被@
|
||||
if re.search(f"@[\s\S]*?(id:{global_config.BOT_QQ})", message.processed_plain_text):
|
||||
is_at = True
|
||||
is_mentioned = True
|
||||
|
||||
if is_at and global_config.at_bot_inevitable_reply:
|
||||
reply_probability = 1.0
|
||||
logger.info("被@,回复概率设置为100%")
|
||||
else:
|
||||
if not is_mentioned:
|
||||
# 判断是否被回复
|
||||
if re.match(
|
||||
f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\):[\s\S]*?],说:", message.processed_plain_text
|
||||
):
|
||||
is_mentioned = True
|
||||
else:
|
||||
# 判断内容中是否被提及
|
||||
message_content = re.sub(r"@[\s\S]*?((\d+))", "", message.processed_plain_text)
|
||||
message_content = re.sub(r"\[回复 [\s\S]*?\(((\d+)|未知id)\):[\s\S]*?],说:", "", message_content)
|
||||
for keyword in keywords:
|
||||
if keyword in message_content:
|
||||
is_mentioned = True
|
||||
for nickname in nicknames:
|
||||
if nickname in message_content:
|
||||
is_mentioned = True
|
||||
if is_mentioned and global_config.mentioned_bot_inevitable_reply:
|
||||
reply_probability = 1.0
|
||||
logger.info("被提及,回复概率设置为100%")
|
||||
return is_mentioned, reply_probability
|
||||
|
||||
|
||||
async def get_embedding(text, request_type="embedding"):
|
||||
"""获取文本的embedding向量"""
|
||||
llm = LLMRequest(model=global_config.embedding, request_type=request_type)
|
||||
# return llm.get_embedding_sync(text)
|
||||
try:
|
||||
embedding = await llm.get_embedding(text)
|
||||
except Exception as e:
|
||||
logger.error(f"获取embedding失败: {str(e)}")
|
||||
embedding = None
|
||||
return embedding
|
||||
|
||||
|
||||
def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, combine=False):
|
||||
recent_messages = list(
|
||||
db.messages.find(
|
||||
{"chat_id": chat_stream_id},
|
||||
{
|
||||
"time": 1, # 返回时间字段
|
||||
"chat_id": 1,
|
||||
"chat_info": 1,
|
||||
"user_info": 1,
|
||||
"message_id": 1, # 返回消息ID字段
|
||||
"detailed_plain_text": 1, # 返回处理后的文本字段
|
||||
},
|
||||
)
|
||||
.sort("time", -1)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
message_detailed_plain_text = ""
|
||||
message_detailed_plain_text_list = []
|
||||
|
||||
# 反转消息列表,使最新的消息在最后
|
||||
recent_messages.reverse()
|
||||
|
||||
if combine:
|
||||
for msg_db_data in recent_messages:
|
||||
message_detailed_plain_text += str(msg_db_data["detailed_plain_text"])
|
||||
return message_detailed_plain_text
|
||||
else:
|
||||
for msg_db_data in recent_messages:
|
||||
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
|
||||
return message_detailed_plain_text_list
|
||||
|
||||
|
||||
def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> list:
|
||||
# 获取当前群聊记录内发言的人
|
||||
recent_messages = list(
|
||||
db.messages.find(
|
||||
{"chat_id": chat_stream_id},
|
||||
{
|
||||
"user_info": 1,
|
||||
},
|
||||
)
|
||||
.sort("time", -1)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
who_chat_in_group = []
|
||||
for msg_db_data in recent_messages:
|
||||
user_info = UserInfo.from_dict(msg_db_data["user_info"])
|
||||
if (
|
||||
(user_info.platform, user_info.user_id) != sender
|
||||
and user_info.user_id != global_config.BOT_QQ
|
||||
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
||||
and len(who_chat_in_group) < 5
|
||||
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||
who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
|
||||
|
||||
return who_chat_in_group
|
||||
|
||||
|
||||
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
"""将文本分割成句子,并根据概率合并
|
||||
1. 识别分割点(, , 。 ; 空格),但如果分割点左右都是英文字母则不分割。
|
||||
2. 将文本分割成 (内容, 分隔符) 的元组。
|
||||
3. 根据原始文本长度计算合并概率,概率性地合并相邻段落。
|
||||
注意:此函数假定颜文字已在上层被保护。
|
||||
Args:
|
||||
text: 要分割的文本字符串 (假定颜文字已被保护)
|
||||
Returns:
|
||||
List[str]: 分割和合并后的句子列表
|
||||
"""
|
||||
# 预处理:处理多余的换行符
|
||||
# 1. 将连续的换行符替换为单个换行符
|
||||
text = re.sub(r"\n\s*\n+", "\n", text)
|
||||
# 2. 处理换行符和其他分隔符的组合
|
||||
text = re.sub(r"\n\s*([,,。;\s])", r"\1", text)
|
||||
text = re.sub(r"([,,。;\s])\s*\n", r"\1", text)
|
||||
|
||||
# 处理两个汉字中间的换行符
|
||||
text = re.sub(r"([\u4e00-\u9fff])\n([\u4e00-\u9fff])", r"\1。\2", text)
|
||||
|
||||
len_text = len(text)
|
||||
if len_text < 3:
|
||||
if random.random() < 0.01:
|
||||
return list(text) # 如果文本很短且触发随机条件,直接按字符分割
|
||||
else:
|
||||
return [text]
|
||||
|
||||
# 定义分隔符
|
||||
separators = {",", ",", " ", "。", ";"}
|
||||
segments = []
|
||||
current_segment = ""
|
||||
|
||||
# 1. 分割成 (内容, 分隔符) 元组
|
||||
i = 0
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
if char in separators:
|
||||
# 检查分割条件:如果分隔符左右都是英文字母,则不分割
|
||||
can_split = True
|
||||
if 0 < i < len(text) - 1:
|
||||
prev_char = text[i - 1]
|
||||
next_char = text[i + 1]
|
||||
# if is_english_letter(prev_char) and is_english_letter(next_char) and char == ' ': # 原计划只对空格应用此规则,现应用于所有分隔符
|
||||
if is_english_letter(prev_char) and is_english_letter(next_char):
|
||||
can_split = False
|
||||
|
||||
if can_split:
|
||||
# 只有当当前段不为空时才添加
|
||||
if current_segment:
|
||||
segments.append((current_segment, char))
|
||||
# 如果当前段为空,但分隔符是空格,则也添加一个空段(保留空格)
|
||||
elif char == " ":
|
||||
segments.append(("", char))
|
||||
current_segment = ""
|
||||
else:
|
||||
# 不分割,将分隔符加入当前段
|
||||
current_segment += char
|
||||
else:
|
||||
current_segment += char
|
||||
i += 1
|
||||
|
||||
# 添加最后一个段(没有后续分隔符)
|
||||
if current_segment:
|
||||
segments.append((current_segment, ""))
|
||||
|
||||
# 过滤掉完全空的段(内容和分隔符都为空)
|
||||
segments = [(content, sep) for content, sep in segments if content or sep]
|
||||
|
||||
# 如果分割后为空(例如,输入全是分隔符且不满足保留条件),恢复颜文字并返回
|
||||
if not segments:
|
||||
# recovered_text = recover_kaomoji([text], mapping) # 恢复原文本中的颜文字 - 已移至上层处理
|
||||
# return [s for s in recovered_text if s] # 返回非空结果
|
||||
return [text] if text else [] # 如果原始文本非空,则返回原始文本(可能只包含未被分割的字符或颜文字占位符)
|
||||
|
||||
# 2. 概率合并
|
||||
if len_text < 12:
|
||||
split_strength = 0.2
|
||||
elif len_text < 32:
|
||||
split_strength = 0.6
|
||||
else:
|
||||
split_strength = 0.7
|
||||
# 合并概率与分割强度相反
|
||||
merge_probability = 1.0 - split_strength
|
||||
|
||||
merged_segments = []
|
||||
idx = 0
|
||||
while idx < len(segments):
|
||||
current_content, current_sep = segments[idx]
|
||||
|
||||
# 检查是否可以与下一段合并
|
||||
# 条件:不是最后一段,且随机数小于合并概率,且当前段有内容(避免合并空段)
|
||||
if idx + 1 < len(segments) and random.random() < merge_probability and current_content:
|
||||
next_content, next_sep = segments[idx + 1]
|
||||
# 合并: (内容1 + 分隔符1 + 内容2, 分隔符2)
|
||||
# 只有当下一段也有内容时才合并文本,否则只传递分隔符
|
||||
if next_content:
|
||||
merged_content = current_content + current_sep + next_content
|
||||
merged_segments.append((merged_content, next_sep))
|
||||
else: # 下一段内容为空,只保留当前内容和下一段的分隔符
|
||||
merged_segments.append((current_content, next_sep))
|
||||
|
||||
idx += 2 # 跳过下一段,因为它已被合并
|
||||
else:
|
||||
# 不合并,直接添加当前段
|
||||
merged_segments.append((current_content, current_sep))
|
||||
idx += 1
|
||||
|
||||
# 提取最终的句子内容
|
||||
final_sentences = [content for content, sep in merged_segments if content] # 只保留有内容的段
|
||||
|
||||
# 清理可能引入的空字符串和仅包含空白的字符串
|
||||
final_sentences = [
|
||||
s for s in final_sentences if s.strip()
|
||||
] # 过滤掉空字符串以及仅包含空白(如换行符、空格)的字符串
|
||||
|
||||
logger.debug(f"分割并合并后的句子: {final_sentences}")
|
||||
return final_sentences
|
||||
|
||||
|
||||
def random_remove_punctuation(text: str) -> str:
|
||||
"""随机处理标点符号,模拟人类打字习惯
|
||||
|
||||
Args:
|
||||
text: 要处理的文本
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
result = ""
|
||||
text_len = len(text)
|
||||
|
||||
for i, char in enumerate(text):
|
||||
if char == "。" and i == text_len - 1: # 结尾的句号
|
||||
if random.random() > 0.1: # 90%概率删除结尾句号
|
||||
continue
|
||||
elif char == ",":
|
||||
rand = random.random()
|
||||
if rand < 0.25: # 5%概率删除逗号
|
||||
continue
|
||||
elif rand < 0.25: # 20%概率把逗号变成空格
|
||||
result += " "
|
||||
continue
|
||||
result += char
|
||||
return result
|
||||
|
||||
|
||||
def process_llm_response(text: str) -> list[str]:
|
||||
# 先保护颜文字
|
||||
if global_config.enable_kaomoji_protection:
|
||||
protected_text, kaomoji_mapping = protect_kaomoji(text)
|
||||
logger.trace(f"保护颜文字后的文本: {protected_text}")
|
||||
else:
|
||||
protected_text = text
|
||||
kaomoji_mapping = {}
|
||||
# 提取被 () 或 [] 包裹且包含中文的内容
|
||||
pattern = re.compile(r"[(\[(](?=.*[一-鿿]).*?[)\])]")
|
||||
# _extracted_contents = pattern.findall(text)
|
||||
_extracted_contents = pattern.findall(protected_text) # 在保护后的文本上查找
|
||||
# 去除 () 和 [] 及其包裹的内容
|
||||
cleaned_text = pattern.sub("", protected_text)
|
||||
|
||||
if cleaned_text == "":
|
||||
return ["呃呃"]
|
||||
|
||||
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
|
||||
|
||||
# 对清理后的文本进行进一步处理
|
||||
max_length = global_config.response_max_length * 2
|
||||
max_sentence_num = global_config.response_max_sentence_num
|
||||
# 如果基本上是中文,则进行长度过滤
|
||||
if get_western_ratio(cleaned_text) < 0.1:
|
||||
if len(cleaned_text) > max_length:
|
||||
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
|
||||
return ["懒得说"]
|
||||
|
||||
typo_generator = ChineseTypoGenerator(
|
||||
error_rate=global_config.chinese_typo_error_rate,
|
||||
min_freq=global_config.chinese_typo_min_freq,
|
||||
tone_error_rate=global_config.chinese_typo_tone_error_rate,
|
||||
word_replace_rate=global_config.chinese_typo_word_replace_rate,
|
||||
)
|
||||
|
||||
if global_config.enable_response_splitter:
|
||||
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
||||
else:
|
||||
split_sentences = [cleaned_text]
|
||||
|
||||
sentences = []
|
||||
for sentence in split_sentences:
|
||||
if global_config.chinese_typo_enable:
|
||||
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
|
||||
sentences.append(typoed_text)
|
||||
if typo_corrections:
|
||||
sentences.append(typo_corrections)
|
||||
else:
|
||||
sentences.append(sentence)
|
||||
|
||||
if len(sentences) > max_sentence_num:
|
||||
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
||||
return [f"{global_config.BOT_NICKNAME}不知道哦"]
|
||||
|
||||
# if extracted_contents:
|
||||
# for content in extracted_contents:
|
||||
# sentences.append(content)
|
||||
|
||||
# 在所有句子处理完毕后,对包含占位符的列表进行恢复
|
||||
if global_config.enable_kaomoji_protection:
|
||||
sentences = recover_kaomoji(sentences, kaomoji_mapping)
|
||||
|
||||
return sentences
|
||||
|
||||
|
||||
def calculate_typing_time(
|
||||
input_string: str,
|
||||
thinking_start_time: float,
|
||||
chinese_time: float = 0.2,
|
||||
english_time: float = 0.1,
|
||||
is_emoji: bool = False,
|
||||
) -> float:
|
||||
"""
|
||||
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
|
||||
input_string (str): 输入的字符串
|
||||
chinese_time (float): 中文字符的输入时间,默认为0.2秒
|
||||
english_time (float): 英文字符的输入时间,默认为0.1秒
|
||||
is_emoji (bool): 是否为emoji,默认为False
|
||||
|
||||
特殊情况:
|
||||
- 如果只有一个中文字符,将使用3倍的中文输入时间
|
||||
- 在所有输入结束后,额外加上回车时间0.3秒
|
||||
- 如果is_emoji为True,将使用固定1秒的输入时间
|
||||
"""
|
||||
# 将0-1的唤醒度映射到-1到1
|
||||
mood_arousal = mood_manager.current_mood.arousal
|
||||
# 映射到0.5到2倍的速度系数
|
||||
typing_speed_multiplier = 1.5**mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半
|
||||
chinese_time *= 1 / typing_speed_multiplier
|
||||
english_time *= 1 / typing_speed_multiplier
|
||||
# 计算中文字符数
|
||||
chinese_chars = sum(1 for char in input_string if "\u4e00" <= char <= "\u9fff")
|
||||
|
||||
# 如果只有一个中文字符,使用3倍时间
|
||||
if chinese_chars == 1 and len(input_string.strip()) == 1:
|
||||
return chinese_time * 3 + 0.3 # 加上回车时间
|
||||
|
||||
# 正常计算所有字符的输入时间
|
||||
total_time = 0.0
|
||||
for char in input_string:
|
||||
if "\u4e00" <= char <= "\u9fff": # 判断是否为中文字符
|
||||
total_time += chinese_time
|
||||
else: # 其他字符(如英文)
|
||||
total_time += english_time
|
||||
|
||||
if is_emoji:
|
||||
total_time = 1
|
||||
|
||||
if time.time() - thinking_start_time > 10:
|
||||
total_time = 1
|
||||
|
||||
# print(f"thinking_start_time:{thinking_start_time}")
|
||||
# print(f"nowtime:{time.time()}")
|
||||
# print(f"nowtime - thinking_start_time:{time.time() - thinking_start_time}")
|
||||
# print(f"{total_time}")
|
||||
|
||||
return total_time # 加上回车时间
|
||||
|
||||
|
||||
def cosine_similarity(v1, v2):
|
||||
"""计算余弦相似度"""
|
||||
dot_product = np.dot(v1, v2)
|
||||
norm1 = np.linalg.norm(v1)
|
||||
norm2 = np.linalg.norm(v2)
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
def text_to_vector(text):
|
||||
"""将文本转换为词频向量"""
|
||||
# 分词
|
||||
words = jieba.lcut(text)
|
||||
# 统计词频
|
||||
word_freq = Counter(words)
|
||||
return word_freq
|
||||
|
||||
|
||||
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
||||
"""使用简单的余弦相似度计算文本相似度"""
|
||||
# 将输入文本转换为词频向量
|
||||
text_vector = text_to_vector(text)
|
||||
|
||||
# 计算每个主题的相似度
|
||||
similarities = []
|
||||
for topic in topics:
|
||||
topic_vector = text_to_vector(topic)
|
||||
# 获取所有唯一词
|
||||
all_words = set(text_vector.keys()) | set(topic_vector.keys())
|
||||
# 构建向量
|
||||
v1 = [text_vector.get(word, 0) for word in all_words]
|
||||
v2 = [topic_vector.get(word, 0) for word in all_words]
|
||||
# 计算相似度
|
||||
similarity = cosine_similarity(v1, v2)
|
||||
similarities.append((topic, similarity))
|
||||
|
||||
# 按相似度降序排序并返回前k个
|
||||
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
|
||||
|
||||
|
||||
def truncate_message(message: str, max_length=20) -> str:
|
||||
"""截断消息,使其不超过指定长度"""
|
||||
if len(message) > max_length:
|
||||
return message[:max_length] + "..."
|
||||
return message
|
||||
|
||||
|
||||
def protect_kaomoji(sentence):
|
||||
""" "
|
||||
识别并保护句子中的颜文字(含括号与无括号),将其替换为占位符,
|
||||
并返回替换后的句子和占位符到颜文字的映射表。
|
||||
Args:
|
||||
sentence (str): 输入的原始句子
|
||||
Returns:
|
||||
tuple: (处理后的句子, {占位符: 颜文字})
|
||||
"""
|
||||
kaomoji_pattern = re.compile(
|
||||
r"("
|
||||
r"[(\[(【]" # 左括号
|
||||
r"[^()\[\]()【】]*?" # 非括号字符(惰性匹配)
|
||||
r"[^一-龥a-zA-Z0-9\s]" # 非中文、非英文、非数字、非空格字符(必须包含至少一个)
|
||||
r"[^()\[\]()【】]*?" # 非括号字符(惰性匹配)
|
||||
r"[)\])】" # 右括号
|
||||
r"]"
|
||||
r")"
|
||||
r"|"
|
||||
r"([▼▽・ᴥω・﹏^><≧≦ ̄`´∀ヮДд︿﹀へ。゚╥╯╰︶︹•⁄]{2,15})"
|
||||
)
|
||||
|
||||
kaomoji_matches = kaomoji_pattern.findall(sentence)
|
||||
placeholder_to_kaomoji = {}
|
||||
|
||||
for idx, match in enumerate(kaomoji_matches):
|
||||
kaomoji = match[0] if match[0] else match[1]
|
||||
placeholder = f"__KAOMOJI_{idx}__"
|
||||
sentence = sentence.replace(kaomoji, placeholder, 1)
|
||||
placeholder_to_kaomoji[placeholder] = kaomoji
|
||||
|
||||
return sentence, placeholder_to_kaomoji
|
||||
|
||||
|
||||
def recover_kaomoji(sentences, placeholder_to_kaomoji):
|
||||
"""
|
||||
根据映射表恢复句子中的颜文字。
|
||||
Args:
|
||||
sentences (list): 含有占位符的句子列表
|
||||
placeholder_to_kaomoji (dict): 占位符到颜文字的映射表
|
||||
Returns:
|
||||
list: 恢复颜文字后的句子列表
|
||||
"""
|
||||
recovered_sentences = []
|
||||
for sentence in sentences:
|
||||
for placeholder, kaomoji in placeholder_to_kaomoji.items():
|
||||
sentence = sentence.replace(placeholder, kaomoji)
|
||||
recovered_sentences.append(sentence)
|
||||
return recovered_sentences
|
||||
|
||||
|
||||
def get_western_ratio(paragraph):
|
||||
"""计算段落中字母数字字符的西文比例
|
||||
原理:检查段落中字母数字字符的西文比例
|
||||
通过is_english_letter函数判断每个字符是否为西文
|
||||
只检查字母数字字符,忽略标点符号和空格等非字母数字字符
|
||||
|
||||
Args:
|
||||
paragraph: 要检查的文本段落
|
||||
|
||||
Returns:
|
||||
float: 西文字符比例(0.0-1.0),如果没有字母数字字符则返回0.0
|
||||
"""
|
||||
alnum_chars = [char for char in paragraph if char.isalnum()]
|
||||
if not alnum_chars:
|
||||
return 0.0
|
||||
|
||||
western_count = sum(1 for char in alnum_chars if is_english_letter(char))
|
||||
return western_count / len(alnum_chars)
|
||||
|
||||
|
||||
def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]:
|
||||
"""计算两个时间点之间的消息数量和文本总长度
|
||||
|
||||
Args:
|
||||
start_time (float): 起始时间戳 (不包含)
|
||||
end_time (float): 结束时间戳 (包含)
|
||||
stream_id (str): 聊天流ID
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: (消息数量, 文本总长度)
|
||||
"""
|
||||
count = 0
|
||||
total_length = 0
|
||||
|
||||
# 参数校验 (可选但推荐)
|
||||
if start_time >= end_time:
|
||||
# logger.debug(f"开始时间 {start_time} 大于或等于结束时间 {end_time},返回 0, 0")
|
||||
return 0, 0
|
||||
if not stream_id:
|
||||
logger.error("stream_id 不能为空")
|
||||
return 0, 0
|
||||
|
||||
# 直接查询时间范围内的消息
|
||||
# time > start_time AND time <= end_time
|
||||
query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
|
||||
|
||||
try:
|
||||
# 执行查询
|
||||
messages_cursor = db.messages.find(query)
|
||||
|
||||
# 遍历结果计算数量和长度
|
||||
for msg in messages_cursor:
|
||||
count += 1
|
||||
total_length += len(msg.get("processed_plain_text", ""))
|
||||
|
||||
# logger.debug(f"查询范围 ({start_time}, {end_time}] 内找到 {count} 条消息,总长度 {total_length}")
|
||||
return count, total_length
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"查询 stream_id={stream_id} 在 ({start_time}, {end_time}] 范围内的消息时出错: {e}")
|
||||
return 0, 0
|
||||
except Exception as e: # 保留一个通用异常捕获以防万一
|
||||
logger.error(f"计算消息数量时发生意外错误: {e}")
|
||||
return 0, 0
|
||||
|
||||
|
||||
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
|
||||
"""将时间戳转换为人类可读的时间格式
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
mode: 转换模式,"normal"为标准格式,"relative"为相对时间格式
|
||||
|
||||
Returns:
|
||||
str: 格式化后的时间字符串
|
||||
"""
|
||||
if mode == "normal":
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
|
||||
elif mode == "relative":
|
||||
now = time.time()
|
||||
diff = now - timestamp
|
||||
|
||||
if diff < 20:
|
||||
return "刚刚:\n"
|
||||
elif diff < 60:
|
||||
return f"{int(diff)}秒前:\n"
|
||||
elif diff < 3600:
|
||||
return f"{int(diff / 60)}分钟前:\n"
|
||||
elif diff < 86400:
|
||||
return f"{int(diff / 3600)}小时前:\n"
|
||||
elif diff < 86400 * 2:
|
||||
return f"{int(diff / 86400)}天前:\n"
|
||||
else:
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":\n"
|
||||
else: # mode = "lite" or unknown
|
||||
# 只返回时分秒格式,喵~
|
||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||
|
||||
|
||||
def parse_text_timestamps(text: str, mode: str = "normal") -> str:
|
||||
"""解析文本中的时间戳并转换为可读时间格式
|
||||
|
||||
Args:
|
||||
text: 包含时间戳的文本,时间戳应以[]包裹
|
||||
mode: 转换模式,传递给translate_timestamp_to_human_readable,"normal"或"relative"
|
||||
|
||||
Returns:
|
||||
str: 替换后的文本
|
||||
|
||||
转换规则:
|
||||
- normal模式: 将文本中所有时间戳转换为可读格式
|
||||
- lite模式:
|
||||
- 第一个和最后一个时间戳必须转换
|
||||
- 以5秒为间隔划分时间段,每段最多转换一个时间戳
|
||||
- 不转换的时间戳替换为空字符串
|
||||
"""
|
||||
# 匹配[数字]或[数字.数字]格式的时间戳
|
||||
pattern = r"\[(\d+(?:\.\d+)?)\]"
|
||||
|
||||
# 找出所有匹配的时间戳
|
||||
matches = list(re.finditer(pattern, text))
|
||||
|
||||
if not matches:
|
||||
return text
|
||||
|
||||
# normal模式: 直接转换所有时间戳
|
||||
if mode == "normal":
|
||||
result_text = text
|
||||
for match in matches:
|
||||
timestamp = float(match.group(1))
|
||||
readable_time = translate_timestamp_to_human_readable(timestamp, "normal")
|
||||
# 由于替换会改变文本长度,需要使用正则替换而非直接替换
|
||||
pattern_instance = re.escape(match.group(0))
|
||||
result_text = re.sub(pattern_instance, readable_time, result_text, count=1)
|
||||
return result_text
|
||||
else:
|
||||
# lite模式: 按5秒间隔划分并选择性转换
|
||||
result_text = text
|
||||
|
||||
# 提取所有时间戳及其位置
|
||||
timestamps = [(float(m.group(1)), m) for m in matches]
|
||||
timestamps.sort(key=lambda x: x[0]) # 按时间戳升序排序
|
||||
|
||||
if not timestamps:
|
||||
return text
|
||||
|
||||
# 获取第一个和最后一个时间戳
|
||||
first_timestamp, first_match = timestamps[0]
|
||||
last_timestamp, last_match = timestamps[-1]
|
||||
|
||||
# 将时间范围划分成5秒间隔的时间段
|
||||
time_segments = {}
|
||||
|
||||
# 对所有时间戳按15秒间隔分组
|
||||
for ts, match in timestamps:
|
||||
segment_key = int(ts // 15) # 将时间戳除以15取整,作为时间段的键
|
||||
if segment_key not in time_segments:
|
||||
time_segments[segment_key] = []
|
||||
time_segments[segment_key].append((ts, match))
|
||||
|
||||
# 记录需要转换的时间戳
|
||||
to_convert = []
|
||||
|
||||
# 从每个时间段中选择一个时间戳进行转换
|
||||
for _, segment_timestamps in time_segments.items():
|
||||
# 选择这个时间段中的第一个时间戳
|
||||
to_convert.append(segment_timestamps[0])
|
||||
|
||||
# 确保第一个和最后一个时间戳在转换列表中
|
||||
first_in_list = False
|
||||
last_in_list = False
|
||||
|
||||
for ts, _ in to_convert:
|
||||
if ts == first_timestamp:
|
||||
first_in_list = True
|
||||
if ts == last_timestamp:
|
||||
last_in_list = True
|
||||
|
||||
if not first_in_list:
|
||||
to_convert.append((first_timestamp, first_match))
|
||||
if not last_in_list:
|
||||
to_convert.append((last_timestamp, last_match))
|
||||
|
||||
# 创建需要转换的时间戳集合,用于快速查找
|
||||
to_convert_set = {match.group(0) for _, match in to_convert}
|
||||
|
||||
# 首先替换所有不需要转换的时间戳为空字符串
|
||||
for _, match in timestamps:
|
||||
if match.group(0) not in to_convert_set:
|
||||
pattern_instance = re.escape(match.group(0))
|
||||
result_text = re.sub(pattern_instance, "", result_text, count=1)
|
||||
|
||||
# 按照时间戳原始顺序排序,避免替换时位置错误
|
||||
to_convert.sort(key=lambda x: x[1].start())
|
||||
|
||||
# 执行替换
|
||||
# 由于替换会改变文本长度,从后向前替换
|
||||
to_convert.reverse()
|
||||
for ts, match in to_convert:
|
||||
readable_time = translate_timestamp_to_human_readable(ts, "relative")
|
||||
pattern_instance = re.escape(match.group(0))
|
||||
result_text = re.sub(pattern_instance, readable_time, result_text, count=1)
|
||||
|
||||
return result_text
|
||||
379
src/chat/utils/utils_image.py
Normal file
379
src/chat/utils/utils_image.py
Normal file
@@ -0,0 +1,379 @@
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
from PIL import Image
|
||||
import io
|
||||
import numpy as np
|
||||
|
||||
|
||||
from ...common.database import db
|
||||
from ...config.config import global_config
|
||||
from ..models.utils_model import LLMRequest
|
||||
|
||||
from src.common.logger_manager import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_image")
|
||||
|
||||
|
||||
class ImageManager:
|
||||
_instance = None
|
||||
IMAGE_DIR = "data" # 图像存储根目录
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self._ensure_image_collection()
|
||||
self._ensure_description_collection()
|
||||
self._ensure_image_dir()
|
||||
self._initialized = True
|
||||
self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
|
||||
|
||||
def _ensure_image_dir(self):
|
||||
"""确保图像存储目录存在"""
|
||||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_image_collection():
|
||||
"""确保images集合存在并创建索引"""
|
||||
if "images" not in db.list_collection_names():
|
||||
db.create_collection("images")
|
||||
|
||||
# 删除旧索引
|
||||
db.images.drop_indexes()
|
||||
# 创建新的复合索引
|
||||
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
|
||||
db.images.create_index([("url", 1)])
|
||||
db.images.create_index([("path", 1)])
|
||||
|
||||
@staticmethod
|
||||
def _ensure_description_collection():
|
||||
"""确保image_descriptions集合存在并创建索引"""
|
||||
if "image_descriptions" not in db.list_collection_names():
|
||||
db.create_collection("image_descriptions")
|
||||
|
||||
# 删除旧索引
|
||||
db.image_descriptions.drop_indexes()
|
||||
# 创建新的复合索引
|
||||
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
||||
"""从数据库获取图片描述
|
||||
|
||||
Args:
|
||||
image_hash: 图片哈希值
|
||||
description_type: 描述类型 ('emoji' 或 'image')
|
||||
|
||||
Returns:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
|
||||
return result["description"] if result else None
|
||||
|
||||
@staticmethod
|
||||
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
|
||||
"""保存图片描述到数据库
|
||||
|
||||
Args:
|
||||
image_hash: 图片哈希值
|
||||
description: 描述文本
|
||||
description_type: 描述类型 ('emoji' 或 'image')
|
||||
"""
|
||||
try:
|
||||
db.image_descriptions.update_one(
|
||||
{"hash": image_hash, "type": description_type},
|
||||
{
|
||||
"$set": {
|
||||
"description": description,
|
||||
"timestamp": int(time.time()),
|
||||
"hash": image_hash, # 确保hash字段存在
|
||||
"type": description_type, # 确保type字段存在
|
||||
}
|
||||
},
|
||||
upsert=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败: {str(e)}")
|
||||
|
||||
async def get_emoji_description(self, image_base64: str) -> str:
|
||||
"""获取表情包描述,带查重和保存功能"""
|
||||
try:
|
||||
# 计算图片哈希
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
if cached_description:
|
||||
# logger.debug(f"缓存表情包描述: {cached_description}")
|
||||
return f"[表情包,含义看起来是:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
if image_format == "gif" or image_format == "GIF":
|
||||
image_base64 = self.transform_gif(image_base64)
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||
else:
|
||||
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
if cached_description:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||
return f"[表情包,含义看起来是:{cached_description}]"
|
||||
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.save_emoji:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")):
|
||||
os.makedirs(os.path.join(self.IMAGE_DIR, "emoji"))
|
||||
file_path = os.path.join(self.IMAGE_DIR, "emoji", filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
image_doc = {
|
||||
"hash": image_hash,
|
||||
"path": file_path,
|
||||
"type": "emoji",
|
||||
"description": description,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
|
||||
logger.trace(f"保存表情包: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件失败: {str(e)}")
|
||||
|
||||
# 保存描述到数据库
|
||||
self._save_description_to_db(image_hash, description, "emoji")
|
||||
|
||||
return f"[表情包:{description}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败: {str(e)}")
|
||||
return "[表情包]"
|
||||
|
||||
async def get_image_description(self, image_base64: str) -> str:
|
||||
"""获取普通图片描述,带查重和保存功能"""
|
||||
try:
|
||||
# 计算图片哈希
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, "image")
|
||||
if cached_description:
|
||||
logger.debug(f"图片描述缓存中 {cached_description}")
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
prompt = (
|
||||
"请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多100个字。"
|
||||
)
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
cached_description = self._get_description_from_db(image_hash, "image")
|
||||
if cached_description:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
logger.debug(f"描述是{description}")
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
return "[图片]"
|
||||
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.save_pic:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")):
|
||||
os.makedirs(os.path.join(self.IMAGE_DIR, "image"))
|
||||
file_path = os.path.join(self.IMAGE_DIR, "image", filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
image_doc = {
|
||||
"hash": image_hash,
|
||||
"path": file_path,
|
||||
"type": "image",
|
||||
"description": description,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
|
||||
logger.trace(f"保存图片: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件失败: {str(e)}")
|
||||
|
||||
# 保存描述到数据库
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
return f"[图片:{description}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取图片描述失败: {str(e)}")
|
||||
return "[图片]"
|
||||
|
||||
@staticmethod
|
||||
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
|
||||
"""将GIF转换为水平拼接的静态图像, 跳过相似的帧
|
||||
|
||||
Args:
|
||||
gif_base64: GIF的base64编码字符串
|
||||
similarity_threshold: 判定帧相似的阈值 (MSE),越小表示要求差异越大才算不同帧,默认1000.0
|
||||
max_frames: 最大抽取的帧数,默认15
|
||||
|
||||
Returns:
|
||||
Optional[str]: 拼接后的JPG图像的base64编码字符串, 或者在失败时返回None
|
||||
"""
|
||||
try:
|
||||
# 解码base64
|
||||
gif_data = base64.b64decode(gif_base64)
|
||||
gif = Image.open(io.BytesIO(gif_data))
|
||||
|
||||
# 收集所有帧
|
||||
all_frames = []
|
||||
try:
|
||||
while True:
|
||||
gif.seek(len(all_frames))
|
||||
# 确保是RGB格式方便比较
|
||||
frame = gif.convert("RGB")
|
||||
all_frames.append(frame.copy())
|
||||
except EOFError:
|
||||
pass # 读完啦
|
||||
|
||||
if not all_frames:
|
||||
logger.warning("GIF中没有找到任何帧")
|
||||
return None # 空的GIF直接返回None
|
||||
|
||||
# --- 新的帧选择逻辑 ---
|
||||
selected_frames = []
|
||||
last_selected_frame_np = None
|
||||
|
||||
for i, current_frame in enumerate(all_frames):
|
||||
current_frame_np = np.array(current_frame)
|
||||
|
||||
# 第一帧总是要选的
|
||||
if i == 0:
|
||||
selected_frames.append(current_frame)
|
||||
last_selected_frame_np = current_frame_np
|
||||
continue
|
||||
|
||||
# 计算和上一张选中帧的差异(均方误差 MSE)
|
||||
if last_selected_frame_np is not None:
|
||||
mse = np.mean((current_frame_np - last_selected_frame_np) ** 2)
|
||||
# logger.trace(f"帧 {i} 与上一选中帧的 MSE: {mse}") # 可以取消注释来看差异值
|
||||
|
||||
# 如果差异够大,就选它!
|
||||
if mse > similarity_threshold:
|
||||
selected_frames.append(current_frame)
|
||||
last_selected_frame_np = current_frame_np
|
||||
# 检查是不是选够了
|
||||
if len(selected_frames) >= max_frames:
|
||||
# logger.debug(f"已选够 {max_frames} 帧,停止选择。")
|
||||
break
|
||||
# 如果差异不大就跳过这一帧啦
|
||||
|
||||
# --- 帧选择逻辑结束 ---
|
||||
|
||||
# 如果选择后连一帧都没有(比如GIF只有一帧且后续处理失败?)或者原始GIF就没帧,也返回None
|
||||
if not selected_frames:
|
||||
logger.warning("处理后没有选中任何帧")
|
||||
return None
|
||||
|
||||
# logger.debug(f"总帧数: {len(all_frames)}, 选中帧数: {len(selected_frames)}")
|
||||
|
||||
# 获取选中的第一帧的尺寸(假设所有帧尺寸一致)
|
||||
frame_width, frame_height = selected_frames[0].size
|
||||
|
||||
# 计算目标尺寸,保持宽高比
|
||||
target_height = 200 # 固定高度
|
||||
# 防止除以零
|
||||
if frame_height == 0:
|
||||
logger.error("帧高度为0,无法计算缩放尺寸")
|
||||
return None
|
||||
target_width = int((target_height / frame_height) * frame_width)
|
||||
# 宽度也不能是0
|
||||
if target_width == 0:
|
||||
logger.warning(f"计算出的目标宽度为0 (原始尺寸 {frame_width}x{frame_height}),调整为1")
|
||||
target_width = 1
|
||||
|
||||
# 调整所有选中帧的大小
|
||||
resized_frames = [
|
||||
frame.resize((target_width, target_height), Image.Resampling.LANCZOS) for frame in selected_frames
|
||||
]
|
||||
|
||||
# 创建拼接图像
|
||||
total_width = target_width * len(resized_frames)
|
||||
# 防止总宽度为0
|
||||
if total_width == 0 and len(resized_frames) > 0:
|
||||
logger.warning("计算出的总宽度为0,但有选中帧,可能目标宽度太小")
|
||||
# 至少给点宽度吧
|
||||
total_width = len(resized_frames)
|
||||
elif total_width == 0:
|
||||
logger.error("计算出的总宽度为0且无选中帧")
|
||||
return None
|
||||
|
||||
combined_image = Image.new("RGB", (total_width, target_height))
|
||||
|
||||
# 水平拼接图像
|
||||
for idx, frame in enumerate(resized_frames):
|
||||
combined_image.paste(frame, (idx * target_width, 0))
|
||||
|
||||
# 转换为base64
|
||||
buffer = io.BytesIO()
|
||||
combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG
|
||||
result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
return result_base64
|
||||
|
||||
except MemoryError:
|
||||
logger.error("GIF转换失败: 内存不足,可能是GIF太大或帧数太多")
|
||||
return None # 内存不够啦
|
||||
except Exception as e:
|
||||
logger.error(f"GIF转换失败: {str(e)}", exc_info=True) # 记录详细错误信息
|
||||
return None # 其他错误也返回None
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
image_manager = ImageManager()
|
||||
|
||||
|
||||
def image_path_to_base64(image_path: str) -> str:
|
||||
"""将图片路径转换为base64编码
|
||||
Args:
|
||||
image_path: 图片文件路径
|
||||
Returns:
|
||||
str: base64编码的图片数据
|
||||
Raises:
|
||||
FileNotFoundError: 当图片文件不存在时
|
||||
IOError: 当读取图片文件失败时
|
||||
"""
|
||||
if not os.path.exists(image_path):
|
||||
raise FileNotFoundError(f"图片文件不存在: {image_path}")
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
if not image_data:
|
||||
raise IOError(f"读取图片文件失败: {image_path}")
|
||||
return base64.b64encode(image_data).decode("utf-8")
|
||||
Reference in New Issue
Block a user