初始化
This commit is contained in:
@@ -6,13 +6,52 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Tuple, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session, db_query, db_save, db_get
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("maibot_statistic")
|
||||
|
||||
# 同步包装器函数,用于在非异步环境中调用异步数据库API
|
||||
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
|
||||
"""同步版本的db_get,用于在线程池中调用"""
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 如果事件循环正在运行,创建新的事件循环
|
||||
import threading
|
||||
result = None
|
||||
exception = None
|
||||
|
||||
def run_in_thread():
|
||||
nonlocal result, exception
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
result = new_loop.run_until_complete(
|
||||
db_get(model_class, filters, limit, order_by, single_result)
|
||||
)
|
||||
new_loop.close()
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
thread = threading.Thread(target=run_in_thread)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return result
|
||||
else:
|
||||
return loop.run_until_complete(
|
||||
db_get(model_class, filters, limit, order_by, single_result)
|
||||
)
|
||||
except RuntimeError:
|
||||
# 没有事件循环,创建一个新的
|
||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
||||
|
||||
# 统计数据的键
|
||||
TOTAL_REQ_CNT = "total_requests"
|
||||
TOTAL_COST = "total_cost"
|
||||
@@ -59,17 +98,9 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
def __init__(self):
|
||||
super().__init__(task_name="Online Time Record Task", run_interval=60)
|
||||
|
||||
self.record_id: int | None = None # Changed to int for Peewee's default ID
|
||||
self.record_id: int | None = None
|
||||
"""记录ID"""
|
||||
|
||||
self._init_database() # 初始化数据库
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库"""
|
||||
with db.atomic(): # Use atomic operations for schema changes
|
||||
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
|
||||
|
||||
async def run(self): # sourcery skip: use-named-expression
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
@@ -77,36 +108,50 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
|
||||
if self.record_id:
|
||||
# 如果有记录,则更新结束时间
|
||||
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore
|
||||
updated_rows = query.execute()
|
||||
updated_rows = await db_query(
|
||||
model_class=OnlineTime,
|
||||
query_type="update",
|
||||
filters={"id": self.record_id},
|
||||
data={"end_timestamp": extended_end_time}
|
||||
)
|
||||
if updated_rows == 0:
|
||||
# Record might have been deleted or ID is stale, try to find/create
|
||||
self.record_id = None # Reset record_id to trigger find/create logic below
|
||||
self.record_id = None
|
||||
|
||||
if not self.record_id: # Check again if record_id was reset or initially None
|
||||
# 如果没有记录,检查一分钟以内是否已有记录
|
||||
# Look for a record whose end_timestamp is recent enough to be considered ongoing
|
||||
recent_record = (
|
||||
OnlineTime.select()
|
||||
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore
|
||||
.order_by(OnlineTime.end_timestamp.desc())
|
||||
.first()
|
||||
if not self.record_id:
|
||||
# 查找最近一分钟内的记录
|
||||
recent_threshold = current_time - timedelta(minutes=1)
|
||||
recent_records = await db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": recent_threshold}},
|
||||
order_by="-end_timestamp",
|
||||
limit=1,
|
||||
single_result=True
|
||||
)
|
||||
|
||||
if recent_record:
|
||||
# 如果有记录,则更新结束时间
|
||||
self.record_id = recent_record.id
|
||||
recent_record.end_timestamp = extended_end_time
|
||||
recent_record.save()
|
||||
else:
|
||||
# 若没有记录,则插入新的在线时间记录
|
||||
new_record = OnlineTime.create(
|
||||
timestamp=current_time.timestamp(), # 添加此行
|
||||
start_timestamp=current_time,
|
||||
end_timestamp=extended_end_time,
|
||||
duration=5, # 初始时长为5分钟
|
||||
|
||||
if recent_records:
|
||||
# 找到近期记录,更新它
|
||||
self.record_id = recent_records['id']
|
||||
await db_query(
|
||||
model_class=OnlineTime,
|
||||
query_type="update",
|
||||
filters={"id": self.record_id},
|
||||
data={"end_timestamp": extended_end_time}
|
||||
)
|
||||
self.record_id = new_record.id
|
||||
else:
|
||||
# 创建新记录
|
||||
new_record = await db_save(
|
||||
model_class=OnlineTime,
|
||||
data={
|
||||
"timestamp": str(current_time),
|
||||
"duration": 5, # 初始时长为5分钟
|
||||
"start_timestamp": current_time,
|
||||
"end_timestamp": extended_end_time,
|
||||
}
|
||||
)
|
||||
if new_record:
|
||||
self.record_id = new_record['id']
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"在线时间记录失败,错误信息:{e}")
|
||||
|
||||
@@ -322,18 +367,23 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
# 以最早的时间戳为起始时间获取记录
|
||||
# Assuming LLMUsage.timestamp is a DateTimeField
|
||||
query_start_time = collect_period[-1][1]
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
|
||||
record_timestamp = record.timestamp # This is already a datetime object
|
||||
records = _sync_db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp"
|
||||
)
|
||||
|
||||
for record in records:
|
||||
record_timestamp = record['timestamp'] # 从字典中获取
|
||||
for idx, (_, period_start) in enumerate(collect_period):
|
||||
if record_timestamp >= period_start:
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||
|
||||
request_type = record.request_type or "unknown"
|
||||
user_id = record.user_id or "unknown" # user_id is TextField, already string
|
||||
model_name = record.model_name or "unknown"
|
||||
request_type = record.get('request_type') or "unknown"
|
||||
user_id = record.get('user_id') or "unknown"
|
||||
model_name = record.get('model_name') or "unknown"
|
||||
|
||||
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
@@ -343,8 +393,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
||||
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
|
||||
|
||||
prompt_tokens = record.prompt_tokens or 0
|
||||
completion_tokens = record.completion_tokens or 0
|
||||
prompt_tokens = record.get('prompt_tokens') or 0
|
||||
completion_tokens = record.get('completion_tokens') or 0
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||
@@ -362,7 +412,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
|
||||
|
||||
cost = record.cost or 0.0
|
||||
cost = record.get('cost') or 0.0
|
||||
stats[period_key][TOTAL_COST] += cost
|
||||
stats[period_key][COST_BY_TYPE][request_type] += cost
|
||||
stats[period_key][COST_BY_USER][user_id] += cost
|
||||
@@ -425,11 +475,15 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_time = collect_period[-1][1]
|
||||
# Assuming OnlineTime.end_timestamp is a DateTimeField
|
||||
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore
|
||||
# record.end_timestamp and record.start_timestamp are datetime objects
|
||||
record_end_timestamp = record.end_timestamp
|
||||
record_start_timestamp = record.start_timestamp
|
||||
records = _sync_db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": query_start_time}},
|
||||
order_by="-end_timestamp"
|
||||
)
|
||||
|
||||
for record in records:
|
||||
record_end_timestamp = record['end_timestamp']
|
||||
record_start_timestamp = record['start_timestamp']
|
||||
|
||||
for idx, (_, period_boundary_start) in enumerate(collect_period):
|
||||
if record_end_timestamp >= period_boundary_start:
|
||||
@@ -466,24 +520,30 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time # This is a float timestamp
|
||||
records = _sync_db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time"
|
||||
)
|
||||
|
||||
for message in records:
|
||||
message_time_ts = message['time'] # This is a float timestamp
|
||||
|
||||
chat_id = None
|
||||
chat_name = None
|
||||
|
||||
# Logic based on Peewee model structure, aiming to replicate original intent
|
||||
if message.chat_info_group_id:
|
||||
chat_id = f"g{message.chat_info_group_id}"
|
||||
chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}"
|
||||
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
|
||||
if message.get('chat_info_group_id'):
|
||||
chat_id = f"g{message['chat_info_group_id']}"
|
||||
chat_name = message.get('chat_info_group_name') or f"群{message['chat_info_group_id']}"
|
||||
elif message.get('user_id'): # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
# This uses the message SENDER's ID as per original logic's fallback
|
||||
chat_id = f"u{message.user_id}" # SENDER's user_id
|
||||
chat_name = message.user_nickname # SENDER's nickname
|
||||
chat_id = f"u{message['user_id']}" # SENDER's user_id
|
||||
chat_name = message.get('user_nickname') # SENDER's nickname
|
||||
else:
|
||||
# If neither group_id nor sender_id is available for chat identification
|
||||
logger.warning(
|
||||
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
|
||||
f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats."
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -1025,8 +1085,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询LLM使用记录
|
||||
query_start_time = start_time
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
|
||||
record_time = record.timestamp
|
||||
records = _sync_db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp"
|
||||
)
|
||||
|
||||
for record in records:
|
||||
record_time = record['timestamp']
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = (record_time - start_time).total_seconds()
|
||||
@@ -1034,17 +1100,17 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 累加总花费数据
|
||||
cost = record.cost or 0.0
|
||||
cost = record.get('cost') or 0.0
|
||||
total_cost_data[interval_index] += cost # type: ignore
|
||||
|
||||
# 累加按模型分类的花费
|
||||
model_name = record.model_name or "unknown"
|
||||
model_name = record.get('model_name') or "unknown"
|
||||
if model_name not in cost_by_model:
|
||||
cost_by_model[model_name] = [0] * len(time_points)
|
||||
cost_by_model[model_name][interval_index] += cost
|
||||
|
||||
# 累加按模块分类的花费
|
||||
request_type = record.request_type or "unknown"
|
||||
request_type = record.get('request_type') or "unknown"
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
if module_name not in cost_by_module:
|
||||
cost_by_module[module_name] = [0] * len(time_points)
|
||||
@@ -1052,8 +1118,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time
|
||||
records = _sync_db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time"
|
||||
)
|
||||
|
||||
for message in records:
|
||||
message_time_ts = message['time']
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = message_time_ts - query_start_timestamp
|
||||
@@ -1062,10 +1134,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 确定聊天流名称
|
||||
chat_name = None
|
||||
if message.chat_info_group_id:
|
||||
chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}"
|
||||
elif message.user_id:
|
||||
chat_name = message.user_nickname or f"用户{message.user_id}"
|
||||
if message.get('chat_info_group_id'):
|
||||
chat_name = message.get('chat_info_group_name') or f"群{message['chat_info_group_id']}"
|
||||
elif message.get('user_id'):
|
||||
chat_name = message.get('user_nickname') or f"用户{message['user_id']}"
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user