feat: 添加异步任务管理器和本地存储管理器,重构统计模块

This commit is contained in:
Oct-autumn
2025-05-07 18:20:26 +08:00
parent b2b43c140f
commit 46d15b1fe7
5 changed files with 725 additions and 332 deletions

43
bot.py
View File

@@ -1,7 +1,6 @@
import asyncio import asyncio
import hashlib import hashlib
import os import os
import shutil
import sys import sys
from pathlib import Path from pathlib import Path
import time import time
@@ -15,6 +14,8 @@ from src.common.crash_logger import install_crash_handler
from src.main import MainSystem from src.main import MainSystem
from rich.traceback import install from rich.traceback import install
from src.manager.async_task_manager import async_task_manager
install(extra_lines=3) install(extra_lines=3)
# 设置工作目录为脚本所在目录 # 设置工作目录为脚本所在目录
@@ -64,38 +65,6 @@ def easter_egg():
print(rainbow_text) print(rainbow_text)
def init_config():
# 初次启动检测
if not os.path.exists("config/bot_config.toml"):
logger.warning("检测到bot_config.toml不存在正在从模板复制")
# 检查config目录是否存在
if not os.path.exists("config"):
os.makedirs("config")
logger.info("创建config目录")
shutil.copy("template/bot_config_template.toml", "config/bot_config.toml")
logger.info("复制完成请修改config/bot_config.toml和.env中的配置后重新启动")
if not os.path.exists("config/lpmm_config.toml"):
logger.warning("检测到lpmm_config.toml不存在正在从模板复制")
# 检查config目录是否存在
if not os.path.exists("config"):
os.makedirs("config")
logger.info("创建config目录")
shutil.copy("template/lpmm_config_template.toml", "config/lpmm_config.toml")
logger.info("复制完成请修改config/lpmm_config.toml和.env中的配置后重新启动")
def init_env():
# 检测.env文件是否存在
if not os.path.exists(".env"):
logger.error("检测到.env文件不存在")
shutil.copy("template/template.env", "./.env")
logger.info("已从template/template.env复制创建.env请修改配置后重新启动")
def load_env(): def load_env():
# 直接加载生产环境变量配置 # 直接加载生产环境变量配置
if os.path.exists(".env"): if os.path.exists(".env"):
@@ -140,6 +109,10 @@ def scan_provider(env_config: dict):
async def graceful_shutdown(): async def graceful_shutdown():
try: try:
logger.info("正在优雅关闭麦麦...") logger.info("正在优雅关闭麦麦...")
# 停止所有异步任务
await async_task_manager.stop_and_wait_all_tasks()
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks: for task in tasks:
task.cancel() task.cancel()
@@ -235,9 +208,9 @@ def raw_main():
check_eula() check_eula()
print("检查EULA和隐私条款完成") print("检查EULA和隐私条款完成")
easter_egg() easter_egg()
init_config()
init_env()
load_env() load_env()
env_config = {key: os.getenv(key) for key in os.environ} env_config = {key: os.getenv(key) for key in os.environ}

View File

@@ -1,6 +1,8 @@
import asyncio import asyncio
import time import time
from .plugins.utils.statistic import LLMStatistics
from .manager.async_task_manager import async_task_manager
from .plugins.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
from .plugins.moods.moods import MoodManager from .plugins.moods.moods import MoodManager
from .plugins.schedule.schedule_generator import bot_schedule from .plugins.schedule.schedule_generator import bot_schedule
from .plugins.emoji_system.emoji_manager import emoji_manager from .plugins.emoji_system.emoji_manager import emoji_manager
@@ -26,11 +28,13 @@ logger = get_logger("main")
class MainSystem: class MainSystem:
mood_manager: MoodManager
hippocampus_manager: HippocampusManager
individuality: Individuality
def __init__(self): def __init__(self):
self.llm_stats = LLMStatistics("llm_statistics.txt")
self.mood_manager = MoodManager.get_instance() self.mood_manager = MoodManager.get_instance()
self.hippocampus_manager = HippocampusManager.get_instance() self.hippocampus_manager = HippocampusManager.get_instance()
self._message_manager_started = False
self.individuality = Individuality.get_instance() self.individuality = Individuality.get_instance()
# 使用消息API替代直接的FastAPI实例 # 使用消息API替代直接的FastAPI实例
@@ -51,9 +55,12 @@ class MainSystem:
async def _init_components(self): async def _init_components(self):
"""初始化其他组件""" """初始化其他组件"""
init_start_time = time.time() init_start_time = time.time()
# 启动LLM统计
self.llm_stats.start() # 添加在线时间统计任务
logger.success("LLM统计功能启动成功") await async_task_manager.add_task(OnlineTimeRecordTask())
# 添加统计信息输出任务
await async_task_manager.add_task(StatisticOutputTask())
# 启动API服务器 # 启动API服务器
start_api_server() start_api_server()

