Merge branch 'new-storage' into plugin

This commit is contained in:
SengokuCola
2025-05-16 21:14:16 +08:00
63 changed files with 2397 additions and 2008 deletions

View File

@@ -190,8 +190,8 @@ async def _build_readable_messages_internal(
person_id = person_info_manager.get_person_id(platform, user_id)
# 根据 replace_bot_name 参数决定是否替换机器人名称
if replace_bot_name and user_id == global_config.BOT_QQ:
person_name = f"{global_config.BOT_NICKNAME}(你)"
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
else:
person_name = await person_info_manager.get_value(person_id, "person_name")
@@ -427,7 +427,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
output_lines = []
def get_anon_name(platform, user_id):
if user_id == global_config.BOT_QQ:
if user_id == global_config.bot.qq_account:
return "SELF"
person_id = person_info_manager.get_person_id(platform, user_id)
if person_id not in person_map:
@@ -454,7 +454,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
def reply_replacer(match, platform=platform):
# aaa = match.group(1)
bbb = match.group(2)
anon_reply = get_anon_name(platform, bbb)
anon_reply = get_anon_name(platform, bbb) # noqa
return f"回复 {anon_reply}"
content = re.sub(reply_pattern, reply_replacer, content, count=1)
@@ -465,7 +465,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
def at_replacer(match, platform=platform):
# aaa = match.group(1)
bbb = match.group(2)
anon_at = get_anon_name(platform, bbb)
anon_at = get_anon_name(platform, bbb) # noqa
return f"@{anon_at}"
content = re.sub(at_pattern, at_replacer, content)
@@ -501,7 +501,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
user_id = user_info.get("user_id")
# 检查必要信息是否存在 且 不是机器人自己
if not all([platform, user_id]) or user_id == global_config.BOT_QQ:
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue
person_id = person_info_manager.get_person_id(platform, user_id)

View File

@@ -1,15 +1,15 @@
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv, MessageSending, Message
from src.common.database import db
from src.common.database.database_model import Messages, ThinkingLog
import time
import traceback
from typing import List
import json
class InfoCatcher:
def __init__(self):
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~
self.context_length = global_config.observation_context_size
self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~
@@ -60,8 +60,6 @@ class InfoCatcher:
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
self.timing_results["sub_heartflow_observe_time"] = obs_duration
# def catch_shf
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
self.timing_results["sub_heartflow_step_time"] = step_duration
if len(past_mind) > 1:
@@ -72,25 +70,10 @@ class InfoCatcher:
self.heartflow_data["sub_heartflow_now"] = current_mind
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
# if self.response_mode == "heart_flow": # 条件判断不需要了喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
# elif self.response_mode == "reasoning": # 条件判断不需要了喵~
# self.reasoning_data["thinking_log"] = reasoning_content
# self.reasoning_data["prompt"] = prompt
# self.reasoning_data["response"] = response
# self.reasoning_data["model"] = model_name
# 直接记录信息喵~
self.reasoning_data["thinking_log"] = reasoning_content
self.reasoning_data["prompt"] = prompt
self.reasoning_data["response"] = response
self.reasoning_data["model"] = model_name
# 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
self.response_text = response
@@ -102,6 +85,7 @@ class InfoCatcher:
):
self.timing_results["make_response_time"] = response_duration
self.response_time = time.time()
self.response_messages = []
for msg in response_message:
self.response_messages.append(msg)
@@ -112,107 +96,112 @@ class InfoCatcher:
@staticmethod
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
try:
# 从数据库中获取消息的时间戳
time_start = message_start.message_info.time
time_end = message_end.message_info.time
chat_id = message_start.chat_stream.stream_id
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
messages_between = db.messages.find(
{"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}}
).sort("time", -1)
messages_between_query = (
Messages.select()
.where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end))
.order_by(Messages.time.desc())
)
result = list(messages_between)
result = list(messages_between_query)
print(f"查询结果数量: {len(result)}")
if result:
print(f"第一条消息时间: {result[0]['time']}")
print(f"最后一条消息时间: {result[-1]['time']}")
print(f"第一条消息时间: {result[0].time}")
print(f"最后一条消息时间: {result[-1].time}")
return result
except Exception as e:
print(f"获取消息时出错: {str(e)}")
print(traceback.format_exc())
return []
def get_message_from_db_before_msg(self, message: MessageRecv):
# 从数据库中获取消息
message_id = message.message_info.message_id
chat_id = message.chat_stream.stream_id
message_id_val = message.message_info.message_id
chat_id_val = message.chat_stream.stream_id
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
messages_before = (
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
.sort("time", -1)
.limit(self.context_length * 3)
) # 获取更多历史信息
messages_before_query = (
Messages.select()
.where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val))
.order_by(Messages.time.desc())
.limit(global_config.chat.observation_context_size * 3)
)
return list(messages_before)
return list(messages_before_query)
def message_list_to_dict(self, message_list):
# 存储简化的聊天记录
result = []
for message in message_list:
if not isinstance(message, dict):
message = self.message_to_dict(message)
# print(message)
for msg_item in message_list:
processed_msg_item = msg_item
if not isinstance(msg_item, dict):
processed_msg_item = self.message_to_dict(msg_item)
if not processed_msg_item:
continue
lite_message = {
"time": message["time"],
"user_nickname": message["user_info"]["user_nickname"],
"processed_plain_text": message["processed_plain_text"],
"time": processed_msg_item.get("time"),
"user_nickname": processed_msg_item.get("user_nickname"),
"processed_plain_text": processed_msg_item.get("processed_plain_text"),
}
result.append(lite_message)
return result
@staticmethod
def message_to_dict(message):
if not message:
def message_to_dict(msg_obj):
if not msg_obj:
return None
if isinstance(message, dict):
return message
return {
# "message_id": message.message_info.message_id,
"time": message.message_info.time,
"user_id": message.message_info.user_info.user_id,
"user_nickname": message.message_info.user_info.user_nickname,
"processed_plain_text": message.processed_plain_text,
# "detailed_plain_text": message.detailed_plain_text
}
if isinstance(msg_obj, dict):
return msg_obj
def done_catch(self):
"""将收集到的信息存储到数据库的 thinking_log 集合中喵~"""
try:
# 将消息对象转换为可序列化的字典喵~
thinking_log_data = {
"chat_id": self.chat_id,
"trigger_text": self.trigger_response_text,
"response_text": self.response_text,
"trigger_info": {
"time": self.trigger_response_time,
"message": self.message_to_dict(self.trigger_response_message),
},
"response_info": {
"time": self.response_time,
"message": self.response_messages,
},
"timing_results": self.timing_results,
"chat_history": self.message_list_to_dict(self.chat_history),
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
"heartflow_data": self.heartflow_data,
"reasoning_data": self.reasoning_data,
if isinstance(msg_obj, Messages):
return {
"time": msg_obj.time,
"user_id": msg_obj.user_id,
"user_nickname": msg_obj.user_nickname,
"processed_plain_text": msg_obj.processed_plain_text,
}
# 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~
# if self.response_mode == "heart_flow":
# thinking_log_data["mode_specific_data"] = self.heartflow_data
# elif self.response_mode == "reasoning":
# thinking_log_data["mode_specific_data"] = self.reasoning_data
if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"):
return {
"time": msg_obj.message_info.time,
"user_id": msg_obj.message_info.user_info.user_id,
"user_nickname": msg_obj.message_info.user_info.user_nickname,
"processed_plain_text": msg_obj.processed_plain_text,
}
# 将数据插入到 thinking_log 集合中喵~
db.thinking_log.insert_one(thinking_log_data)
print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}")
return {}
def done_catch(self):
"""将收集到的信息存储到数据库的 thinking_log 表中喵~"""
try:
trigger_info_dict = self.message_to_dict(self.trigger_response_message)
response_info_dict = {
"time": self.response_time,
"message": self.response_messages,
}
chat_history_list = self.message_list_to_dict(self.chat_history)
chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking)
chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response)
log_entry = ThinkingLog(
chat_id=self.chat_id,
trigger_text=self.trigger_response_text,
response_text=self.response_text,
trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None,
response_info_json=json.dumps(response_info_dict),
timing_results_json=json.dumps(self.timing_results),
chat_history_json=json.dumps(chat_history_list),
chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
chat_history_after_response_json=json.dumps(chat_history_after_response_list),
heartflow_data_json=json.dumps(self.heartflow_data),
reasoning_data_json=json.dumps(self.reasoning_data),
)
log_entry.save()
return True
except Exception as e:

View File

@@ -2,10 +2,12 @@ from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
from src.common.logger import get_module_logger
from src.manager.async_task_manager import AsyncTask
from ...common.database import db
from ...common.database.database import db # This db is the Peewee database instance
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
from src.manager.local_store_manager import local_storage
logger = get_module_logger("maibot_statistic")
@@ -39,7 +41,7 @@ class OnlineTimeRecordTask(AsyncTask):
def __init__(self):
super().__init__(task_name="Online Time Record Task", run_interval=60)
self.record_id: str | None = None
self.record_id: int | None = None # Changed to int for Peewee's default ID
"""记录ID"""
self._init_database() # 初始化数据库
@@ -47,49 +49,46 @@ class OnlineTimeRecordTask(AsyncTask):
@staticmethod
def _init_database():
"""初始化数据库"""
if "online_time" not in db.list_collection_names():
# 初始化数据库(在线时长)
db.create_collection("online_time")
# 创建索引
if ("end_timestamp", 1) not in db.online_time.list_indexes():
db.online_time.create_index([("end_timestamp", 1)])
with db.atomic(): # Use atomic operations for schema changes
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
async def run(self):
try:
current_time = datetime.now()
extended_end_time = current_time + timedelta(minutes=1)
if self.record_id:
# 如果有记录,则更新结束时间
db.online_time.update_one(
{"_id": self.record_id},
{
"$set": {
"end_timestamp": datetime.now() + timedelta(minutes=1),
}
},
)
else:
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
updated_rows = query.execute()
if updated_rows == 0:
# Record might have been deleted or ID is stale, try to find/create
self.record_id = None # Reset record_id to trigger find/create logic below
if not self.record_id: # Check again if record_id was reset or initially None
# 如果没有记录,检查一分钟以内是否已有记录
current_time = datetime.now()
if recent_record := db.online_time.find_one(
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}}
):
# Look for a record whose end_timestamp is recent enough to be considered ongoing
recent_record = (
OnlineTime.select()
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
.order_by(OnlineTime.end_timestamp.desc())
.first()
)
if recent_record:
# 如果有记录,则更新结束时间
self.record_id = recent_record["_id"]
db.online_time.update_one(
{"_id": self.record_id},
{
"$set": {
"end_timestamp": current_time + timedelta(minutes=1),
}
},
)
self.record_id = recent_record.id
recent_record.end_timestamp = extended_end_time
recent_record.save()
else:
# 若没有记录,则插入新的在线时间记录
self.record_id = db.online_time.insert_one(
{
"start_timestamp": current_time,
"end_timestamp": current_time + timedelta(minutes=1),
}
).inserted_id
new_record = OnlineTime.create(
timestamp=current_time.timestamp(), # 添加此行
start_timestamp=current_time,
end_timestamp=extended_end_time,
duration=5, # 初始时长为5分钟
)
self.record_id = new_record.id
except Exception as e:
logger.error(f"在线时间记录失败,错误信息:{e}")
@@ -201,35 +200,28 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
if not collect_period:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 总LLM请求数
TOTAL_REQ_CNT: 0,
# 请求次数统计
REQ_CNT_BY_TYPE: defaultdict(int),
REQ_CNT_BY_USER: defaultdict(int),
REQ_CNT_BY_MODEL: defaultdict(int),
# 输入Token数
IN_TOK_BY_TYPE: defaultdict(int),
IN_TOK_BY_USER: defaultdict(int),
IN_TOK_BY_MODEL: defaultdict(int),
# 输出Token数
OUT_TOK_BY_TYPE: defaultdict(int),
OUT_TOK_BY_USER: defaultdict(int),
OUT_TOK_BY_MODEL: defaultdict(int),
# 总Token数
TOTAL_TOK_BY_TYPE: defaultdict(int),
TOTAL_TOK_BY_USER: defaultdict(int),
TOTAL_TOK_BY_MODEL: defaultdict(int),
# 总开销
TOTAL_COST: 0.0,
# 请求开销统计
COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float),
@@ -238,26 +230,26 @@ class StatisticOutputTask(AsyncTask):
}
# 以最早的时间戳为起始时间获取记录
for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}):
record_timestamp = record.get("timestamp")
# Assuming LLMUsage.timestamp is a DateTimeField
query_start_time = collect_period[-1][1]
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
record_timestamp = record.timestamp # This is already a datetime object
for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.get("request_type", "unknown") # 请求类型
user_id = str(record.get("user_id", "unknown")) # 用户ID
model_name = record.get("model_name", "unknown") # 模型名称
request_type = record.request_type or "unknown"
user_id = record.user_id or "unknown" # user_id is TextField, already string
model_name = record.model_name or "unknown"
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数
completion_tokens = record.get("completion_tokens", 0) # 输出Token数
total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数
prompt_tokens = record.prompt_tokens or 0
completion_tokens = record.completion_tokens or 0
total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
@@ -271,13 +263,12 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
cost = record.get("cost", 0.0)
cost = record.cost or 0.0
stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost
stats[period_key][COST_BY_MODEL][model_name] += cost
break # 取消更早时间段的判断
break
return stats
@staticmethod
@@ -287,39 +278,38 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
if not collect_period:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 在线时间统计
ONLINE_TIME: 0.0,
}
for period_key, _ in collect_period
}
# 统计在线时间
for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}):
end_timestamp: datetime = record.get("end_timestamp")
for idx, (_, period_start) in enumerate(collect_period):
if end_timestamp >= period_start:
# 由于end_timestamp会超前标记时间所以我们需要判断是否晚于当前时间如果是则使用当前时间作为结束时间
end_timestamp = min(end_timestamp, now)
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _period_start in collect_period[idx:]:
start_timestamp: datetime = record.get("start_timestamp")
if start_timestamp < _period_start:
# 如果开始时间在查询边界之前,则使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds()
else:
# 否则,使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds()
break # 取消更早时间段的判断
query_start_time = collect_period[-1][1]
# Assuming OnlineTime.end_timestamp is a DateTimeField
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
# record.end_timestamp and record.start_timestamp are datetime objects
record_end_timestamp = record.end_timestamp
record_start_timestamp = record.start_timestamp
for idx, (_, period_boundary_start) in enumerate(collect_period):
if record_end_timestamp >= period_boundary_start:
# Calculate effective end time for this record in relation to 'now'
effective_end_time = min(record_end_timestamp, now)
for period_key, current_period_start_time in collect_period[idx:]:
# Determine the portion of the record that falls within this specific statistical period
overlap_start = max(record_start_timestamp, current_period_start_time)
overlap_end = effective_end_time # Already capped by 'now' and record's own end
if overlap_end > overlap_start:
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
break
return stats
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
@@ -328,55 +318,57 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
if not collect_period:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 消息统计
TOTAL_MSG_CNT: 0,
MSG_CNT_BY_CHAT: defaultdict(int),
}
for period_key, _ in collect_period
}
# 统计消息量
for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}):
chat_info = message.get("chat_info", None) # 聊天信息
user_info = message.get("user_info", None) # 用户信息(消息发送人)
message_time = message.get("time", 0) # 消息时间
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp):
message_time_ts = message.time # This is a float timestamp
group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息
if group_info is not None:
# 若有群聊信息
chat_id = f"g{group_info.get('group_id')}"
chat_name = group_info.get("group_name", f"{group_info.get('group_id')}")
elif user_info:
# 若没有群聊信息,则尝试获取用户信息
chat_id = f"u{user_info['user_id']}"
chat_name = user_info["user_nickname"]
chat_id = None
chat_name = None
# Logic based on Peewee model structure, aiming to replicate original intent
if message.chat_info_group_id:
chat_id = f"g{message.chat_info_group_id}"
chat_name = message.chat_info_group_name or f"{message.chat_info_group_id}"
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
# This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message.user_id}" # SENDER's user_id
chat_name = message.user_nickname # SENDER's nickname
else:
continue # 如果没有群组信息也没有用户信息,则跳过
# If neither group_id nor sender_id is available for chat identification
logger.warning(
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
)
continue
if not chat_id: # Should not happen if above logic is correct
continue
# Update name_mapping
if chat_id in self.name_mapping:
if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]:
# 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称
self.name_mapping[chat_id] = (chat_name, message_time)
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
self.name_mapping[chat_id] = (chat_name, message_time_ts)
else:
self.name_mapping[chat_id] = (chat_name, message_time)
self.name_mapping[chat_id] = (chat_name, message_time_ts)
for idx, (_, period_start) in enumerate(collect_period):
if message_time >= period_start.timestamp():
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for idx, (_, period_start_dt) in enumerate(collect_period):
if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break
return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:

View File

@@ -13,7 +13,7 @@ from src.manager.mood_manager import mood_manager
from ..message_receive.message import MessageRecv
from ..models.utils_model import LLMRequest
from .typo_generator import ChineseTypoGenerator
from ...common.database import db
from ...common.database.database import db
from ...config.config import global_config
logger = get_module_logger("chat_utils")
@@ -43,8 +43,8 @@ def db_message_to_str(message_dict: dict) -> str:
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
nicknames = global_config.BOT_ALIAS_NAMES
keywords = [global_config.bot.nickname]
nicknames = global_config.bot.alias_names
reply_probability = 0.0
is_at = False
is_mentioned = False
@@ -64,18 +64,18 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
)
# 判断是否被@
if re.search(f"@[\s\S]*?id:{global_config.BOT_QQ}", message.processed_plain_text):
if re.search(f"@[\s\S]*?id:{global_config.bot.qq_account}", message.processed_plain_text):
is_at = True
is_mentioned = True
if is_at and global_config.at_bot_inevitable_reply:
if is_at and global_config.normal_chat.at_bot_inevitable_reply:
reply_probability = 1.0
logger.info("被@回复概率设置为100%")
else:
if not is_mentioned:
# 判断是否被回复
if re.match(
f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\)[\s\S]*?],说:", message.processed_plain_text
f"\[回复 [\s\S]*?\({str(global_config.bot.qq_account)}\)[\s\S]*?],说:", message.processed_plain_text
):
is_mentioned = True
else:
@@ -88,7 +88,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
for nickname in nicknames:
if nickname in message_content:
is_mentioned = True
if is_mentioned and global_config.mentioned_bot_inevitable_reply:
if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply:
reply_probability = 1.0
logger.info("被提及回复概率设置为100%")
return is_mentioned, reply_probability
@@ -96,7 +96,8 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
async def get_embedding(text, request_type="embedding"):
"""获取文本的embedding向量"""
llm = LLMRequest(model=global_config.embedding, request_type=request_type)
# TODO: API-Adapter修改标记
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
# return llm.get_embedding_sync(text)
try:
embedding = await llm.get_embedding(text)
@@ -163,7 +164,7 @@ def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> li
user_info = UserInfo.from_dict(msg_db_data["user_info"])
if (
(user_info.platform, user_info.user_id) != sender
and user_info.user_id != global_config.BOT_QQ
and user_info.user_id != global_config.bot.qq_account
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目
@@ -321,7 +322,7 @@ def random_remove_punctuation(text: str) -> str:
def process_llm_response(text: str) -> list[str]:
# 先保护颜文字
if global_config.enable_kaomoji_protection:
if global_config.response_splitter.enable_kaomoji_protection:
protected_text, kaomoji_mapping = protect_kaomoji(text)
logger.trace(f"保护颜文字后的文本: {protected_text}")
else:
@@ -340,8 +341,8 @@ def process_llm_response(text: str) -> list[str]:
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
# 对清理后的文本进行进一步处理
max_length = global_config.response_max_length * 2
max_sentence_num = global_config.response_max_sentence_num
max_length = global_config.response_splitter.max_length * 2
max_sentence_num = global_config.response_splitter.max_sentence_num
# 如果基本上是中文,则进行长度过滤
if get_western_ratio(cleaned_text) < 0.1:
if len(cleaned_text) > max_length:
@@ -349,20 +350,20 @@ def process_llm_response(text: str) -> list[str]:
return ["懒得说"]
typo_generator = ChineseTypoGenerator(
error_rate=global_config.chinese_typo_error_rate,
min_freq=global_config.chinese_typo_min_freq,
tone_error_rate=global_config.chinese_typo_tone_error_rate,
word_replace_rate=global_config.chinese_typo_word_replace_rate,
error_rate=global_config.chinese_typo.error_rate,
min_freq=global_config.chinese_typo.min_freq,
tone_error_rate=global_config.chinese_typo.tone_error_rate,
word_replace_rate=global_config.chinese_typo.word_replace_rate,
)
if global_config.enable_response_splitter:
if global_config.response_splitter.enable:
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
else:
split_sentences = [cleaned_text]
sentences = []
for sentence in split_sentences:
if global_config.chinese_typo_enable:
if global_config.chinese_typo.enable:
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
sentences.append(typoed_text)
if typo_corrections:
@@ -372,7 +373,7 @@ def process_llm_response(text: str) -> list[str]:
if len(sentences) > max_sentence_num:
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f"{global_config.BOT_NICKNAME}不知道哦"]
return [f"{global_config.bot.nickname}不知道哦"]
# if extracted_contents:
# for content in extracted_contents:

