feat: 添加数据库表创建和初始化功能,确保模型表存在

This commit is contained in:
墨梓柒
2025-05-14 23:04:22 +08:00
parent b84cc9240a
commit 2051b011b1
2 changed files with 128 additions and 77 deletions

View File

@@ -6,7 +6,7 @@ from src.common.logger import get_module_logger
from src.manager.async_task_manager import AsyncTask
from ...common.database.database import db # This db is the Peewee database instance
from ...common.database.database_model import OnlineTime # Import the Peewee model
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
from src.manager.local_store_manager import local_storage
logger = get_module_logger("maibot_statistic")
@@ -195,35 +195,28 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
if not collect_period:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 总LLM请求数
TOTAL_REQ_CNT: 0,
# 请求次数统计
REQ_CNT_BY_TYPE: defaultdict(int),
REQ_CNT_BY_USER: defaultdict(int),
REQ_CNT_BY_MODEL: defaultdict(int),
# 输入Token数
IN_TOK_BY_TYPE: defaultdict(int),
IN_TOK_BY_USER: defaultdict(int),
IN_TOK_BY_MODEL: defaultdict(int),
# 输出Token数
OUT_TOK_BY_TYPE: defaultdict(int),
OUT_TOK_BY_USER: defaultdict(int),
OUT_TOK_BY_MODEL: defaultdict(int),
# 总Token数
TOTAL_TOK_BY_TYPE: defaultdict(int),
TOTAL_TOK_BY_USER: defaultdict(int),
TOTAL_TOK_BY_MODEL: defaultdict(int),
# 总开销
TOTAL_COST: 0.0,
# 请求开销统计
COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float),
@@ -232,26 +225,26 @@ class StatisticOutputTask(AsyncTask):
}
# 以最早的时间戳为起始时间获取记录
for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}):
record_timestamp = record.get("timestamp")
# Assuming LLMUsage.timestamp is a DateTimeField
query_start_time = collect_period[-1][1]
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
record_timestamp = record.timestamp # This is already a datetime object
for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.get("request_type", "unknown") # 请求类型
user_id = str(record.get("user_id", "unknown")) # 用户ID
model_name = record.get("model_name", "unknown") # 模型名称
request_type = record.request_type or "unknown"
user_id = record.user_id or "unknown" # user_id is TextField, already string
model_name = record.model_name or "unknown"
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数
completion_tokens = record.get("completion_tokens", 0) # 输出Token数
total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数
prompt_tokens = record.prompt_tokens or 0
completion_tokens = record.completion_tokens or 0
total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
@@ -265,13 +258,12 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
cost = record.get("cost", 0.0)
cost = record.cost or 0.0
stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost
stats[period_key][COST_BY_MODEL][model_name] += cost
break # 取消更早时间段的判断
break
return stats
@staticmethod
@@ -281,39 +273,38 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
if not collect_period:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 在线时间统计
ONLINE_TIME: 0.0,
}
for period_key, _ in collect_period
}
# 统计在线时间
for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}):
end_timestamp: datetime = record.get("end_timestamp")
for idx, (_, period_start) in enumerate(collect_period):
if end_timestamp >= period_start:
# 由于end_timestamp会超前标记时间所以我们需要判断是否晚于当前时间如果是则使用当前时间作为结束时间
end_timestamp = min(end_timestamp, now)
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _period_start in collect_period[idx:]:
start_timestamp: datetime = record.get("start_timestamp")
if start_timestamp < _period_start:
# 如果开始时间在查询边界之前,则使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds()
else:
# 否则,使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds()
break # 取消更早时间段的判断
query_start_time = collect_period[-1][1]
# Assuming OnlineTime.end_timestamp is a DateTimeField
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
# record.end_timestamp and record.start_timestamp are datetime objects
record_end_timestamp = record.end_timestamp
record_start_timestamp = record.start_timestamp
for idx, (_, period_boundary_start) in enumerate(collect_period):
if record_end_timestamp >= period_boundary_start:
# Calculate effective end time for this record in relation to 'now'
effective_end_time = min(record_end_timestamp, now)
for period_key, current_period_start_time in collect_period[idx:]:
# Determine the portion of the record that falls within this specific statistical period
overlap_start = max(record_start_timestamp, current_period_start_time)
overlap_end = effective_end_time # Already capped by 'now' and record's own end
if overlap_end > overlap_start:
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
break
return stats
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
@@ -322,55 +313,55 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
if not collect_period:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 消息统计
TOTAL_MSG_CNT: 0,
MSG_CNT_BY_CHAT: defaultdict(int),
}
for period_key, _ in collect_period
}
# 统计消息量
for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}):
chat_info = message.get("chat_info", None) # 聊天信息
user_info = message.get("user_info", None) # 用户信息(消息发送人)
message_time = message.get("time", 0) # 消息时间
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp):
message_time_ts = message.time # This is a float timestamp
group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息
if group_info is not None:
# 若有群聊信息
chat_id = f"g{group_info.get('group_id')}"
chat_name = group_info.get("group_name", f"{group_info.get('group_id')}")
elif user_info:
# 若没有群聊信息,则尝试获取用户信息
chat_id = f"u{user_info['user_id']}"
chat_name = user_info["user_nickname"]
chat_id = None
chat_name = None
# Logic based on Peewee model structure, aiming to replicate original intent
if message.chat_info_group_id:
chat_id = f"g{message.chat_info_group_id}"
chat_name = message.chat_info_group_name or f"{message.chat_info_group_id}"
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
# This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message.user_id}" # SENDER's user_id
chat_name = message.user_nickname # SENDER's nickname
else:
continue # 如果没有群组信息也没有用户信息,则跳过
# If neither group_id nor sender_id is available for chat identification
logger.warning(f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats.")
continue
if not chat_id: # Should not happen if above logic is correct
continue
# Update name_mapping
if chat_id in self.name_mapping:
if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]:
# 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称
self.name_mapping[chat_id] = (chat_name, message_time)
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
self.name_mapping[chat_id] = (chat_name, message_time_ts)
else:
self.name_mapping[chat_id] = (chat_name, message_time)
self.name_mapping[chat_id] = (chat_name, message_time_ts)
for idx, (_, period_start) in enumerate(collect_period):
if message_time >= period_start.timestamp():
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for idx, (_, period_start_dt) in enumerate(collect_period):
if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break
return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:

View File

@@ -240,3 +240,63 @@ class ThinkingLog(BaseModel):
class Meta:
table_name = 'thinking_logs'
def create_tables():
"""
创建所有在模型中定义的数据库表。
"""
with db:
db.create_tables([
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog
])
def initialize_database():
"""
检查所有定义的表是否存在,如果不存在则创建它们。
"""
models = [
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog
]
needs_creation = False
try:
with db: # 管理 table_exists 检查的连接
for model in models:
if not db.table_exists(model):
print(f"'{model._meta.table_name}' 未找到。")
needs_creation = True
break # 一个表丢失,无需进一步检查。
except Exception as e:
print(f"检查表是否存在时出错: {e}")
# 如果检查失败(例如数据库不可用),则退出
return
if needs_creation:
print("正在初始化数据库:一个或多个表丢失。正在尝试创建所有定义的表...")
try:
create_tables() # 此函数有其自己的 'with db:' 上下文管理。
print("数据库表创建过程完成。")
except Exception as e:
print(f"创建表期间出错: {e}")
else:
print("所有数据库表均已存在。")
# 模块加载时调用初始化函数
initialize_database()