View File

@@ -0,0 +1,150 @@
from abc import abstractmethod
import asyncio
from asyncio import Task, Event, Lock
from typing import Callable, Dict
from src.common.logger_manager import get_logger
logger = get_logger("async_task_manager")
class AsyncTask:
"""异步任务基类"""
def __init__(self, task_name: str | None = None, wait_before_start: int = 0, run_interval: int = 0):
self.task_name: str = task_name or self.__class__.__name__
"""任务名称"""
self.wait_before_start: int = wait_before_start
"""运行任务前是否进行等待单位设为0则不等待"""
self.run_interval: int = run_interval
"""多次运行的时间间隔单位设为0则仅运行一次"""
@abstractmethod
async def run(self):
"""
任务的执行过程
"""
pass
async def start_task(self, abort_flag: asyncio.Event):
if self.wait_before_start > 0:
# 等待指定时间后开始任务
await asyncio.sleep(self.wait_before_start)
while not abort_flag.is_set():
await self.run()
if self.run_interval > 0:
await asyncio.sleep(self.run_interval)
else:
break
class AsyncTaskManager:
"""异步任务管理器"""
def __init__(self):
self.tasks: Dict[str, Task] = {}
"""任务列表"""
self.abort_flag: Event = Event()
"""是否中止任务标志"""
self._lock: Lock = Lock()
"""异步锁当可能出现await时需要加锁"""
def _remove_task_call_back(self, task: Task):
"""
call_back: 任务完成后移除任务
"""
task_name = task.get_name()
if task_name in self.tasks:
# 任务完成后移除任务
del self.tasks[task_name]
logger.debug(f"已移除任务 '{task_name}'")
else:
logger.warning(f"尝试移除不存在的任务 '{task_name}'")
@staticmethod
def _default_finish_call_back(task: Task):
"""
call_back: 默认的任务完成回调函数
"""
try:
task.result()
logger.debug(f"任务 '{task.get_name()}' 完成")
except asyncio.CancelledError:
logger.debug(f"任务 '{task.get_name()}' 被取消")
except Exception as e:
logger.error(f"任务 '{task.get_name()}' 执行时发生异常: {e}", exc_info=True)
async def add_task(self, task: AsyncTask, call_back: Callable[[asyncio.Task], None] | None = None):
"""
添加任务
"""
if not issubclass(task.__class__, AsyncTask):
raise TypeError(f"task '{task.__class__.__name__}' 必须是继承 AsyncTask 的子类")
with self._lock: # 由于可能需要await等待任务完成所以需要加异步锁
if task.task_name in self.tasks:
logger.warning(f"已存在名称为 '{task.task_name}' 的任务,正在尝试取消并替换")
self.tasks[task.task_name].cancel() # 取消已存在的任务
await self.tasks[task.task_name] # 等待任务完成
logger.info(f"成功结束任务 '{task.task_name}'")
# 创建新任务
task_inst = asyncio.create_task(task.start_task(self.abort_flag))
task_inst.set_name(task.task_name)
task_inst.add_done_callback(self._remove_task_call_back) # 添加完成回调函数-完成任务后自动移除任务
task_inst.add_done_callback(
call_back or self._default_finish_call_back
) # 添加完成回调函数-用户自定义或默认的FallBack
self.tasks[task.task_name] = task_inst # 将任务添加到任务列表
logger.info(f"已启动任务 '{task.task_name}'")
def get_tasks_status(self) -> Dict[str, Dict[str, str]]:
"""
获取所有任务的状态
"""
tasks_status = {}
for task_name, task in self.tasks.items():
tasks_status[task_name] = {
"status": "running" if not task.done() else "done",
}
return tasks_status
async def stop_and_wait_all_tasks(self):
"""
终止所有任务并等待它们完成该方法会阻塞其它尝试add_task()的操作)
"""
with self._lock: # 由于可能需要await等待任务完成所以需要加异步锁
# 设置中止标志
self.abort_flag.set()
# 取消所有任务
for name, inst in self.tasks.items():
try:
inst.cancel()
except asyncio.CancelledError:
logger.info(f"已取消任务 '{name}'")
# 等待所有任务完成
for task_name, task_inst in self.tasks.items():
if not task_inst.done():
try:
await task_inst
except asyncio.CancelledError: # 此处再次捕获取消异常防止stop_all_tasks()时延迟抛出异常
logger.info(f"任务 {task_name} 已取消")
except Exception as e:
logger.error(f"任务 {task_name} 执行时发生异常: {e}", ext_info=True)
# 清空任务列表
self.tasks.clear()
self.abort_flag.clear()
logger.info("所有异步任务已停止")
async_task_manager = AsyncTaskManager()
"""全局异步任务管理器实例"""

View File