View File

@@ -8,7 +8,8 @@ import io
import numpy as np
from ...common.database import db
from ...common.database.database import db
from ...common.database.database_model import Images, ImageDescriptions
from ...config.config import global_config
from ..models.utils_model import LLMRequest
@@ -32,40 +33,23 @@ class ImageManager:
def __init__(self):
if not self._initialized:
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir()
self._initialized = True
self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions], safe=True)
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
self._initialized = True
self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
def _ensure_image_dir(self):
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
@staticmethod
def _ensure_image_collection():
"""确保images集合存在并创建索引"""
if "images" not in db.list_collection_names():
db.create_collection("images")
# 删除旧索引
db.images.drop_indexes()
# 创建新的复合索引
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)])
@staticmethod
def _ensure_description_collection():
"""确保image_descriptions集合存在并创建索引"""
if "image_descriptions" not in db.list_collection_names():
db.create_collection("image_descriptions")
# 删除旧索引
db.image_descriptions.drop_indexes()
# 创建新的复合索引
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
@staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
@@ -77,8 +61,14 @@ class ImageManager:
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
return result["description"] if result else None
try:
record = ImageDescriptions.get_or_none(
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
)
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
return None
@staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
@@ -90,20 +80,17 @@ class ImageManager:
description_type: 描述类型 ('emoji''image')
"""
try:
db.image_descriptions.update_one(
{"hash": image_hash, "type": description_type},
{
"$set": {
"description": description,
"timestamp": int(time.time()),
"hash": image_hash, # 确保hash字段存在
"type": description_type, # 确保type字段存在
}
},
upsert=True,
current_timestamp = time.time()
defaults = {"description": description, "timestamp": current_timestamp}
desc_obj, created = ImageDescriptions.get_or_create(
hash=image_hash, type=description_type, defaults=defaults
)
if not created: # 如果记录已存在,则更新
desc_obj.description = description
desc_obj.timestamp = current_timestamp
desc_obj.save()
except Exception as e:
logger.error(f"保存描述到数据库失败: {str(e)}")
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能"""
@@ -116,51 +103,64 @@ class ImageManager:
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
# logger.debug(f"缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]"
# 调用AI获取描述
if image_format == "gif" or image_format == "GIF":
image_base64 = self.transform_gif(image_base64)
image_base64_processed = self.transform_gif(image_base64)
if image_base64_processed is None:
logger.warning("GIF转换失败无法获取描述")
return "[表情包(GIF处理失败)]"
prompt = "这是一个动态图表情包每一张图代表了动态图的某一帧黑色背景代表透明使用1-2个词描述一下表情包表达的情感和内容简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg")
description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
else:
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成表情包描述")
return "[表情包(描述生成失败)]"
# 再次检查缓存,防止并发写入时重复生成
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]"
# 根据配置决定是否保存图片
if global_config.save_emoji:
if global_config.emoji.save_emoji:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")):
os.makedirs(os.path.join(self.IMAGE_DIR, "emoji"))
file_path = os.path.join(self.IMAGE_DIR, "emoji", filename)
current_timestamp = time.time()
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
os.makedirs(emoji_dir, exist_ok=True)
file_path = os.path.join(emoji_dir, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
"hash": image_hash,
"path": file_path,
"type": "emoji",
"description": description,
"timestamp": timestamp,
}
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
logger.trace(f"保存表情包: {file_path}")
# 保存到数据库 (Images表)
try:
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
img_obj.path = file_path
img_obj.description = description
img_obj.timestamp = current_timestamp
img_obj.save()
except Images.DoesNotExist:
Images.create(
hash=image_hash,
path=file_path,
type="emoji",
description=description,
timestamp=current_timestamp,
)
logger.trace(f"保存表情包元数据: {file_path}")
except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}")
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
# 保存描述到数据库
# 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "emoji")
return f"[表情包:{description}]"
@@ -188,6 +188,11 @@ class ImageManager:
)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片(描述生成失败)]"
# 再次检查缓存
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
@@ -195,38 +200,40 @@ class ImageManager:
logger.debug(f"描述是{description}")
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片]"
# 根据配置决定是否保存图片
if global_config.save_pic:
if global_config.emoji.save_pic:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")):
os.makedirs(os.path.join(self.IMAGE_DIR, "image"))
file_path = os.path.join(self.IMAGE_DIR, "image", filename)
current_timestamp = time.time()
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
image_dir = os.path.join(self.IMAGE_DIR, "image")
os.makedirs(image_dir, exist_ok=True)
file_path = os.path.join(image_dir, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
"hash": image_hash,
"path": file_path,
"type": "image",
"description": description,
"timestamp": timestamp,
}
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
logger.trace(f"保存图片: {file_path}")
# 保存到数据库 (Images表)
try:
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "image"))
img_obj.path = file_path
img_obj.description = description
img_obj.timestamp = current_timestamp
img_obj.save()
except Images.DoesNotExist:
Images.create(
hash=image_hash,
path=file_path,
type="image",
description=description,
timestamp=current_timestamp,
)
logger.trace(f"保存图片元数据: {file_path}")
except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}")
logger.error(f"保存图片文件或元数据失败: {str(e)}")
# 保存描述到数据库
# 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "image")
return f"[图片:{description}]"