From 2051b011b12090ae9765b865a0f30131be608134 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 14 May 2025 23:04:22 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93=E8=A1=A8=E5=88=9B=E5=BB=BA=E5=92=8C=E5=88=9D=E5=A7=8B?= =?UTF-8?q?=E5=8C=96=E5=8A=9F=E8=83=BD=EF=BC=8C=E7=A1=AE=E4=BF=9D=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E8=A1=A8=E5=AD=98=E5=9C=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/utils/statistic.py | 145 ++++++++++++-------------- src/common/database/database_model.py | 60 +++++++++++ 2 files changed, 128 insertions(+), 77 deletions(-) diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 9a0131f74..88329c3f4 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -6,7 +6,7 @@ from src.common.logger import get_module_logger from src.manager.async_task_manager import AsyncTask from ...common.database.database import db # This db is the Peewee database instance -from ...common.database.database_model import OnlineTime # Import the Peewee model +from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model from src.manager.local_store_manager import local_storage logger = get_module_logger("maibot_statistic") @@ -195,35 +195,28 @@ class StatisticOutputTask(AsyncTask): :param collect_period: 统计时间段 """ - if len(collect_period) <= 0: + if not collect_period: return {} - else: - # 排序-按照时间段开始时间降序排列(最晚的时间段在前) - collect_period.sort(key=lambda x: x[1], reverse=True) + + # 排序-按照时间段开始时间降序排列(最晚的时间段在前) + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { period_key: { - # 总LLM请求数 TOTAL_REQ_CNT: 0, - # 请求次数统计 REQ_CNT_BY_TYPE: defaultdict(int), REQ_CNT_BY_USER: defaultdict(int), REQ_CNT_BY_MODEL: defaultdict(int), - # 输入Token数 IN_TOK_BY_TYPE: defaultdict(int), IN_TOK_BY_USER: defaultdict(int), IN_TOK_BY_MODEL: defaultdict(int), - # 输出Token数 OUT_TOK_BY_TYPE: defaultdict(int), OUT_TOK_BY_USER: defaultdict(int), OUT_TOK_BY_MODEL: defaultdict(int), - # 总Token数 TOTAL_TOK_BY_TYPE: defaultdict(int), TOTAL_TOK_BY_USER: defaultdict(int), TOTAL_TOK_BY_MODEL: defaultdict(int), - # 总开销 TOTAL_COST: 0.0, - # 请求开销统计 COST_BY_TYPE: defaultdict(float), COST_BY_USER: defaultdict(float), COST_BY_MODEL: defaultdict(float), @@ -232,26 +225,26 @@ class StatisticOutputTask(AsyncTask): } # 以最早的时间戳为起始时间获取记录 - for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}): - record_timestamp = record.get("timestamp") + # Assuming LLMUsage.timestamp is a DateTimeField + query_start_time = collect_period[-1][1] + for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): + record_timestamp = record.timestamp # This is already a datetime object for idx, (_, period_start) in enumerate(collect_period): if record_timestamp >= period_start: - # 如果记录时间在当前时间段内,则它一定在更早的时间段内 - # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据 for period_key, _ in collect_period[idx:]: stats[period_key][TOTAL_REQ_CNT] += 1 - request_type = record.get("request_type", "unknown") # 请求类型 - user_id = str(record.get("user_id", "unknown")) # 用户ID - model_name = record.get("model_name", "unknown") # 模型名称 + request_type = record.request_type or "unknown" + user_id = record.user_id or "unknown" # user_id is TextField, already string + model_name = record.model_name or "unknown" stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1 stats[period_key][REQ_CNT_BY_USER][user_id] += 1 stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1 - prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数 - completion_tokens = record.get("completion_tokens", 0) # 输出Token数 - total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数 + prompt_tokens = record.prompt_tokens or 0 + completion_tokens = record.completion_tokens or 0 + total_tokens = prompt_tokens + completion_tokens stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens @@ -265,13 +258,12 @@ class StatisticOutputTask(AsyncTask): stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens - cost = record.get("cost", 0.0) + cost = record.cost or 0.0 stats[period_key][TOTAL_COST] += cost stats[period_key][COST_BY_TYPE][request_type] += cost stats[period_key][COST_BY_USER][user_id] += cost stats[period_key][COST_BY_MODEL][model_name] += cost - break # 取消更早时间段的判断 - + break return stats @staticmethod @@ -281,39 +273,38 @@ class StatisticOutputTask(AsyncTask): :param collect_period: 统计时间段 """ - if len(collect_period) <= 0: + if not collect_period: return {} - else: - # 排序-按照时间段开始时间降序排列(最晚的时间段在前) - collect_period.sort(key=lambda x: x[1], reverse=True) + + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { period_key: { - # 在线时间统计 ONLINE_TIME: 0.0, } for period_key, _ in collect_period } - # 统计在线时间 - for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}): - end_timestamp: datetime = record.get("end_timestamp") - for idx, (_, period_start) in enumerate(collect_period): - if end_timestamp >= period_start: - # 由于end_timestamp会超前标记时间,所以我们需要判断是否晚于当前时间,如果是,则使用当前时间作为结束时间 - end_timestamp = min(end_timestamp, now) - # 如果记录时间在当前时间段内,则它一定在更早的时间段内 - # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据 - for period_key, _period_start in collect_period[idx:]: - start_timestamp: datetime = record.get("start_timestamp") - if start_timestamp < _period_start: - # 如果开始时间在查询边界之前,则使用开始时间 - stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds() - else: - # 否则,使用开始时间 - stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds() - break # 取消更早时间段的判断 + query_start_time = collect_period[-1][1] + # Assuming OnlineTime.end_timestamp is a DateTimeField + for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): + # record.end_timestamp and record.start_timestamp are datetime objects + record_end_timestamp = record.end_timestamp + record_start_timestamp = record.start_timestamp + for idx, (_, period_boundary_start) in enumerate(collect_period): + if record_end_timestamp >= period_boundary_start: + # Calculate effective end time for this record in relation to 'now' + effective_end_time = min(record_end_timestamp, now) + + for period_key, current_period_start_time in collect_period[idx:]: + # Determine the portion of the record that falls within this specific statistical period + overlap_start = max(record_start_timestamp, current_period_start_time) + overlap_end = effective_end_time # Already capped by 'now' and record's own end + + if overlap_end > overlap_start: + stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds() + break return stats def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: @@ -322,55 +313,55 @@ class StatisticOutputTask(AsyncTask): :param collect_period: 统计时间段 """ - if len(collect_period) <= 0: + if not collect_period: return {} - else: - # 排序-按照时间段开始时间降序排列(最晚的时间段在前) - collect_period.sort(key=lambda x: x[1], reverse=True) + + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { period_key: { - # 消息统计 TOTAL_MSG_CNT: 0, MSG_CNT_BY_CHAT: defaultdict(int), } for period_key, _ in collect_period } - # 统计消息量 - for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}): - chat_info = message.get("chat_info", None) # 聊天信息 - user_info = message.get("user_info", None) # 用户信息(消息发送人) - message_time = message.get("time", 0) # 消息时间 + query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) + for message in Messages.select().where(Messages.time >= query_start_timestamp): + message_time_ts = message.time # This is a float timestamp - group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息 - if group_info is not None: - # 若有群聊信息 - chat_id = f"g{group_info.get('group_id')}" - chat_name = group_info.get("group_name", f"群{group_info.get('group_id')}") - elif user_info: - # 若没有群聊信息,则尝试获取用户信息 - chat_id = f"u{user_info['user_id']}" - chat_name = user_info["user_nickname"] + chat_id = None + chat_name = None + + # Logic based on Peewee model structure, aiming to replicate original intent + if message.chat_info_group_id: + chat_id = f"g{message.chat_info_group_id}" + chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}" + elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat + # This uses the message SENDER's ID as per original logic's fallback + chat_id = f"u{message.user_id}" # SENDER's user_id + chat_name = message.user_nickname # SENDER's nickname else: - continue # 如果没有群组信息也没有用户信息,则跳过 + # If neither group_id nor sender_id is available for chat identification + logger.warning(f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats.") + continue + + if not chat_id: # Should not happen if above logic is correct + continue + # Update name_mapping if chat_id in self.name_mapping: - if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]: - # 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称 - self.name_mapping[chat_id] = (chat_name, message_time) + if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]: + self.name_mapping[chat_id] = (chat_name, message_time_ts) else: - self.name_mapping[chat_id] = (chat_name, message_time) + self.name_mapping[chat_id] = (chat_name, message_time_ts) - for idx, (_, period_start) in enumerate(collect_period): - if message_time >= period_start.timestamp(): - # 如果记录时间在当前时间段内,则它一定在更早的时间段内 - # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据 + for idx, (_, period_start_dt) in enumerate(collect_period): + if message_time_ts >= period_start_dt.timestamp(): for period_key, _ in collect_period[idx:]: stats[period_key][TOTAL_MSG_CNT] += 1 stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1 break - return stats def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index b46cace9f..89e047414 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -240,3 +240,63 @@ class ThinkingLog(BaseModel): class Meta: table_name = 'thinking_logs' +def create_tables(): + """ + 创建所有在模型中定义的数据库表。 + """ + with db: + db.create_tables([ + ChatStreams, + LLMUsage, + Emoji, + Messages, + Images, + ImageDescriptions, + OnlineTime, + PersonInfo, + Knowledges, + ThinkingLog + ]) + +def initialize_database(): + """ + 检查所有定义的表是否存在,如果不存在则创建它们。 + """ + models = [ + ChatStreams, + LLMUsage, + Emoji, + Messages, + Images, + ImageDescriptions, + OnlineTime, + PersonInfo, + Knowledges, + ThinkingLog + ] + + needs_creation = False + try: + with db: # 管理 table_exists 检查的连接 + for model in models: + if not db.table_exists(model): + print(f"表 '{model._meta.table_name}' 未找到。") + needs_creation = True + break # 一个表丢失,无需进一步检查。 + except Exception as e: + print(f"检查表是否存在时出错: {e}") + # 如果检查失败(例如数据库不可用),则退出 + return + + if needs_creation: + print("正在初始化数据库:一个或多个表丢失。正在尝试创建所有定义的表...") + try: + create_tables() # 此函数有其自己的 'with db:' 上下文管理。 + print("数据库表创建过程完成。") + except Exception as e: + print(f"创建表期间出错: {e}") + else: + print("所有数据库表均已存在。") + +# 模块加载时调用初始化函数 +initialize_database()