修复代码格式和文件名大小写问题
This commit is contained in:
@@ -13,45 +13,45 @@ from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("maibot_statistic")
|
||||
|
||||
|
||||
# 同步包装器函数,用于在非异步环境中调用异步数据库API
|
||||
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
|
||||
"""同步版本的db_get,用于在线程池中调用"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 如果事件循环正在运行,创建新的事件循环
|
||||
import threading
|
||||
|
||||
result = None
|
||||
exception = None
|
||||
|
||||
|
||||
def run_in_thread():
|
||||
nonlocal result, exception
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
result = new_loop.run_until_complete(
|
||||
db_get(model_class, filters, limit, order_by, single_result)
|
||||
)
|
||||
result = new_loop.run_until_complete(db_get(model_class, filters, limit, order_by, single_result))
|
||||
new_loop.close()
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
|
||||
thread = threading.Thread(target=run_in_thread)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return result
|
||||
else:
|
||||
return loop.run_until_complete(
|
||||
db_get(model_class, filters, limit, order_by, single_result)
|
||||
)
|
||||
return loop.run_until_complete(db_get(model_class, filters, limit, order_by, single_result))
|
||||
except RuntimeError:
|
||||
# 没有事件循环,创建一个新的
|
||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
||||
|
||||
|
||||
# 统计数据的键
|
||||
TOTAL_REQ_CNT = "total_requests"
|
||||
TOTAL_COST = "total_cost"
|
||||
@@ -112,7 +112,7 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
model_class=OnlineTime,
|
||||
query_type="update",
|
||||
filters={"id": self.record_id},
|
||||
data={"end_timestamp": extended_end_time}
|
||||
data={"end_timestamp": extended_end_time},
|
||||
)
|
||||
if updated_rows == 0:
|
||||
# Record might have been deleted or ID is stale, try to find/create
|
||||
@@ -126,17 +126,17 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
filters={"end_timestamp": {"$gte": recent_threshold}},
|
||||
order_by="-end_timestamp",
|
||||
limit=1,
|
||||
single_result=True
|
||||
single_result=True,
|
||||
)
|
||||
|
||||
|
||||
if recent_records:
|
||||
# 找到近期记录,更新它
|
||||
self.record_id = recent_records['id']
|
||||
self.record_id = recent_records["id"]
|
||||
await db_query(
|
||||
model_class=OnlineTime,
|
||||
query_type="update",
|
||||
filters={"id": self.record_id},
|
||||
data={"end_timestamp": extended_end_time}
|
||||
data={"end_timestamp": extended_end_time},
|
||||
)
|
||||
else:
|
||||
# 创建新记录
|
||||
@@ -147,10 +147,10 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
"duration": 5, # 初始时长为5分钟
|
||||
"start_timestamp": current_time,
|
||||
"end_timestamp": extended_end_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
if new_record:
|
||||
self.record_id = new_record['id']
|
||||
self.record_id = new_record["id"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"在线时间记录失败,错误信息:{e}")
|
||||
@@ -368,20 +368,19 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 以最早的时间戳为起始时间获取记录
|
||||
query_start_time = collect_period[-1][1]
|
||||
records = _sync_db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp"
|
||||
) or []
|
||||
|
||||
records = (
|
||||
_sync_db_get(model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp")
|
||||
or []
|
||||
)
|
||||
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
record_timestamp = record.get('timestamp')
|
||||
|
||||
record_timestamp = record.get("timestamp")
|
||||
if isinstance(record_timestamp, str):
|
||||
record_timestamp = datetime.fromisoformat(record_timestamp)
|
||||
|
||||
|
||||
if not record_timestamp:
|
||||
continue
|
||||
|
||||
@@ -390,9 +389,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||
|
||||
request_type = record.get('request_type') or "unknown"
|
||||
user_id = record.get('user_id') or "unknown"
|
||||
model_name = record.get('model_name') or "unknown"
|
||||
request_type = record.get("request_type") or "unknown"
|
||||
user_id = record.get("user_id") or "unknown"
|
||||
model_name = record.get("model_name") or "unknown"
|
||||
|
||||
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
@@ -402,8 +401,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
||||
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
|
||||
|
||||
prompt_tokens = record.get('prompt_tokens') or 0
|
||||
completion_tokens = record.get('completion_tokens') or 0
|
||||
prompt_tokens = record.get("prompt_tokens") or 0
|
||||
completion_tokens = record.get("completion_tokens") or 0
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||
@@ -421,40 +420,40 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
|
||||
|
||||
cost = record.get('cost') or 0.0
|
||||
cost = record.get("cost") or 0.0
|
||||
stats[period_key][TOTAL_COST] += cost
|
||||
stats[period_key][COST_BY_TYPE][request_type] += cost
|
||||
stats[period_key][COST_BY_USER][user_id] += cost
|
||||
stats[period_key][COST_BY_MODEL][model_name] += cost
|
||||
stats[period_key][COST_BY_MODULE][module_name] += cost
|
||||
|
||||
|
||||
# 收集time_cost数据
|
||||
time_cost = record.get('time_cost') or 0.0
|
||||
time_cost = record.get("time_cost") or 0.0
|
||||
if time_cost > 0: # 只记录有效的time_cost
|
||||
stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost)
|
||||
break
|
||||
|
||||
# 计算平均耗时和标准差
|
||||
|
||||
# 计算平均耗时和标准差
|
||||
for period_key in stats:
|
||||
for category in [REQ_CNT_BY_TYPE, REQ_CNT_BY_USER, REQ_CNT_BY_MODEL, REQ_CNT_BY_MODULE]:
|
||||
time_cost_key = f"time_costs_by_{category.split('_')[-1]}"
|
||||
avg_key = f"avg_time_costs_by_{category.split('_')[-1]}"
|
||||
std_key = f"std_time_costs_by_{category.split('_')[-1]}"
|
||||
|
||||
|
||||
for item_name in stats[period_key][category]:
|
||||
time_costs = stats[period_key][time_cost_key].get(item_name, [])
|
||||
if time_costs:
|
||||
# 计算平均耗时
|
||||
avg_time_cost = sum(time_costs) / len(time_costs)
|
||||
stats[period_key][avg_key][item_name] = round(avg_time_cost, 3)
|
||||
|
||||
|
||||
# 计算标准差
|
||||
if len(time_costs) > 1:
|
||||
variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs)
|
||||
std_time_cost = variance ** 0.5
|
||||
std_time_cost = variance**0.5
|
||||
stats[period_key][std_key][item_name] = round(std_time_cost, 3)
|
||||
else:
|
||||
stats[period_key][std_key][item_name] = 0.0
|
||||
@@ -483,21 +482,22 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_time = collect_period[-1][1]
|
||||
records = _sync_db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": query_start_time}},
|
||||
order_by="-end_timestamp"
|
||||
) or []
|
||||
|
||||
records = (
|
||||
_sync_db_get(
|
||||
model_class=OnlineTime, filters={"end_timestamp": {"$gte": query_start_time}}, order_by="-end_timestamp"
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
record_end_timestamp = record.get('end_timestamp')
|
||||
record_end_timestamp = record.get("end_timestamp")
|
||||
if isinstance(record_end_timestamp, str):
|
||||
record_end_timestamp = datetime.fromisoformat(record_end_timestamp)
|
||||
|
||||
record_start_timestamp = record.get('start_timestamp')
|
||||
record_start_timestamp = record.get("start_timestamp")
|
||||
if isinstance(record_start_timestamp, str):
|
||||
record_start_timestamp = datetime.fromisoformat(record_start_timestamp)
|
||||
|
||||
@@ -539,16 +539,15 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
records = _sync_db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time"
|
||||
) or []
|
||||
|
||||
records = (
|
||||
_sync_db_get(model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time")
|
||||
or []
|
||||
)
|
||||
|
||||
for message in records:
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
message_time_ts = message.get('time') # This is a float timestamp
|
||||
message_time_ts = message.get("time") # This is a float timestamp
|
||||
|
||||
if not message_time_ts:
|
||||
continue
|
||||
@@ -557,18 +556,16 @@ class StatisticOutputTask(AsyncTask):
|
||||
chat_name = None
|
||||
|
||||
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
|
||||
if message.get('chat_info_group_id'):
|
||||
if message.get("chat_info_group_id"):
|
||||
chat_id = f"g{message['chat_info_group_id']}"
|
||||
chat_name = message.get('chat_info_group_name') or f"群{message['chat_info_group_id']}"
|
||||
elif message.get('user_id'): # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}"
|
||||
elif message.get("user_id"): # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
# This uses the message SENDER's ID as per original logic's fallback
|
||||
chat_id = f"u{message['user_id']}" # SENDER's user_id
|
||||
chat_name = message.get('user_nickname') # SENDER's nickname
|
||||
chat_name = message.get("user_nickname") # SENDER's nickname
|
||||
else:
|
||||
# If neither group_id nor sender_id is available for chat identification
|
||||
logger.warning(
|
||||
f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats."
|
||||
)
|
||||
logger.warning(f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats.")
|
||||
continue
|
||||
|
||||
if not chat_id: # Should not happen if above logic is correct
|
||||
@@ -589,8 +586,6 @@ class StatisticOutputTask(AsyncTask):
|
||||
break
|
||||
return stats
|
||||
|
||||
|
||||
|
||||
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
收集各时间段的统计数据
|
||||
@@ -721,7 +716,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = stats[COST_BY_MODEL][model_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
|
||||
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost))
|
||||
output.append(
|
||||
data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)
|
||||
)
|
||||
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
@@ -1109,13 +1106,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 查询LLM使用记录
|
||||
query_start_time = start_time
|
||||
records = _sync_db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp"
|
||||
model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp"
|
||||
)
|
||||
|
||||
|
||||
for record in records:
|
||||
record_time = record['timestamp']
|
||||
record_time = record["timestamp"]
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = (record_time - start_time).total_seconds()
|
||||
@@ -1123,17 +1118,17 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 累加总花费数据
|
||||
cost = record.get('cost') or 0.0
|
||||
cost = record.get("cost") or 0.0
|
||||
total_cost_data[interval_index] += cost # type: ignore
|
||||
|
||||
# 累加按模型分类的花费
|
||||
model_name = record.get('model_name') or "unknown"
|
||||
model_name = record.get("model_name") or "unknown"
|
||||
if model_name not in cost_by_model:
|
||||
cost_by_model[model_name] = [0] * len(time_points)
|
||||
cost_by_model[model_name][interval_index] += cost
|
||||
|
||||
# 累加按模块分类的花费
|
||||
request_type = record.get('request_type') or "unknown"
|
||||
request_type = record.get("request_type") or "unknown"
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
if module_name not in cost_by_module:
|
||||
cost_by_module[module_name] = [0] * len(time_points)
|
||||
@@ -1142,13 +1137,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
records = _sync_db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time"
|
||||
model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time"
|
||||
)
|
||||
|
||||
|
||||
for message in records:
|
||||
message_time_ts = message['time']
|
||||
message_time_ts = message["time"]
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = message_time_ts - query_start_timestamp
|
||||
@@ -1157,10 +1150,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 确定聊天流名称
|
||||
chat_name = None
|
||||
if message.get('chat_info_group_id'):
|
||||
chat_name = message.get('chat_info_group_name') or f"群{message['chat_info_group_id']}"
|
||||
elif message.get('user_id'):
|
||||
chat_name = message.get('user_nickname') or f"用户{message['user_id']}"
|
||||
if message.get("chat_info_group_id"):
|
||||
chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}"
|
||||
elif message.get("user_id"):
|
||||
chat_name = message.get("user_nickname") or f"用户{message['user_id']}"
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user