Merge branch 'new-storage' into plugin
This commit is contained in:
@@ -190,8 +190,8 @@ async def _build_readable_messages_internal(
|
||||
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
||||
if replace_bot_name and user_id == global_config.BOT_QQ:
|
||||
person_name = f"{global_config.BOT_NICKNAME}(你)"
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
@@ -427,7 +427,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
output_lines = []
|
||||
|
||||
def get_anon_name(platform, user_id):
|
||||
if user_id == global_config.BOT_QQ:
|
||||
if user_id == global_config.bot.qq_account:
|
||||
return "SELF"
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
if person_id not in person_map:
|
||||
@@ -454,7 +454,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
def reply_replacer(match, platform=platform):
|
||||
# aaa = match.group(1)
|
||||
bbb = match.group(2)
|
||||
anon_reply = get_anon_name(platform, bbb)
|
||||
anon_reply = get_anon_name(platform, bbb) # noqa
|
||||
return f"回复 {anon_reply}"
|
||||
|
||||
content = re.sub(reply_pattern, reply_replacer, content, count=1)
|
||||
@@ -465,7 +465,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
def at_replacer(match, platform=platform):
|
||||
# aaa = match.group(1)
|
||||
bbb = match.group(2)
|
||||
anon_at = get_anon_name(platform, bbb)
|
||||
anon_at = get_anon_name(platform, bbb) # noqa
|
||||
return f"@{anon_at}"
|
||||
|
||||
content = re.sub(at_pattern, at_replacer, content)
|
||||
@@ -501,7 +501,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
user_id = user_info.get("user_id")
|
||||
|
||||
# 检查必要信息是否存在 且 不是机器人自己
|
||||
if not all([platform, user_id]) or user_id == global_config.BOT_QQ:
|
||||
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||
continue
|
||||
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending, Message
|
||||
from src.common.database import db
|
||||
from src.common.database.database_model import Messages, ThinkingLog
|
||||
import time
|
||||
import traceback
|
||||
from typing import List
|
||||
import json
|
||||
|
||||
|
||||
class InfoCatcher:
|
||||
def __init__(self):
|
||||
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~
|
||||
self.context_length = global_config.observation_context_size
|
||||
self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~
|
||||
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~
|
||||
|
||||
@@ -60,8 +60,6 @@ class InfoCatcher:
|
||||
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
|
||||
self.timing_results["sub_heartflow_observe_time"] = obs_duration
|
||||
|
||||
# def catch_shf
|
||||
|
||||
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
|
||||
self.timing_results["sub_heartflow_step_time"] = step_duration
|
||||
if len(past_mind) > 1:
|
||||
@@ -72,25 +70,10 @@ class InfoCatcher:
|
||||
self.heartflow_data["sub_heartflow_now"] = current_mind
|
||||
|
||||
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
|
||||
# if self.response_mode == "heart_flow": # 条件判断不需要了喵~
|
||||
# self.heartflow_data["prompt"] = prompt
|
||||
# self.heartflow_data["response"] = response
|
||||
# self.heartflow_data["model"] = model_name
|
||||
# elif self.response_mode == "reasoning": # 条件判断不需要了喵~
|
||||
# self.reasoning_data["thinking_log"] = reasoning_content
|
||||
# self.reasoning_data["prompt"] = prompt
|
||||
# self.reasoning_data["response"] = response
|
||||
# self.reasoning_data["model"] = model_name
|
||||
|
||||
# 直接记录信息喵~
|
||||
self.reasoning_data["thinking_log"] = reasoning_content
|
||||
self.reasoning_data["prompt"] = prompt
|
||||
self.reasoning_data["response"] = response
|
||||
self.reasoning_data["model"] = model_name
|
||||
# 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~
|
||||
# self.heartflow_data["prompt"] = prompt
|
||||
# self.heartflow_data["response"] = response
|
||||
# self.heartflow_data["model"] = model_name
|
||||
|
||||
self.response_text = response
|
||||
|
||||
@@ -102,6 +85,7 @@ class InfoCatcher:
|
||||
):
|
||||
self.timing_results["make_response_time"] = response_duration
|
||||
self.response_time = time.time()
|
||||
self.response_messages = []
|
||||
for msg in response_message:
|
||||
self.response_messages.append(msg)
|
||||
|
||||
@@ -112,107 +96,112 @@ class InfoCatcher:
|
||||
@staticmethod
|
||||
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
|
||||
try:
|
||||
# 从数据库中获取消息的时间戳
|
||||
time_start = message_start.message_info.time
|
||||
time_end = message_end.message_info.time
|
||||
chat_id = message_start.chat_stream.stream_id
|
||||
|
||||
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
|
||||
|
||||
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
|
||||
messages_between = db.messages.find(
|
||||
{"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}}
|
||||
).sort("time", -1)
|
||||
messages_between_query = (
|
||||
Messages.select()
|
||||
.where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end))
|
||||
.order_by(Messages.time.desc())
|
||||
)
|
||||
|
||||
result = list(messages_between)
|
||||
result = list(messages_between_query)
|
||||
print(f"查询结果数量: {len(result)}")
|
||||
if result:
|
||||
print(f"第一条消息时间: {result[0]['time']}")
|
||||
print(f"最后一条消息时间: {result[-1]['time']}")
|
||||
print(f"第一条消息时间: {result[0].time}")
|
||||
print(f"最后一条消息时间: {result[-1].time}")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"获取消息时出错: {str(e)}")
|
||||
print(traceback.format_exc())
|
||||
return []
|
||||
|
||||
def get_message_from_db_before_msg(self, message: MessageRecv):
|
||||
# 从数据库中获取消息
|
||||
message_id = message.message_info.message_id
|
||||
chat_id = message.chat_stream.stream_id
|
||||
message_id_val = message.message_info.message_id
|
||||
chat_id_val = message.chat_stream.stream_id
|
||||
|
||||
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
|
||||
messages_before = (
|
||||
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
|
||||
.sort("time", -1)
|
||||
.limit(self.context_length * 3)
|
||||
) # 获取更多历史信息
|
||||
messages_before_query = (
|
||||
Messages.select()
|
||||
.where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val))
|
||||
.order_by(Messages.time.desc())
|
||||
.limit(global_config.chat.observation_context_size * 3)
|
||||
)
|
||||
|
||||
return list(messages_before)
|
||||
return list(messages_before_query)
|
||||
|
||||
def message_list_to_dict(self, message_list):
|
||||
# 存储简化的聊天记录
|
||||
result = []
|
||||
for message in message_list:
|
||||
if not isinstance(message, dict):
|
||||
message = self.message_to_dict(message)
|
||||
# print(message)
|
||||
for msg_item in message_list:
|
||||
processed_msg_item = msg_item
|
||||
if not isinstance(msg_item, dict):
|
||||
processed_msg_item = self.message_to_dict(msg_item)
|
||||
|
||||
if not processed_msg_item:
|
||||
continue
|
||||
|
||||
lite_message = {
|
||||
"time": message["time"],
|
||||
"user_nickname": message["user_info"]["user_nickname"],
|
||||
"processed_plain_text": message["processed_plain_text"],
|
||||
"time": processed_msg_item.get("time"),
|
||||
"user_nickname": processed_msg_item.get("user_nickname"),
|
||||
"processed_plain_text": processed_msg_item.get("processed_plain_text"),
|
||||
}
|
||||
result.append(lite_message)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def message_to_dict(message):
|
||||
if not message:
|
||||
def message_to_dict(msg_obj):
|
||||
if not msg_obj:
|
||||
return None
|
||||
if isinstance(message, dict):
|
||||
return message
|
||||
return {
|
||||
# "message_id": message.message_info.message_id,
|
||||
"time": message.message_info.time,
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"user_nickname": message.message_info.user_info.user_nickname,
|
||||
"processed_plain_text": message.processed_plain_text,
|
||||
# "detailed_plain_text": message.detailed_plain_text
|
||||
}
|
||||
if isinstance(msg_obj, dict):
|
||||
return msg_obj
|
||||
|
||||
def done_catch(self):
|
||||
"""将收集到的信息存储到数据库的 thinking_log 集合中喵~"""
|
||||
try:
|
||||
# 将消息对象转换为可序列化的字典喵~
|
||||
|
||||
thinking_log_data = {
|
||||
"chat_id": self.chat_id,
|
||||
"trigger_text": self.trigger_response_text,
|
||||
"response_text": self.response_text,
|
||||
"trigger_info": {
|
||||
"time": self.trigger_response_time,
|
||||
"message": self.message_to_dict(self.trigger_response_message),
|
||||
},
|
||||
"response_info": {
|
||||
"time": self.response_time,
|
||||
"message": self.response_messages,
|
||||
},
|
||||
"timing_results": self.timing_results,
|
||||
"chat_history": self.message_list_to_dict(self.chat_history),
|
||||
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
|
||||
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
|
||||
"heartflow_data": self.heartflow_data,
|
||||
"reasoning_data": self.reasoning_data,
|
||||
if isinstance(msg_obj, Messages):
|
||||
return {
|
||||
"time": msg_obj.time,
|
||||
"user_id": msg_obj.user_id,
|
||||
"user_nickname": msg_obj.user_nickname,
|
||||
"processed_plain_text": msg_obj.processed_plain_text,
|
||||
}
|
||||
|
||||
# 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~
|
||||
# if self.response_mode == "heart_flow":
|
||||
# thinking_log_data["mode_specific_data"] = self.heartflow_data
|
||||
# elif self.response_mode == "reasoning":
|
||||
# thinking_log_data["mode_specific_data"] = self.reasoning_data
|
||||
if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"):
|
||||
return {
|
||||
"time": msg_obj.message_info.time,
|
||||
"user_id": msg_obj.message_info.user_info.user_id,
|
||||
"user_nickname": msg_obj.message_info.user_info.user_nickname,
|
||||
"processed_plain_text": msg_obj.processed_plain_text,
|
||||
}
|
||||
|
||||
# 将数据插入到 thinking_log 集合中喵~
|
||||
db.thinking_log.insert_one(thinking_log_data)
|
||||
print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}")
|
||||
return {}
|
||||
|
||||
def done_catch(self):
|
||||
"""将收集到的信息存储到数据库的 thinking_log 表中喵~"""
|
||||
try:
|
||||
trigger_info_dict = self.message_to_dict(self.trigger_response_message)
|
||||
response_info_dict = {
|
||||
"time": self.response_time,
|
||||
"message": self.response_messages,
|
||||
}
|
||||
chat_history_list = self.message_list_to_dict(self.chat_history)
|
||||
chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking)
|
||||
chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response)
|
||||
|
||||
log_entry = ThinkingLog(
|
||||
chat_id=self.chat_id,
|
||||
trigger_text=self.trigger_response_text,
|
||||
response_text=self.response_text,
|
||||
trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None,
|
||||
response_info_json=json.dumps(response_info_dict),
|
||||
timing_results_json=json.dumps(self.timing_results),
|
||||
chat_history_json=json.dumps(chat_history_list),
|
||||
chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
|
||||
chat_history_after_response_json=json.dumps(chat_history_after_response_list),
|
||||
heartflow_data_json=json.dumps(self.heartflow_data),
|
||||
reasoning_data_json=json.dumps(self.reasoning_data),
|
||||
)
|
||||
log_entry.save()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,10 +2,12 @@ from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Tuple, List
|
||||
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
|
||||
from ...common.database import db
|
||||
from ...common.database.database import db # This db is the Peewee database instance
|
||||
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_module_logger("maibot_statistic")
|
||||
@@ -39,7 +41,7 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
def __init__(self):
|
||||
super().__init__(task_name="Online Time Record Task", run_interval=60)
|
||||
|
||||
self.record_id: str | None = None
|
||||
self.record_id: int | None = None # Changed to int for Peewee's default ID
|
||||
"""记录ID"""
|
||||
|
||||
self._init_database() # 初始化数据库
|
||||
@@ -47,49 +49,46 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库"""
|
||||
if "online_time" not in db.list_collection_names():
|
||||
# 初始化数据库(在线时长)
|
||||
db.create_collection("online_time")
|
||||
# 创建索引
|
||||
if ("end_timestamp", 1) not in db.online_time.list_indexes():
|
||||
db.online_time.create_index([("end_timestamp", 1)])
|
||||
with db.atomic(): # Use atomic operations for schema changes
|
||||
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
extended_end_time = current_time + timedelta(minutes=1)
|
||||
|
||||
if self.record_id:
|
||||
# 如果有记录,则更新结束时间
|
||||
db.online_time.update_one(
|
||||
{"_id": self.record_id},
|
||||
{
|
||||
"$set": {
|
||||
"end_timestamp": datetime.now() + timedelta(minutes=1),
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
|
||||
updated_rows = query.execute()
|
||||
if updated_rows == 0:
|
||||
# Record might have been deleted or ID is stale, try to find/create
|
||||
self.record_id = None # Reset record_id to trigger find/create logic below
|
||||
|
||||
if not self.record_id: # Check again if record_id was reset or initially None
|
||||
# 如果没有记录,检查一分钟以内是否已有记录
|
||||
current_time = datetime.now()
|
||||
if recent_record := db.online_time.find_one(
|
||||
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}}
|
||||
):
|
||||
# Look for a record whose end_timestamp is recent enough to be considered ongoing
|
||||
recent_record = (
|
||||
OnlineTime.select()
|
||||
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
|
||||
.order_by(OnlineTime.end_timestamp.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if recent_record:
|
||||
# 如果有记录,则更新结束时间
|
||||
self.record_id = recent_record["_id"]
|
||||
db.online_time.update_one(
|
||||
{"_id": self.record_id},
|
||||
{
|
||||
"$set": {
|
||||
"end_timestamp": current_time + timedelta(minutes=1),
|
||||
}
|
||||
},
|
||||
)
|
||||
self.record_id = recent_record.id
|
||||
recent_record.end_timestamp = extended_end_time
|
||||
recent_record.save()
|
||||
else:
|
||||
# 若没有记录,则插入新的在线时间记录
|
||||
self.record_id = db.online_time.insert_one(
|
||||
{
|
||||
"start_timestamp": current_time,
|
||||
"end_timestamp": current_time + timedelta(minutes=1),
|
||||
}
|
||||
).inserted_id
|
||||
new_record = OnlineTime.create(
|
||||
timestamp=current_time.timestamp(), # 添加此行
|
||||
start_timestamp=current_time,
|
||||
end_timestamp=extended_end_time,
|
||||
duration=5, # 初始时长为5分钟
|
||||
)
|
||||
self.record_id = new_record.id
|
||||
except Exception as e:
|
||||
logger.error(f"在线时间记录失败,错误信息:{e}")
|
||||
|
||||
@@ -201,35 +200,28 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
:param collect_period: 统计时间段
|
||||
"""
|
||||
if len(collect_period) <= 0:
|
||||
if not collect_period:
|
||||
return {}
|
||||
else:
|
||||
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
|
||||
collect_period.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
|
||||
collect_period.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
stats = {
|
||||
period_key: {
|
||||
# 总LLM请求数
|
||||
TOTAL_REQ_CNT: 0,
|
||||
# 请求次数统计
|
||||
REQ_CNT_BY_TYPE: defaultdict(int),
|
||||
REQ_CNT_BY_USER: defaultdict(int),
|
||||
REQ_CNT_BY_MODEL: defaultdict(int),
|
||||
# 输入Token数
|
||||
IN_TOK_BY_TYPE: defaultdict(int),
|
||||
IN_TOK_BY_USER: defaultdict(int),
|
||||
IN_TOK_BY_MODEL: defaultdict(int),
|
||||
# 输出Token数
|
||||
OUT_TOK_BY_TYPE: defaultdict(int),
|
||||
OUT_TOK_BY_USER: defaultdict(int),
|
||||
OUT_TOK_BY_MODEL: defaultdict(int),
|
||||
# 总Token数
|
||||
TOTAL_TOK_BY_TYPE: defaultdict(int),
|
||||
TOTAL_TOK_BY_USER: defaultdict(int),
|
||||
TOTAL_TOK_BY_MODEL: defaultdict(int),
|
||||
# 总开销
|
||||
TOTAL_COST: 0.0,
|
||||
# 请求开销统计
|
||||
COST_BY_TYPE: defaultdict(float),
|
||||
COST_BY_USER: defaultdict(float),
|
||||
COST_BY_MODEL: defaultdict(float),
|
||||
@@ -238,26 +230,26 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
# 以最早的时间戳为起始时间获取记录
|
||||
for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}):
|
||||
record_timestamp = record.get("timestamp")
|
||||
# Assuming LLMUsage.timestamp is a DateTimeField
|
||||
query_start_time = collect_period[-1][1]
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
|
||||
record_timestamp = record.timestamp # This is already a datetime object
|
||||
for idx, (_, period_start) in enumerate(collect_period):
|
||||
if record_timestamp >= period_start:
|
||||
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
|
||||
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||
|
||||
request_type = record.get("request_type", "unknown") # 请求类型
|
||||
user_id = str(record.get("user_id", "unknown")) # 用户ID
|
||||
model_name = record.get("model_name", "unknown") # 模型名称
|
||||
request_type = record.request_type or "unknown"
|
||||
user_id = record.user_id or "unknown" # user_id is TextField, already string
|
||||
model_name = record.model_name or "unknown"
|
||||
|
||||
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
|
||||
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
|
||||
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
||||
|
||||
prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数
|
||||
completion_tokens = record.get("completion_tokens", 0) # 输出Token数
|
||||
total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数
|
||||
prompt_tokens = record.prompt_tokens or 0
|
||||
completion_tokens = record.completion_tokens or 0
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
|
||||
@@ -271,13 +263,12 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
||||
|
||||
cost = record.get("cost", 0.0)
|
||||
cost = record.cost or 0.0
|
||||
stats[period_key][TOTAL_COST] += cost
|
||||
stats[period_key][COST_BY_TYPE][request_type] += cost
|
||||
stats[period_key][COST_BY_USER][user_id] += cost
|
||||
stats[period_key][COST_BY_MODEL][model_name] += cost
|
||||
break # 取消更早时间段的判断
|
||||
|
||||
break
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
@@ -287,39 +278,38 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
:param collect_period: 统计时间段
|
||||
"""
|
||||
if len(collect_period) <= 0:
|
||||
if not collect_period:
|
||||
return {}
|
||||
else:
|
||||
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
|
||||
collect_period.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
collect_period.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
stats = {
|
||||
period_key: {
|
||||
# 在线时间统计
|
||||
ONLINE_TIME: 0.0,
|
||||
}
|
||||
for period_key, _ in collect_period
|
||||
}
|
||||
|
||||
# 统计在线时间
|
||||
for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}):
|
||||
end_timestamp: datetime = record.get("end_timestamp")
|
||||
for idx, (_, period_start) in enumerate(collect_period):
|
||||
if end_timestamp >= period_start:
|
||||
# 由于end_timestamp会超前标记时间,所以我们需要判断是否晚于当前时间,如果是,则使用当前时间作为结束时间
|
||||
end_timestamp = min(end_timestamp, now)
|
||||
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
|
||||
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
|
||||
for period_key, _period_start in collect_period[idx:]:
|
||||
start_timestamp: datetime = record.get("start_timestamp")
|
||||
if start_timestamp < _period_start:
|
||||
# 如果开始时间在查询边界之前,则使用开始时间
|
||||
stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds()
|
||||
else:
|
||||
# 否则,使用开始时间
|
||||
stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds()
|
||||
break # 取消更早时间段的判断
|
||||
query_start_time = collect_period[-1][1]
|
||||
# Assuming OnlineTime.end_timestamp is a DateTimeField
|
||||
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
|
||||
# record.end_timestamp and record.start_timestamp are datetime objects
|
||||
record_end_timestamp = record.end_timestamp
|
||||
record_start_timestamp = record.start_timestamp
|
||||
|
||||
for idx, (_, period_boundary_start) in enumerate(collect_period):
|
||||
if record_end_timestamp >= period_boundary_start:
|
||||
# Calculate effective end time for this record in relation to 'now'
|
||||
effective_end_time = min(record_end_timestamp, now)
|
||||
|
||||
for period_key, current_period_start_time in collect_period[idx:]:
|
||||
# Determine the portion of the record that falls within this specific statistical period
|
||||
overlap_start = max(record_start_timestamp, current_period_start_time)
|
||||
overlap_end = effective_end_time # Already capped by 'now' and record's own end
|
||||
|
||||
if overlap_end > overlap_start:
|
||||
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
|
||||
break
|
||||
return stats
|
||||
|
||||
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
@@ -328,55 +318,57 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
:param collect_period: 统计时间段
|
||||
"""
|
||||
if len(collect_period) <= 0:
|
||||
if not collect_period:
|
||||
return {}
|
||||
else:
|
||||
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
|
||||
collect_period.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
collect_period.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
stats = {
|
||||
period_key: {
|
||||
# 消息统计
|
||||
TOTAL_MSG_CNT: 0,
|
||||
MSG_CNT_BY_CHAT: defaultdict(int),
|
||||
}
|
||||
for period_key, _ in collect_period
|
||||
}
|
||||
|
||||
# 统计消息量
|
||||
for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}):
|
||||
chat_info = message.get("chat_info", None) # 聊天信息
|
||||
user_info = message.get("user_info", None) # 用户信息(消息发送人)
|
||||
message_time = message.get("time", 0) # 消息时间
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp):
|
||||
message_time_ts = message.time # This is a float timestamp
|
||||
|
||||
group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息
|
||||
if group_info is not None:
|
||||
# 若有群聊信息
|
||||
chat_id = f"g{group_info.get('group_id')}"
|
||||
chat_name = group_info.get("group_name", f"群{group_info.get('group_id')}")
|
||||
elif user_info:
|
||||
# 若没有群聊信息,则尝试获取用户信息
|
||||
chat_id = f"u{user_info['user_id']}"
|
||||
chat_name = user_info["user_nickname"]
|
||||
chat_id = None
|
||||
chat_name = None
|
||||
|
||||
# Logic based on Peewee model structure, aiming to replicate original intent
|
||||
if message.chat_info_group_id:
|
||||
chat_id = f"g{message.chat_info_group_id}"
|
||||
chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}"
|
||||
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
# This uses the message SENDER's ID as per original logic's fallback
|
||||
chat_id = f"u{message.user_id}" # SENDER's user_id
|
||||
chat_name = message.user_nickname # SENDER's nickname
|
||||
else:
|
||||
continue # 如果没有群组信息也没有用户信息,则跳过
|
||||
# If neither group_id nor sender_id is available for chat identification
|
||||
logger.warning(
|
||||
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
|
||||
)
|
||||
continue
|
||||
|
||||
if not chat_id: # Should not happen if above logic is correct
|
||||
continue
|
||||
|
||||
# Update name_mapping
|
||||
if chat_id in self.name_mapping:
|
||||
if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]:
|
||||
# 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称
|
||||
self.name_mapping[chat_id] = (chat_name, message_time)
|
||||
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
else:
|
||||
self.name_mapping[chat_id] = (chat_name, message_time)
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
|
||||
for idx, (_, period_start) in enumerate(collect_period):
|
||||
if message_time >= period_start.timestamp():
|
||||
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
|
||||
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
|
||||
for idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
if message_time_ts >= period_start_dt.timestamp():
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_MSG_CNT] += 1
|
||||
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
||||
break
|
||||
|
||||
return stats
|
||||
|
||||
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.manager.mood_manager import mood_manager
|
||||
from ..message_receive.message import MessageRecv
|
||||
from ..models.utils_model import LLMRequest
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
from ...common.database import db
|
||||
from ...common.database.database import db
|
||||
from ...config.config import global_config
|
||||
|
||||
logger = get_module_logger("chat_utils")
|
||||
@@ -43,8 +43,8 @@ def db_message_to_str(message_dict: dict) -> str:
|
||||
|
||||
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
"""检查消息是否提到了机器人"""
|
||||
keywords = [global_config.BOT_NICKNAME]
|
||||
nicknames = global_config.BOT_ALIAS_NAMES
|
||||
keywords = [global_config.bot.nickname]
|
||||
nicknames = global_config.bot.alias_names
|
||||
reply_probability = 0.0
|
||||
is_at = False
|
||||
is_mentioned = False
|
||||
@@ -64,18 +64,18 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
)
|
||||
|
||||
# 判断是否被@
|
||||
if re.search(f"@[\s\S]*?(id:{global_config.BOT_QQ})", message.processed_plain_text):
|
||||
if re.search(f"@[\s\S]*?(id:{global_config.bot.qq_account})", message.processed_plain_text):
|
||||
is_at = True
|
||||
is_mentioned = True
|
||||
|
||||
if is_at and global_config.at_bot_inevitable_reply:
|
||||
if is_at and global_config.normal_chat.at_bot_inevitable_reply:
|
||||
reply_probability = 1.0
|
||||
logger.info("被@,回复概率设置为100%")
|
||||
else:
|
||||
if not is_mentioned:
|
||||
# 判断是否被回复
|
||||
if re.match(
|
||||
f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\):[\s\S]*?],说:", message.processed_plain_text
|
||||
f"\[回复 [\s\S]*?\({str(global_config.bot.qq_account)}\):[\s\S]*?],说:", message.processed_plain_text
|
||||
):
|
||||
is_mentioned = True
|
||||
else:
|
||||
@@ -88,7 +88,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
for nickname in nicknames:
|
||||
if nickname in message_content:
|
||||
is_mentioned = True
|
||||
if is_mentioned and global_config.mentioned_bot_inevitable_reply:
|
||||
if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply:
|
||||
reply_probability = 1.0
|
||||
logger.info("被提及,回复概率设置为100%")
|
||||
return is_mentioned, reply_probability
|
||||
@@ -96,7 +96,8 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
|
||||
async def get_embedding(text, request_type="embedding"):
|
||||
"""获取文本的embedding向量"""
|
||||
llm = LLMRequest(model=global_config.embedding, request_type=request_type)
|
||||
# TODO: API-Adapter修改标记
|
||||
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
|
||||
# return llm.get_embedding_sync(text)
|
||||
try:
|
||||
embedding = await llm.get_embedding(text)
|
||||
@@ -163,7 +164,7 @@ def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> li
|
||||
user_info = UserInfo.from_dict(msg_db_data["user_info"])
|
||||
if (
|
||||
(user_info.platform, user_info.user_id) != sender
|
||||
and user_info.user_id != global_config.BOT_QQ
|
||||
and user_info.user_id != global_config.bot.qq_account
|
||||
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
||||
and len(who_chat_in_group) < 5
|
||||
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||
@@ -321,7 +322,7 @@ def random_remove_punctuation(text: str) -> str:
|
||||
|
||||
def process_llm_response(text: str) -> list[str]:
|
||||
# 先保护颜文字
|
||||
if global_config.enable_kaomoji_protection:
|
||||
if global_config.response_splitter.enable_kaomoji_protection:
|
||||
protected_text, kaomoji_mapping = protect_kaomoji(text)
|
||||
logger.trace(f"保护颜文字后的文本: {protected_text}")
|
||||
else:
|
||||
@@ -340,8 +341,8 @@ def process_llm_response(text: str) -> list[str]:
|
||||
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
|
||||
|
||||
# 对清理后的文本进行进一步处理
|
||||
max_length = global_config.response_max_length * 2
|
||||
max_sentence_num = global_config.response_max_sentence_num
|
||||
max_length = global_config.response_splitter.max_length * 2
|
||||
max_sentence_num = global_config.response_splitter.max_sentence_num
|
||||
# 如果基本上是中文,则进行长度过滤
|
||||
if get_western_ratio(cleaned_text) < 0.1:
|
||||
if len(cleaned_text) > max_length:
|
||||
@@ -349,20 +350,20 @@ def process_llm_response(text: str) -> list[str]:
|
||||
return ["懒得说"]
|
||||
|
||||
typo_generator = ChineseTypoGenerator(
|
||||
error_rate=global_config.chinese_typo_error_rate,
|
||||
min_freq=global_config.chinese_typo_min_freq,
|
||||
tone_error_rate=global_config.chinese_typo_tone_error_rate,
|
||||
word_replace_rate=global_config.chinese_typo_word_replace_rate,
|
||||
error_rate=global_config.chinese_typo.error_rate,
|
||||
min_freq=global_config.chinese_typo.min_freq,
|
||||
tone_error_rate=global_config.chinese_typo.tone_error_rate,
|
||||
word_replace_rate=global_config.chinese_typo.word_replace_rate,
|
||||
)
|
||||
|
||||
if global_config.enable_response_splitter:
|
||||
if global_config.response_splitter.enable:
|
||||
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
||||
else:
|
||||
split_sentences = [cleaned_text]
|
||||
|
||||
sentences = []
|
||||
for sentence in split_sentences:
|
||||
if global_config.chinese_typo_enable:
|
||||
if global_config.chinese_typo.enable:
|
||||
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
|
||||
sentences.append(typoed_text)
|
||||
if typo_corrections:
|
||||
@@ -372,7 +373,7 @@ def process_llm_response(text: str) -> list[str]:
|
||||
|
||||
if len(sentences) > max_sentence_num:
|
||||
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
||||
return [f"{global_config.BOT_NICKNAME}不知道哦"]
|
||||
return [f"{global_config.bot.nickname}不知道哦"]
|
||||
|
||||
# if extracted_contents:
|
||||
# for content in extracted_contents:
|
||||
|
||||
@@ -8,7 +8,8 @@ import io
|
||||
import numpy as np
|
||||
|
||||
|
||||
from ...common.database import db
|
||||
from ...common.database.database import db
|
||||
from ...common.database.database_model import Images, ImageDescriptions
|
||||
from ...config.config import global_config
|
||||
from ..models.utils_model import LLMRequest
|
||||
|
||||
@@ -32,40 +33,23 @@ class ImageManager:
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self._ensure_image_collection()
|
||||
self._ensure_description_collection()
|
||||
self._ensure_image_dir()
|
||||
|
||||
self._initialized = True
|
||||
self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
|
||||
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
db.create_tables([Images, ImageDescriptions], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或表创建失败: {e}")
|
||||
|
||||
self._initialized = True
|
||||
self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
|
||||
|
||||
def _ensure_image_dir(self):
|
||||
"""确保图像存储目录存在"""
|
||||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_image_collection():
|
||||
"""确保images集合存在并创建索引"""
|
||||
if "images" not in db.list_collection_names():
|
||||
db.create_collection("images")
|
||||
|
||||
# 删除旧索引
|
||||
db.images.drop_indexes()
|
||||
# 创建新的复合索引
|
||||
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
|
||||
db.images.create_index([("url", 1)])
|
||||
db.images.create_index([("path", 1)])
|
||||
|
||||
@staticmethod
|
||||
def _ensure_description_collection():
|
||||
"""确保image_descriptions集合存在并创建索引"""
|
||||
if "image_descriptions" not in db.list_collection_names():
|
||||
db.create_collection("image_descriptions")
|
||||
|
||||
# 删除旧索引
|
||||
db.image_descriptions.drop_indexes()
|
||||
# 创建新的复合索引
|
||||
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
||||
"""从数据库获取图片描述
|
||||
@@ -77,8 +61,14 @@ class ImageManager:
|
||||
Returns:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
|
||||
return result["description"] if result else None
|
||||
try:
|
||||
record = ImageDescriptions.get_or_none(
|
||||
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
|
||||
)
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
|
||||
@@ -90,20 +80,17 @@ class ImageManager:
|
||||
description_type: 描述类型 ('emoji' 或 'image')
|
||||
"""
|
||||
try:
|
||||
db.image_descriptions.update_one(
|
||||
{"hash": image_hash, "type": description_type},
|
||||
{
|
||||
"$set": {
|
||||
"description": description,
|
||||
"timestamp": int(time.time()),
|
||||
"hash": image_hash, # 确保hash字段存在
|
||||
"type": description_type, # 确保type字段存在
|
||||
}
|
||||
},
|
||||
upsert=True,
|
||||
current_timestamp = time.time()
|
||||
defaults = {"description": description, "timestamp": current_timestamp}
|
||||
desc_obj, created = ImageDescriptions.get_or_create(
|
||||
hash=image_hash, type=description_type, defaults=defaults
|
||||
)
|
||||
if not created: # 如果记录已存在,则更新
|
||||
desc_obj.description = description
|
||||
desc_obj.timestamp = current_timestamp
|
||||
desc_obj.save()
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败: {str(e)}")
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
|
||||
async def get_emoji_description(self, image_base64: str) -> str:
|
||||
"""获取表情包描述,带查重和保存功能"""
|
||||
@@ -116,51 +103,64 @@ class ImageManager:
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
if cached_description:
|
||||
# logger.debug(f"缓存表情包描述: {cached_description}")
|
||||
return f"[表情包,含义看起来是:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
if image_format == "gif" or image_format == "GIF":
|
||||
image_base64 = self.transform_gif(image_base64)
|
||||
image_base64_processed = self.transform_gif(image_base64)
|
||||
if image_base64_processed is None:
|
||||
logger.warning("GIF转换失败,无法获取描述")
|
||||
return "[表情包(GIF处理失败)]"
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
|
||||
else:
|
||||
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成表情包描述")
|
||||
return "[表情包(描述生成失败)]"
|
||||
|
||||
# 再次检查缓存,防止并发写入时重复生成
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
if cached_description:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||
return f"[表情包,含义看起来是:{cached_description}]"
|
||||
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.save_emoji:
|
||||
if global_config.emoji.save_emoji:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")):
|
||||
os.makedirs(os.path.join(self.IMAGE_DIR, "emoji"))
|
||||
file_path = os.path.join(self.IMAGE_DIR, "emoji", filename)
|
||||
current_timestamp = time.time()
|
||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||
emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
|
||||
os.makedirs(emoji_dir, exist_ok=True)
|
||||
file_path = os.path.join(emoji_dir, filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
image_doc = {
|
||||
"hash": image_hash,
|
||||
"path": file_path,
|
||||
"type": "emoji",
|
||||
"description": description,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
|
||||
logger.trace(f"保存表情包: {file_path}")
|
||||
# 保存到数据库 (Images表)
|
||||
try:
|
||||
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
img_obj.path = file_path
|
||||
img_obj.description = description
|
||||
img_obj.timestamp = current_timestamp
|
||||
img_obj.save()
|
||||
except Images.DoesNotExist:
|
||||
Images.create(
|
||||
hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=description,
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
logger.trace(f"保存表情包元数据: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件失败: {str(e)}")
|
||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||
|
||||
# 保存描述到数据库
|
||||
# 保存描述到数据库 (ImageDescriptions表)
|
||||
self._save_description_to_db(image_hash, description, "emoji")
|
||||
|
||||
return f"[表情包:{description}]"
|
||||
@@ -188,6 +188,11 @@ class ImageManager:
|
||||
)
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
return "[图片(描述生成失败)]"
|
||||
|
||||
# 再次检查缓存
|
||||
cached_description = self._get_description_from_db(image_hash, "image")
|
||||
if cached_description:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
|
||||
@@ -195,38 +200,40 @@ class ImageManager:
|
||||
|
||||
logger.debug(f"描述是{description}")
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
return "[图片]"
|
||||
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.save_pic:
|
||||
if global_config.emoji.save_pic:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")):
|
||||
os.makedirs(os.path.join(self.IMAGE_DIR, "image"))
|
||||
file_path = os.path.join(self.IMAGE_DIR, "image", filename)
|
||||
current_timestamp = time.time()
|
||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||
image_dir = os.path.join(self.IMAGE_DIR, "image")
|
||||
os.makedirs(image_dir, exist_ok=True)
|
||||
file_path = os.path.join(image_dir, filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
image_doc = {
|
||||
"hash": image_hash,
|
||||
"path": file_path,
|
||||
"type": "image",
|
||||
"description": description,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
|
||||
logger.trace(f"保存图片: {file_path}")
|
||||
# 保存到数据库 (Images表)
|
||||
try:
|
||||
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "image"))
|
||||
img_obj.path = file_path
|
||||
img_obj.description = description
|
||||
img_obj.timestamp = current_timestamp
|
||||
img_obj.save()
|
||||
except Images.DoesNotExist:
|
||||
Images.create(
|
||||
hash=image_hash,
|
||||
path=file_path,
|
||||
type="image",
|
||||
description=description,
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
logger.trace(f"保存图片元数据: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件失败: {str(e)}")
|
||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||
|
||||
# 保存描述到数据库
|
||||
# 保存描述到数据库 (ImageDescriptions表)
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
return f"[图片:{description}]"
|
||||
|
||||
Reference in New Issue
Block a user