@@ -0,0 +1,67 @@
import json
import os
from src.common.logger_manager import get_logger
LOCAL_STORE_FILE_PATH = "data/local_store.json"
logger = get_logger("local_storage")
class LocalStoreManager:
file_path: str
"""本地存储路径"""
store: dict[str, str | list | dict | int | float | bool]
"""本地存储数据"""
def __init__(self, local_store_path: str | None = None):
self.file_path = local_store_path or LOCAL_STORE_FILE_PATH
self.store = {}
self.load_local_store()
def __getitem__(self, item: str) -> str | list | dict | int | float | bool | None:
"""获取本地存储数据"""
return self.store.get(item, None)
def __setitem__(self, key: str, value: str | list | dict | int | float | bool):
"""设置本地存储数据"""
self.store[key] = value
self.save_local_store()
def __contains__(self, item: str) -> bool:
"""检查本地存储数据是否存在"""
return item in self.store
def load_local_store(self):
"""加载本地存储数据"""
if os.path.exists(self.file_path):
# 存在本地存储文件,加载数据
logger.info("正在阅读记事本......我在看,我真的在看!")
logger.debug(f"加载本地存储数据: {self.file_path}")
try:
with open(self.file_path, "r", encoding="utf-8") as f:
self.store = json.load(f)
logger.success("全都记起来了!")
except json.JSONDecodeError:
logger.warning("啊咧?记事本被弄脏了,正在重建记事本......")
self.store = {}
with open(self.file_path, "w", encoding="utf-8") as f:
json.dump({}, f, ensure_ascii=False, indent=4)
logger.success("记事本重建成功!")
else:
# 不存在本地存储文件,创建新的目录和文件
logger.warning("啊咧?记事本不存在,正在创建新的记事本......")
os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
with open(self.file_path, "w", encoding="utf-8") as f:
json.dump({}, f, ensure_ascii=False, indent=4)
logger.success("记事本创建成功!")
def save_local_store(self):
"""保存本地存储数据"""
logger.debug(f"保存本地存储数据: {self.file_path}")
with open(self.file_path, "w", encoding="utf-8") as f:
json.dump(self.store, f, ensure_ascii=False, indent=4)
local_storage = LocalStoreManager("data/local_store.json") # 全局单例化

View File

