初始化
This commit is contained in:
@@ -7,13 +7,14 @@ from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database_model import Images
|
||||
from src.person_info.person_info import Person,get_person_id
|
||||
from src.common.database.sqlalchemy_models import ActionRecords, Images
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
session = get_session()
|
||||
|
||||
def replace_user_references_sync(
|
||||
content: str,
|
||||
@@ -254,50 +255,90 @@ def get_actions_by_timestamp_with_chat(
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
||||
query = ActionRecords.select().where(
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
& (ActionRecords.time > timestamp_start) # type: ignore
|
||||
& (ActionRecords.time < timestamp_end) # type: ignore
|
||||
)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
)
|
||||
))
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.desc()).limit(limit))
|
||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in reversed(actions)]
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()).limit(limit))
|
||||
else:
|
||||
query = query.order_by(ActionRecords.time.asc())
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()))
|
||||
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in actions]
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in actions]
|
||||
|
||||
|
||||
def get_actions_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||
query = ActionRecords.select().where(
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
& (ActionRecords.time >= timestamp_start) # type: ignore
|
||||
& (ActionRecords.time <= timestamp_end) # type: ignore
|
||||
)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
)
|
||||
))
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.desc()).limit(limit))
|
||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in reversed(actions)]
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()).limit(limit))
|
||||
else:
|
||||
query = query.order_by(ActionRecords.time.asc())
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()))
|
||||
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in actions]
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in actions]
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_random(
|
||||
@@ -700,7 +741,7 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# 从数据库中获取图片描述
|
||||
description = "内容正在阅读,请稍等"
|
||||
try:
|
||||
image = Images.get_or_none(Images.image_id == pic_id)
|
||||
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar()
|
||||
if image and image.description:
|
||||
description = image.description
|
||||
except Exception:
|
||||
@@ -813,7 +854,7 @@ def build_readable_messages(
|
||||
timestamp_mode: str = "relative",
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_actions: bool = True,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
@@ -846,21 +887,21 @@ def build_readable_messages(
|
||||
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
|
||||
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions_in_range = (
|
||||
ActionRecords.select()
|
||||
.where(
|
||||
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
|
||||
actions_in_range = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.time >= min_time,
|
||||
ActionRecords.time <= max_time,
|
||||
ActionRecords.chat_id == chat_id
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
)
|
||||
).order_by(ActionRecords.time)).scalars()
|
||||
|
||||
# 获取最新消息之后的第一个动作记录
|
||||
action_after_latest = (
|
||||
ActionRecords.select()
|
||||
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
)
|
||||
action_after_latest = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.time > max_time,
|
||||
ActionRecords.chat_id == chat_id
|
||||
)
|
||||
).order_by(ActionRecords.time).limit(1)).scalars()
|
||||
|
||||
# 合并两部分动作记录
|
||||
actions = list(actions_in_range) + list(action_after_latest)
|
||||
|
||||
@@ -6,13 +6,52 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Tuple, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session, db_query, db_save, db_get
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("maibot_statistic")
|
||||
|
||||
# 同步包装器函数,用于在非异步环境中调用异步数据库API
|
||||
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
|
||||
"""同步版本的db_get,用于在线程池中调用"""
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 如果事件循环正在运行,创建新的事件循环
|
||||
import threading
|
||||
result = None
|
||||
exception = None
|
||||
|
||||
def run_in_thread():
|
||||
nonlocal result, exception
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
result = new_loop.run_until_complete(
|
||||
db_get(model_class, filters, limit, order_by, single_result)
|
||||
)
|
||||
new_loop.close()
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
thread = threading.Thread(target=run_in_thread)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return result
|
||||
else:
|
||||
return loop.run_until_complete(
|
||||
db_get(model_class, filters, limit, order_by, single_result)
|
||||
)
|
||||
except RuntimeError:
|
||||
# 没有事件循环,创建一个新的
|
||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
||||
|
||||
# 统计数据的键
|
||||
TOTAL_REQ_CNT = "total_requests"
|
||||
TOTAL_COST = "total_cost"
|
||||
@@ -59,17 +98,9 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
def __init__(self):
|
||||
super().__init__(task_name="Online Time Record Task", run_interval=60)
|
||||
|
||||
self.record_id: int | None = None # Changed to int for Peewee's default ID
|
||||
self.record_id: int | None = None
|
||||
"""记录ID"""
|
||||
|
||||
self._init_database() # 初始化数据库
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库"""
|
||||
with db.atomic(): # Use atomic operations for schema changes
|
||||
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
|
||||
|
||||
async def run(self): # sourcery skip: use-named-expression
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
@@ -77,36 +108,50 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
|
||||
if self.record_id:
|
||||
# 如果有记录,则更新结束时间
|
||||
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore
|
||||
updated_rows = query.execute()
|
||||
updated_rows = await db_query(
|
||||
model_class=OnlineTime,
|
||||
query_type="update",
|
||||
filters={"id": self.record_id},
|
||||
data={"end_timestamp": extended_end_time}
|
||||
)
|
||||
if updated_rows == 0:
|
||||
# Record might have been deleted or ID is stale, try to find/create
|
||||
self.record_id = None # Reset record_id to trigger find/create logic below
|
||||
self.record_id = None
|
||||
|
||||
if not self.record_id: # Check again if record_id was reset or initially None
|
||||
# 如果没有记录,检查一分钟以内是否已有记录
|
||||
# Look for a record whose end_timestamp is recent enough to be considered ongoing
|
||||
recent_record = (
|
||||
OnlineTime.select()
|
||||
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore
|
||||
.order_by(OnlineTime.end_timestamp.desc())
|
||||
.first()
|
||||
if not self.record_id:
|
||||
# 查找最近一分钟内的记录
|
||||
recent_threshold = current_time - timedelta(minutes=1)
|
||||
recent_records = await db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": recent_threshold}},
|
||||
order_by="-end_timestamp",
|
||||
limit=1,
|
||||
single_result=True
|
||||
)
|
||||
|
||||
if recent_record:
|
||||
# 如果有记录,则更新结束时间
|
||||
self.record_id = recent_record.id
|
||||
recent_record.end_timestamp = extended_end_time
|
||||
recent_record.save()
|
||||
else:
|
||||
# 若没有记录,则插入新的在线时间记录
|
||||
new_record = OnlineTime.create(
|
||||
timestamp=current_time.timestamp(), # 添加此行
|
||||
start_timestamp=current_time,
|
||||
end_timestamp=extended_end_time,
|
||||
duration=5, # 初始时长为5分钟
|
||||
|
||||
if recent_records:
|
||||
# 找到近期记录,更新它
|
||||
self.record_id = recent_records['id']
|
||||
await db_query(
|
||||
model_class=OnlineTime,
|
||||
query_type="update",
|
||||
filters={"id": self.record_id},
|
||||
data={"end_timestamp": extended_end_time}
|
||||
)
|
||||
self.record_id = new_record.id
|
||||
else:
|
||||
# 创建新记录
|
||||
new_record = await db_save(
|
||||
model_class=OnlineTime,
|
||||
data={
|
||||
"timestamp": str(current_time),
|
||||
"duration": 5, # 初始时长为5分钟
|
||||
"start_timestamp": current_time,
|
||||
"end_timestamp": extended_end_time,
|
||||
}
|
||||
)
|
||||
if new_record:
|
||||
self.record_id = new_record['id']
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"在线时间记录失败,错误信息:{e}")
|
||||
|
||||
@@ -322,18 +367,23 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
# 以最早的时间戳为起始时间获取记录
|
||||
# Assuming LLMUsage.timestamp is a DateTimeField
|
||||
query_start_time = collect_period[-1][1]
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
|
||||
record_timestamp = record.timestamp # This is already a datetime object
|
||||
records = _sync_db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp"
|
||||
)
|
||||
|
||||
for record in records:
|
||||
record_timestamp = record['timestamp'] # 从字典中获取
|
||||
for idx, (_, period_start) in enumerate(collect_period):
|
||||
if record_timestamp >= period_start:
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||
|
||||
request_type = record.request_type or "unknown"
|
||||
user_id = record.user_id or "unknown" # user_id is TextField, already string
|
||||
model_name = record.model_name or "unknown"
|
||||
request_type = record.get('request_type') or "unknown"
|
||||
user_id = record.get('user_id') or "unknown"
|
||||
model_name = record.get('model_name') or "unknown"
|
||||
|
||||
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
@@ -343,8 +393,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
||||
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
|
||||
|
||||
prompt_tokens = record.prompt_tokens or 0
|
||||
completion_tokens = record.completion_tokens or 0
|
||||
prompt_tokens = record.get('prompt_tokens') or 0
|
||||
completion_tokens = record.get('completion_tokens') or 0
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||
@@ -362,7 +412,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
|
||||
|
||||
cost = record.cost or 0.0
|
||||
cost = record.get('cost') or 0.0
|
||||
stats[period_key][TOTAL_COST] += cost
|
||||
stats[period_key][COST_BY_TYPE][request_type] += cost
|
||||
stats[period_key][COST_BY_USER][user_id] += cost
|
||||
@@ -425,11 +475,15 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_time = collect_period[-1][1]
|
||||
# Assuming OnlineTime.end_timestamp is a DateTimeField
|
||||
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore
|
||||
# record.end_timestamp and record.start_timestamp are datetime objects
|
||||
record_end_timestamp = record.end_timestamp
|
||||
record_start_timestamp = record.start_timestamp
|
||||
records = _sync_db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": query_start_time}},
|
||||
order_by="-end_timestamp"
|
||||
)
|
||||
|
||||
for record in records:
|
||||
record_end_timestamp = record['end_timestamp']
|
||||
record_start_timestamp = record['start_timestamp']
|
||||
|
||||
for idx, (_, period_boundary_start) in enumerate(collect_period):
|
||||
if record_end_timestamp >= period_boundary_start:
|
||||
@@ -466,24 +520,30 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time # This is a float timestamp
|
||||
records = _sync_db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time"
|
||||
)
|
||||
|
||||
for message in records:
|
||||
message_time_ts = message['time'] # This is a float timestamp
|
||||
|
||||
chat_id = None
|
||||
chat_name = None
|
||||
|
||||
# Logic based on Peewee model structure, aiming to replicate original intent
|
||||
if message.chat_info_group_id:
|
||||
chat_id = f"g{message.chat_info_group_id}"
|
||||
chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}"
|
||||
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
|
||||
if message.get('chat_info_group_id'):
|
||||
chat_id = f"g{message['chat_info_group_id']}"
|
||||
chat_name = message.get('chat_info_group_name') or f"群{message['chat_info_group_id']}"
|
||||
elif message.get('user_id'): # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
# This uses the message SENDER's ID as per original logic's fallback
|
||||
chat_id = f"u{message.user_id}" # SENDER's user_id
|
||||
chat_name = message.user_nickname # SENDER's nickname
|
||||
chat_id = f"u{message['user_id']}" # SENDER's user_id
|
||||
chat_name = message.get('user_nickname') # SENDER's nickname
|
||||
else:
|
||||
# If neither group_id nor sender_id is available for chat identification
|
||||
logger.warning(
|
||||
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
|
||||
f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats."
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -1025,8 +1085,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询LLM使用记录
|
||||
query_start_time = start_time
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
|
||||
record_time = record.timestamp
|
||||
records = _sync_db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp"
|
||||
)
|
||||
|
||||
for record in records:
|
||||
record_time = record['timestamp']
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = (record_time - start_time).total_seconds()
|
||||
@@ -1034,17 +1100,17 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 累加总花费数据
|
||||
cost = record.cost or 0.0
|
||||
cost = record.get('cost') or 0.0
|
||||
total_cost_data[interval_index] += cost # type: ignore
|
||||
|
||||
# 累加按模型分类的花费
|
||||
model_name = record.model_name or "unknown"
|
||||
model_name = record.get('model_name') or "unknown"
|
||||
if model_name not in cost_by_model:
|
||||
cost_by_model[model_name] = [0] * len(time_points)
|
||||
cost_by_model[model_name][interval_index] += cost
|
||||
|
||||
# 累加按模块分类的花费
|
||||
request_type = record.request_type or "unknown"
|
||||
request_type = record.get('request_type') or "unknown"
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
if module_name not in cost_by_module:
|
||||
cost_by_module[module_name] = [0] * len(time_points)
|
||||
@@ -1052,8 +1118,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time
|
||||
records = _sync_db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time"
|
||||
)
|
||||
|
||||
for message in records:
|
||||
message_time_ts = message['time']
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = message_time_ts - query_start_timestamp
|
||||
@@ -1062,10 +1134,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 确定聊天流名称
|
||||
chat_name = None
|
||||
if message.chat_info_group_id:
|
||||
chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}"
|
||||
elif message.user_id:
|
||||
chat_name = message.user_nickname or f"用户{message.user_id}"
|
||||
if message.get('chat_info_group_id'):
|
||||
chat_name = message.get('chat_info_group_name') or f"群{message['chat_info_group_id']}"
|
||||
elif message.get('user_id'):
|
||||
chat_name = message.get('user_nickname') or f"用户{message['user_id']}"
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
@@ -13,10 +13,12 @@ from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import Images, ImageDescriptions
|
||||
from src.common.database.sqlalchemy_models import Images, ImageDescriptions
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_image")
|
||||
@@ -41,9 +43,10 @@ class ImageManager:
|
||||
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
db.create_tables([Images, ImageDescriptions], safe=True)
|
||||
# 使用SQLAlchemy创建表已在初始化时完成
|
||||
logger.debug("使用SQLAlchemy进行表管理")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或表创建失败: {e}")
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@@ -63,12 +66,13 @@ class ImageManager:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
record = ImageDescriptions.get_or_none(
|
||||
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
|
||||
)
|
||||
return record.description if record else None
|
||||
with get_db_session() as session:
|
||||
record = session.execute(select(ImageDescriptions).where(
|
||||
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
|
||||
)).scalar()
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
|
||||
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -82,16 +86,28 @@ class ImageManager:
|
||||
"""
|
||||
try:
|
||||
current_timestamp = time.time()
|
||||
defaults = {"description": description, "timestamp": current_timestamp}
|
||||
desc_obj, created = ImageDescriptions.get_or_create(
|
||||
image_description_hash=image_hash, type=description_type, defaults=defaults
|
||||
)
|
||||
if not created: # 如果记录已存在,则更新
|
||||
desc_obj.description = description
|
||||
desc_obj.timestamp = current_timestamp
|
||||
desc_obj.save()
|
||||
with get_db_session() as session:
|
||||
# 查找现有记录
|
||||
existing = session.execute(select(ImageDescriptions).where(
|
||||
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
|
||||
)).scalar()
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
existing.description = description
|
||||
existing.timestamp = current_timestamp
|
||||
else:
|
||||
# 创建新记录
|
||||
new_desc = ImageDescriptions(
|
||||
image_description_hash=image_hash,
|
||||
type=description_type,
|
||||
description=description,
|
||||
timestamp=current_timestamp
|
||||
)
|
||||
session.add(new_desc)
|
||||
# session.commit() 会在上下文管理器中自动调用
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
|
||||
|
||||
async def get_emoji_tag(self, image_base64: str) -> str:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
@@ -214,19 +230,29 @@ class ImageManager:
|
||||
|
||||
# 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程
|
||||
try:
|
||||
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
img_obj.path = file_path
|
||||
img_obj.description = detailed_description # 保存详细描述
|
||||
img_obj.timestamp = current_timestamp
|
||||
img_obj.save()
|
||||
except Images.DoesNotExist: # type: ignore
|
||||
Images.create(
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=detailed_description, # 保存详细描述
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
with get_db_session() as session:
|
||||
existing_img = session.execute(select(Images).where(
|
||||
and_(Images.emoji_hash == image_hash, Images.type == "emoji")
|
||||
)).scalar()
|
||||
|
||||
if existing_img:
|
||||
existing_img.path = file_path
|
||||
existing_img.description = detailed_description # 保存详细描述
|
||||
existing_img.timestamp = current_timestamp
|
||||
else:
|
||||
new_img = Images(
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=detailed_description, # 保存详细描述
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
session.add(new_img)
|
||||
# session.commit() 会在上下文管理器中自动调用
|
||||
except Exception as e:
|
||||
logger.error(f"保存到Images表失败: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||
|
||||
@@ -249,19 +275,19 @@ class ImageManager:
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 优先检查Images表中是否已有完整的描述
|
||||
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
|
||||
if existing_image:
|
||||
# 更新计数
|
||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||
existing_image.count += 1
|
||||
else:
|
||||
existing_image.count = 1
|
||||
existing_image.save()
|
||||
with get_db_session() as session:
|
||||
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
|
||||
if existing_image:
|
||||
# 更新计数
|
||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||
existing_image.count += 1
|
||||
else:
|
||||
existing_image.count = 1
|
||||
|
||||
# 如果已有描述,直接返回
|
||||
if existing_image.description:
|
||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
|
||||
return f"[图片:{existing_image.description}]"
|
||||
# 如果已有描述,直接返回
|
||||
if existing_image.description:
|
||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
|
||||
return f"[图片:{existing_image.description}]"
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||
@@ -300,10 +326,10 @@ class ImageManager:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = True
|
||||
existing_image.save()
|
||||
session.commit()
|
||||
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
||||
else:
|
||||
Images.create(
|
||||
new_img = Images(
|
||||
image_id=str(uuid.uuid4()),
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
@@ -313,6 +339,8 @@ class ImageManager:
|
||||
vlm_processed=True,
|
||||
count=1,
|
||||
)
|
||||
session.add(new_img)
|
||||
session.commit()
|
||||
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||
@@ -465,31 +493,32 @@ class ImageManager:
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
with get_db_session() as session:
|
||||
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
|
||||
if existing_image:
|
||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||
if (
|
||||
not hasattr(existing_image, "image_id")
|
||||
or not existing_image.image_id
|
||||
or not hasattr(existing_image, "count")
|
||||
or existing_image.count is None
|
||||
or not hasattr(existing_image, "vlm_processed")
|
||||
or existing_image.vlm_processed is None
|
||||
):
|
||||
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
|
||||
if not existing_image.image_id:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if existing_image.count is None:
|
||||
existing_image.count = 0
|
||||
if existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = False
|
||||
|
||||
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
|
||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||
if (
|
||||
not hasattr(existing_image, "image_id")
|
||||
or not existing_image.image_id
|
||||
or not hasattr(existing_image, "count")
|
||||
or existing_image.count is None
|
||||
or not hasattr(existing_image, "vlm_processed")
|
||||
or existing_image.vlm_processed is None
|
||||
):
|
||||
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
|
||||
if not existing_image.image_id:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if existing_image.count is None:
|
||||
existing_image.count = 0
|
||||
if existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = False
|
||||
existing_image.count += 1
|
||||
session.commit()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
|
||||
existing_image.count += 1
|
||||
existing_image.save()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
else:
|
||||
# print(f"图片不存在: {image_hash}")
|
||||
image_id = str(uuid.uuid4())
|
||||
# print(f"图片不存在: {image_hash}")
|
||||
image_id = str(uuid.uuid4())
|
||||
|
||||
# 保存新图片
|
||||
current_timestamp = time.time()
|
||||
@@ -503,7 +532,7 @@ class ImageManager:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
Images.create(
|
||||
new_img = Images(
|
||||
image_id=image_id,
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
@@ -512,6 +541,8 @@ class ImageManager:
|
||||
vlm_processed=False,
|
||||
count=1,
|
||||
)
|
||||
session.add(new_img)
|
||||
session.commit()
|
||||
|
||||
# 启动异步VLM处理
|
||||
asyncio.create_task(self._process_image_with_vlm(image_id, image_base64))
|
||||
@@ -536,60 +567,64 @@ class ImageManager:
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
with get_db_session() as session:
|
||||
# 获取当前图片记录
|
||||
image = session.execute(select(Images).where(Images.image_id == image_id)).scalar()
|
||||
|
||||
# 获取当前图片记录
|
||||
image = Images.get(Images.image_id == image_id)
|
||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||
existing_with_description = session.execute(select(Images).where(
|
||||
and_(
|
||||
Images.emoji_hash == image_hash,
|
||||
Images.description.isnot(None),
|
||||
Images.description != "",
|
||||
Images.id != image.id
|
||||
)
|
||||
)).scalar()
|
||||
if existing_with_description:
|
||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||
image.description = existing_with_description.description
|
||||
image.vlm_processed = True
|
||||
session.commit()
|
||||
# 同时保存到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, existing_with_description.description, "image")
|
||||
return
|
||||
|
||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||
existing_with_description = Images.get_or_none(
|
||||
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
|
||||
)
|
||||
if existing_with_description and existing_with_description.id != image.id:
|
||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||
image.description = existing_with_description.description
|
||||
# 检查ImageDescriptions表的缓存描述
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||
image.description = cached_description
|
||||
image.vlm_processed = True
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# 获取图片格式
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 构建prompt
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
|
||||
# 获取VLM描述
|
||||
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
|
||||
if description is None:
|
||||
logger.warning("VLM未能生成图片描述")
|
||||
description = "无法生成描述"
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
|
||||
description = cached_description
|
||||
|
||||
# 更新数据库
|
||||
image.description = description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
# 同时保存到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, existing_with_description.description, "image")
|
||||
return
|
||||
|
||||
# 检查ImageDescriptions表的缓存描述
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||
image.description = cached_description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
return
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
# 获取图片格式
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 构建prompt
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
|
||||
# 获取VLM描述
|
||||
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
|
||||
if description is None:
|
||||
logger.warning("VLM未能生成图片描述")
|
||||
description = "无法生成描述"
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
|
||||
description = cached_description
|
||||
|
||||
# 更新数据库
|
||||
image.description = description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
|
||||
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VLM处理图片失败: {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user