@@ -1,354 +1,550 @@
import threading
import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List from typing import Any, Dict, Tuple, List
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.manager.async_task_manager import AsyncTask
from ...common.database import db from ...common.database import db
from src.manager.local_store_manager import local_storage
logger = get_module_logger("llm_statistics") logger = get_module_logger("maibot_statistic")
# 统计数据的键
TOTAL_REQ_CNT = "total_requests"
TOTAL_COST = "total_cost"
REQ_CNT_BY_TYPE = "requests_by_type"
REQ_CNT_BY_USER = "requests_by_user"
REQ_CNT_BY_MODEL = "requests_by_model"
IN_TOK_BY_TYPE = "in_tokens_by_type"
IN_TOK_BY_USER = "in_tokens_by_user"
IN_TOK_BY_MODEL = "in_tokens_by_model"
OUT_TOK_BY_TYPE = "out_tokens_by_type"
OUT_TOK_BY_USER = "out_tokens_by_user"
OUT_TOK_BY_MODEL = "out_tokens_by_model"
TOTAL_TOK_BY_TYPE = "tokens_by_type"
TOTAL_TOK_BY_USER = "tokens_by_user"
TOTAL_TOK_BY_MODEL = "tokens_by_model"
COST_BY_TYPE = "costs_by_type"
COST_BY_USER = "costs_by_user"
COST_BY_MODEL = "costs_by_model"
ONLINE_TIME = "online_time"
TOTAL_MSG_CNT = "total_messages"
MSG_CNT_BY_CHAT = "messages_by_chat"
class LLMStatistics: class OnlineTimeRecordTask(AsyncTask):
def __init__(self, output_file: str = "llm_statistics.txt"): """在线时间记录任务"""
"""初始化LLM统计类
Args: def __init__(self):
output_file: 统计结果输出文件路径 super().__init__(task_name="Online Time Record Task", run_interval=60)
"""
self.output_file = output_file self.record_id: str | None = None
self.running = False """记录ID"""
self.stats_thread = None
self.console_thread = None self._init_database() # 初始化数据库
self._init_database()
self.name_dict: Dict[List] = {}
@staticmethod @staticmethod
def _init_database(): def _init_database():
"""初始化数据库集合""" """初始化数据库"""
if "online_time" not in db.list_collection_names(): if "online_time" not in db.list_collection_names():
# 初始化数据库(在线时长)
db.create_collection("online_time") db.create_collection("online_time")
db.online_time.create_index([("timestamp", 1)]) # 创建索引
if ("end_timestamp", 1) not in db.online_time.list_indexes():
db.online_time.create_index([("end_timestamp", 1)])
def start(self): async def run(self):
"""启动统计线程""" try:
if not self.running: if self.record_id:
self.running = True # 如果有记录,则更新结束时间
# 启动文件统计线程 db.online_time.update_one(
self.stats_thread = threading.Thread(target=self._stats_loop) {"_id": self.record_id},
self.stats_thread.daemon = True {
self.stats_thread.start() "$set": {
# 启动控制台输出线程 "end_timestamp": datetime.now() + timedelta(minutes=1),
self.console_thread = threading.Thread(target=self._console_output_loop) }
self.console_thread.daemon = True },
self.console_thread.start() )
else:
# 如果没有记录,检查一分钟以内是否已有记录
current_time = datetime.now()
recent_record = db.online_time.find_one(
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}}
)
def stop(self): if not recent_record:
"""停止统计线程""" # 若没有记录,则插入新的在线时间记录
self.running = False self.record_id = db.online_time.insert_one(
if self.stats_thread: {
self.stats_thread.join() "start_timestamp": current_time,
if self.console_thread: "end_timestamp": current_time + timedelta(minutes=1),
self.console_thread.join() }
).inserted_id
else:
# 如果有记录,则更新结束时间
self.record_id = recent_record["_id"]
db.online_time.update_one(
{"_id": self.record_id},
{
"$set": {
"end_timestamp": current_time + timedelta(minutes=1),
}
},
)
except Exception:
logger.exception("在线时间记录失败")
@staticmethod
def _record_online_time():
"""记录在线时间"""
current_time = datetime.now()
# 检查5分钟内是否已有记录
recent_record = db.online_time.find_one({"timestamp": {"$gte": current_time - timedelta(minutes=5)}})
if not recent_record: class StatisticOutputTask(AsyncTask):
db.online_time.insert_one( """统计输出任务"""
{
"timestamp": current_time, SEP_LINE = "-" * 84
"duration": 5, # 5分钟
} def __init__(self, record_file_path: str = "llm_statistics.txt"):
# 延迟300秒启动运行间隔300秒
super().__init__(task_name="Statistics Data Output Task", wait_before_start=300, run_interval=300)
self.name_mapping: Dict[str, Tuple[str, float]] = {}
"""
联系人/群聊名称映射 {聊天ID: (联系人/群聊名称, 记录时间timestamp)}
注:设计记录时间的目的是方便更新名称,使联系人/群聊名称保持最新
"""
self.record_file_path: str = record_file_path
"""
记录文件路径
"""
now = datetime.now()
self.stat_period: List[Tuple[str, datetime, str]] = [
("all_time", datetime(2000, 1, 1), "自部署以来的"),
("last_7_days", now - timedelta(days=7), "最近7天的"),
("last_24_hours", now - timedelta(days=1), "最近24小时的"),
("last_hour", now - timedelta(hours=1), "最近1小时的"),
]
"""
统计时间段
"""
def _statistic_console_output(self, stats: Dict[str, Any]):
"""
输出统计数据到控制台
"""
# 输出最近一小时的统计数据
output = [
self.SEP_LINE,
f" 最近1小时的统计数据 (详细信息见文件:{self.record_file_path})",
self.SEP_LINE,
self._format_total_stat(stats["last_hour"]),
"",
self._format_model_classified_stat(stats["last_hour"]),
"",
self._format_chat_stat(stats["last_hour"]),
self.SEP_LINE,
"",
]
logger.info("\n" + "\n".join(output))
def _statistic_file_output(self, stats: Dict[str, Any]):
"""
输出统计数据到文件
"""
output = [f"MaiBot运行统计报告 (生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')})", ""]
def _format_stat_data(title: str, stats_: Dict[str, Any]) -> str:
"""
格式化统计数据
"""
return "\n".join(
[
self.SEP_LINE,
f" {title}",
self.SEP_LINE,
self._format_total_stat(stats_),
"",
self._format_model_classified_stat(stats_),
"",
self._format_req_type_classified_stat(stats_),
"",
self._format_user_classified_stat(stats_),
"",
self._format_chat_stat(stats_),
"",
]
) )
def _collect_statistics_for_period(self, start_time: datetime) -> Dict[str, Any]: for period_key, period_start_time, period_desc in self.stat_period:
"""收集指定时间段的LLM请求统计数据 if period_key in stats:
# 统计数据存在
output.append(
_format_stat_data(
f"{period_desc}统计数据 (自{period_start_time.strftime('%Y-%m-%d %H:%M:%S')}开始)",
stats[period_key],
)
)
Args: with open(self.record_file_path, "w", encoding="utf-8") as f:
start_time: 统计开始时间 f.write("\n\n".join(output))
async def run(self):
try:
# 收集统计数据
stats = self._collect_all_statistics()
# 输出统计数据到控制台
self._statistic_console_output(stats)
# 输出统计数据到文件
self._statistic_file_output(stats)
except Exception as e:
logger.exception(f"输出统计数据过程中发生异常,错误信息:{e}")
# -- 以下为统计数据收集方法 --
@staticmethod
def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime, str]]) -> Dict[str, Any]:
""" """
收集指定时间段的LLM请求统计数据
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = { stats = {
"total_requests": 0, period_key: {
"requests_by_type": defaultdict(int), # 总LLM请求数
"requests_by_user": defaultdict(int), TOTAL_REQ_CNT: 0,
"requests_by_model": defaultdict(int), # 请求次数统计
"average_tokens": 0, REQ_CNT_BY_TYPE: defaultdict(int),
"total_tokens": 0, REQ_CNT_BY_USER: defaultdict(int),
"total_cost": 0.0, REQ_CNT_BY_MODEL: defaultdict(int),
"costs_by_user": defaultdict(float), # 输入Token数
"costs_by_type": defaultdict(float), IN_TOK_BY_TYPE: defaultdict(int),
"costs_by_model": defaultdict(float), IN_TOK_BY_USER: defaultdict(int),
# 新增token统计字段 IN_TOK_BY_MODEL: defaultdict(int),
"tokens_by_type": defaultdict(int), # 输出Token数
"tokens_by_user": defaultdict(int), OUT_TOK_BY_TYPE: defaultdict(int),
"tokens_by_model": defaultdict(int), OUT_TOK_BY_USER: defaultdict(int),
# 新增在线时间统计 OUT_TOK_BY_MODEL: defaultdict(int),
"online_time_minutes": 0, # 总Token数
# 新增消息统计字段 TOTAL_TOK_BY_TYPE: defaultdict(int),
"total_messages": 0, TOTAL_TOK_BY_USER: defaultdict(int),
"messages_by_user": defaultdict(int), TOTAL_TOK_BY_MODEL: defaultdict(int),
"messages_by_chat": defaultdict(int), # 总开销
TOTAL_COST: 0.0,
# 请求开销统计
COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float),
}
for period_key, _, _ in collect_period
} }
cursor = db.llm_usage.find({"timestamp": {"$gte": start_time}}) # 以最早的时间戳为起始时间获取记录
total_requests = 0 for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}):
record_timestamp = record.get("timestamp")
for idx, (_, period_start, _) in enumerate(collect_period):
if record_timestamp >= period_start:
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
for doc in cursor: request_type = record.get("request_type", "unknown") # 请求类型
stats["total_requests"] += 1 user_id = str(record.get("user_id", "unknown")) # 用户ID
request_type = doc.get("request_type", "unknown") model_name = record.get("model_name", "unknown") # 模型名称
user_id = str(doc.get("user_id", "unknown"))
model_name = doc.get("model_name", "unknown")
stats["requests_by_type"][request_type] += 1 stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
stats["requests_by_user"][user_id] += 1 stats[period_key][REQ_CNT_BY_USER][user_id] += 1
stats["requests_by_model"][model_name] += 1 stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
prompt_tokens = doc.get("prompt_tokens", 0) prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数
completion_tokens = doc.get("completion_tokens", 0) completion_tokens = record.get("completion_tokens", 0) # 输出Token数
total_tokens = prompt_tokens + completion_tokens total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数
stats["tokens_by_type"][request_type] += total_tokens
stats["tokens_by_user"][user_id] += total_tokens
stats["tokens_by_model"][model_name] += total_tokens
stats["total_tokens"] += total_tokens
cost = doc.get("cost", 0.0) stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats["total_cost"] += cost stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
stats["costs_by_user"][user_id] += cost stats[period_key][IN_TOK_BY_MODEL][model_name] += prompt_tokens
stats["costs_by_type"][request_type] += cost
stats["costs_by_model"][model_name] += cost
total_requests += 1 stats[period_key][OUT_TOK_BY_TYPE][request_type] += completion_tokens
stats[period_key][OUT_TOK_BY_USER][user_id] += completion_tokens
stats[period_key][OUT_TOK_BY_MODEL][model_name] += completion_tokens
if total_requests > 0: stats[period_key][TOTAL_TOK_BY_TYPE][request_type] += total_tokens
stats["average_tokens"] = stats["total_tokens"] / total_requests stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
cost = record.get("cost", 0.0)
stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost
stats[period_key][COST_BY_MODEL][model_name] += cost
break # 取消更早时间段的判断
return stats
@staticmethod
def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime, str]]) -> Dict[str, Any]:
"""
收集指定时间段的在线时间统计数据
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 在线时间统计
ONLINE_TIME: 0.0,
}
for period_key, _, _ in collect_period
}
# 统计在线时间 # 统计在线时间
online_time_cursor = db.online_time.find({"timestamp": {"$gte": start_time}}) for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}):
for doc in online_time_cursor: end_timestamp: datetime = record.get("end_timestamp")
stats["online_time_minutes"] += doc.get("duration", 0) for idx, (_, period_start, _) in enumerate(collect_period):
if end_timestamp >= period_start:
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _period_start, _ in collect_period[idx:]:
start_timestamp: datetime = record.get("start_timestamp")
if start_timestamp < _period_start:
# 如果开始时间在查询边界之前,则使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds() / 60
else:
# 否则,使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds() / 60
break # 取消更早时间段的判断
return stats
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime, str]]) -> Dict[str, Any]:
"""
收集指定时间段的消息统计数据
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 消息统计
TOTAL_MSG_CNT: 0,
MSG_CNT_BY_CHAT: defaultdict(int),
}
for period_key, _, _ in collect_period
}
# 统计消息量 # 统计消息量
messages_cursor = db.messages.find({"time": {"$gte": start_time.timestamp()}}) for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}):
for doc in messages_cursor: chat_info = message.get("chat_info", None) # 聊天信息
stats["total_messages"] += 1 user_info = message.get("user_info", None) # 用户信息(消息发送人)
# user_id = str(doc.get("user_info", {}).get("user_id", "unknown")) message_time = message.get("time", 0) # 消息时间
chat_info = doc.get("chat_info", {})
user_info = doc.get("user_info", {}) group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息
user_id = str(user_info.get("user_id", "unknown")) if group_info is not None:
message_time = doc.get("time", 0) # 若有群聊信息
group_info = chat_info.get("group_info") if chat_info else {} chat_id = f"g{group_info.get('group_id')}"
# print(f"group_info: {group_info}") chat_name = group_info.get("group_name", f"{group_info.get('group_id')}")
group_name = None elif user_info:
if group_info: # 若没有群聊信息,则尝试获取用户信息
group_id = f"g{group_info.get('group_id')}" chat_id = f"u{user_info['user_id']}"
group_name = group_info.get("group_name", f"{group_info.get('group_id')}") chat_name = user_info["user_nickname"]
if user_info and not group_name:
group_id = f"u{user_info['user_id']}"
group_name = user_info["user_nickname"]
if self.name_dict.get(group_id):
if message_time > self.name_dict.get(group_id)[1]:
self.name_dict[group_id] = [group_name, message_time]
else: else:
self.name_dict[group_id] = [group_name, message_time] continue # 如果没有群组信息也没有用户信息,则跳过
# print(f"group_name: {group_name}")
stats["messages_by_user"][user_id] += 1 if chat_id in self.name_mapping:
stats["messages_by_chat"][group_id] += 1 if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]:
# 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称
self.name_mapping[chat_id] = (chat_name, message_time)
else:
self.name_mapping[chat_id] = (chat_name, message_time)
for idx, (_, period_start, _) in enumerate(collect_period):
if message_time >= period_start.timestamp():
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _, _ in collect_period[idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break
return stats return stats
def _collect_all_statistics(self) -> Dict[str, Dict[str, Any]]: def _collect_all_statistics(self) -> Dict[str, Dict[str, Any]]:
"""收集所有时间范围的统计数据""" """
收集各时间段的统计数据
"""
now = datetime.now() now = datetime.now()
# 使用2000年1月1日作为"所有时间"的起始时间,这是一个更合理的起始点
all_time_start = datetime(2000, 1, 1)
return { last_all_time_stat = None
"all_time": self._collect_statistics_for_period(all_time_start),
"last_7_days": self._collect_statistics_for_period(now - timedelta(days=7)),
"last_24_hours": self._collect_statistics_for_period(now - timedelta(days=1)),
"last_hour": self._collect_statistics_for_period(now - timedelta(hours=1)),
}
def _format_stats_section(self, stats: Dict[str, Any], title: str) -> str: stat = {period[0]: {} for period in self.stat_period}
"""格式化统计部分的输出"""
output = ["\n" + "-" * 84, f"{title}", "-" * 84, f"总请求数: {stats['total_requests']}"]
if stats["total_requests"] > 0: if "last_full_statistics_timestamp" in local_storage and "last_full_statistics" in local_storage:
output.append(f"总Token数: {stats['total_tokens']}") # 若存有上次完整统计的时间戳,则使用该时间戳作为"所有时间"的起始时间,进行增量统计
output.append(f"总花费: {stats['total_cost']:.4f}¥") last_full_stat_ts: float = local_storage["last_full_statistics_timestamp"]
output.append(f"在线时间: {stats['online_time_minutes']}分钟") last_all_time_stat = local_storage["last_full_statistics"]
output.append(f"总消息数: {stats['total_messages']}\n") self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段
self.stat_period.append(("all_time", datetime.fromtimestamp(last_full_stat_ts), "自部署以来的"))
data_fmt = "{:<32} {:>10} {:>14} {:>13.4f} ¥" model_req_stat = self._collect_model_request_for_period(self.stat_period)
online_time_stat = self._collect_online_time_for_period(self.stat_period)
message_count_stat = self._collect_message_count_for_period(self.stat_period)
# 统计数据合并
# 合并三类统计数据
for period_key, _, _ in self.stat_period:
stat[period_key].update(model_req_stat[period_key])
stat[period_key].update(online_time_stat[period_key])
stat[period_key].update(message_count_stat[period_key])
if last_all_time_stat:
# 若存在上次完整统计数据,则将其与当前统计数据合并
for key, val in last_all_time_stat.items():
if isinstance(val, dict):
# 是字典类型,则进行合并
for sub_key, sub_val in val.items():
stat["all_time"][key][sub_key] += sub_val
else:
# 直接合并
stat["all_time"][key] += val
# 更新上次完整统计数据的时间戳
local_storage["last_full_statistics_timestamp"] = now.timestamp()
# 更新上次完整统计数据
local_storage["last_full_statistics"] = stat["all_time"]
return stat
# -- 以下为统计数据格式化方法 --
@staticmethod
def _format_total_stat(stats: Dict[str, Any]) -> str:
"""
格式化总统计数据
"""
output = [
f"总在线时间: {stats[ONLINE_TIME]:.1f}分钟",
f"总消息数: {stats[TOTAL_MSG_CNT]}",
f"总请求数: {stats[TOTAL_REQ_CNT]}",
f"总花费: {stats[TOTAL_COST]:.4f}¥",
"",
]
return "\n".join(output)
@staticmethod
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
"""
格式化按模型分类的统计数据
"""
if stats[TOTAL_REQ_CNT] > 0:
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.4f}¥"
output = [
"按模型分类统计:",
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费",
]
for model_name, count in sorted(stats[REQ_CNT_BY_MODEL].items()):
name = model_name[:29] + "..." if len(model_name) > 32 else model_name
in_tokens = stats[IN_TOK_BY_MODEL][model_name]
out_tokens = stats[OUT_TOK_BY_MODEL][model_name]
tokens = stats[TOTAL_TOK_BY_MODEL][model_name]
cost = stats[COST_BY_MODEL][model_name]
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost))
# 按模型统计
output.append("按模型统计:")
output.append("模型名称 调用次数 Token总量 累计花费")
for model_name, count in sorted(stats["requests_by_model"].items()):
tokens = stats["tokens_by_model"][model_name]
cost = stats["costs_by_model"][model_name]
output.append(
data_fmt.format(model_name[:30] + ".." if len(model_name) > 32 else model_name, count, tokens, cost)
)
output.append("") output.append("")
return "\n".join(output)
else:
return ""
@staticmethod
def _format_req_type_classified_stat(stats: Dict[str, Any]) -> str:
"""
格式化按请求类型分类的统计数据
"""
if stats[TOTAL_REQ_CNT] > 0:
# 按请求类型统计 # 按请求类型统计
output.append("按请求类型统计:") data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.4f}¥"
output.append("模型名称 调用次数 Token总量 累计花费")
for req_type, count in sorted(stats["requests_by_type"].items()):
tokens = stats["tokens_by_type"][req_type]
cost = stats["costs_by_type"][req_type]
output.append(
data_fmt.format(req_type[:22] + ".." if len(req_type) > 24 else req_type, count, tokens, cost)
)
output.append("")
output = [
"按请求类型分类统计:",
" 请求类型 调用次数 输入Token 输出Token Token总量 累计花费",
]
for req_type, count in sorted(stats[REQ_CNT_BY_TYPE].items()):
name = req_type[:29] + "..." if len(req_type) > 32 else req_type
in_tokens = stats[IN_TOK_BY_TYPE][req_type]
out_tokens = stats[OUT_TOK_BY_TYPE][req_type]
tokens = stats[TOTAL_TOK_BY_TYPE][req_type]
cost = stats[COST_BY_TYPE][req_type]
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost))
output.append("")
return "\n".join(output)
else:
return ""
@staticmethod
def _format_user_classified_stat(stats: Dict[str, Any]) -> str:
"""
格式化按用户分类的统计数据
"""
if stats[TOTAL_REQ_CNT] > 0:
# 修正用户统计列宽 # 修正用户统计列宽
output.append("按用户统计:") data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.4f}¥"
output.append("用户ID 调用次数 Token总量 累计花费")
for user_id, count in sorted(stats["requests_by_user"].items()): output = [
tokens = stats["tokens_by_user"][user_id] "按用户分类统计:",
cost = stats["costs_by_user"][user_id] " 用户名称 调用次数 输入Token 输出Token Token总量 累计花费",
]
for user_id, count in sorted(stats[REQ_CNT_BY_USER].items()):
in_tokens = stats[IN_TOK_BY_USER][user_id]
out_tokens = stats[OUT_TOK_BY_USER][user_id]
tokens = stats[TOTAL_TOK_BY_USER][user_id]
cost = stats[COST_BY_USER][user_id]
output.append( output.append(
data_fmt.format( data_fmt.format(
user_id[:22], # 不再添加省略号保持原始ID user_id[:22], # 不再添加省略号保持原始ID
count, count,
in_tokens,
out_tokens,
tokens, tokens,
cost, cost,
) )
) )
output.append("") output.append("")
return "\n".join(output)
else:
return ""
# 添加聊天统计 def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
output.append("群组统计:") """
output.append("群组名称 消息数量") 格式化聊天统计数据
for group_id, count in sorted(stats["messages_by_chat"].items()): """
output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}") if stats[TOTAL_MSG_CNT] > 0:
output = ["聊天消息统计:", " 联系人/群组名称 消息数量"]
for chat_id, count in sorted(stats[MSG_CNT_BY_CHAT].items()):
output.append(f"{self.name_mapping[chat_id][0][:32]:<32} {count:>10}")
return "\n".join(output)
def _format_stats_section_lite(self, stats: Dict[str, Any], title: str) -> str:
"""格式化统计部分的输出"""
output = ["\n" + "-" * 84, f"{title}", "-" * 84]
# output.append(f"总请求数: {stats['total_requests']}")
if stats["total_requests"] > 0:
# output.append(f"总Token数: {stats['total_tokens']}")
output.append(f"总花费: {stats['total_cost']:.4f}¥")
# output.append(f"在线时间: {stats['online_time_minutes']}分钟")
output.append(f"总消息数: {stats['total_messages']}\n")
data_fmt = "{:<32} {:>10} {:>14} {:>13.4f} ¥"
# 按模型统计
output.append("按模型统计:")
output.append("模型名称 调用次数 Token总量 累计花费")
for model_name, count in sorted(stats["requests_by_model"].items()):
tokens = stats["tokens_by_model"][model_name]
cost = stats["costs_by_model"][model_name]
output.append(
data_fmt.format(model_name[:30] + ".." if len(model_name) > 32 else model_name, count, tokens, cost)
)
output.append("") output.append("")
return "\n".join(output)
# 按请求类型统计 else:
# output.append("按请求类型统计:") return ""
# output.append(("模型名称 调用次数 Token总量 累计花费"))
# for req_type, count in sorted(stats["requests_by_type"].items()):
# tokens = stats["tokens_by_type"][req_type]
# cost = stats["costs_by_type"][req_type]
# output.append(
# data_fmt.format(req_type[:22] + ".." if len(req_type) > 24 else req_type, count, tokens, cost)
# )
# output.append("")
# 修正用户统计列宽
# output.append("按用户统计:")
# output.append(("用户ID 调用次数 Token总量 累计花费"))
# for user_id, count in sorted(stats["requests_by_user"].items()):
# tokens = stats["tokens_by_user"][user_id]
# cost = stats["costs_by_user"][user_id]
# output.append(
# data_fmt.format(
# user_id[:22], # 不再添加省略号保持原始ID
# count,
# tokens,
# cost,
# )
# )
# output.append("")
# 添加聊天统计
output.append("群组统计:")
output.append("群组名称 消息数量")
for group_id, count in sorted(stats["messages_by_chat"].items()):
output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}")
return "\n".join(output)
def _save_statistics(self, all_stats: Dict[str, Dict[str, Any]]):
"""将统计结果保存到文件"""
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
output = [f"LLM请求统计报告 (生成时间: {current_time})"]
# 添加各个时间段的统计
sections = [
("所有时间统计", "all_time"),
("最近7天统计", "last_7_days"),
("最近24小时统计", "last_24_hours"),
("最近1小时统计", "last_hour"),
]
for title, key in sections:
output.append(self._format_stats_section(all_stats[key], title))
# 写入文件
with open(self.output_file, "w", encoding="utf-8") as f:
f.write("\n".join(output))
def _console_output_loop(self):
"""控制台输出循环每5分钟输出一次最近1小时的统计"""
while self.running:
# 等待5分钟
for _ in range(300): # 5分钟 = 300秒
if not self.running:
break
time.sleep(1)
try:
# 收集最近1小时的统计数据
now = datetime.now()
hour_stats = self._collect_statistics_for_period(now - timedelta(hours=1))
# 使用logger输出
stats_output = self._format_stats_section_lite(
hour_stats, "最近1小时统计详细信息见根目录文件llm_statistics.txt"
)
logger.info("\n" + stats_output + "\n" + "=" * 50)
except Exception:
logger.exception("控制台统计数据输出失败")
def _stats_loop(self):
"""统计循环每5分钟运行一次"""
while self.running:
try:
# 记录在线时间
self._record_online_time()
# 收集并保存统计数据
all_stats = self._collect_all_statistics()
self._save_statistics(all_stats)
except Exception:
logger.exception("统计数据处理失败")
# 等待5分钟
for _ in range(300): # 5分钟 = 300秒
if not self.running:
break
time.sleep(1)