diff --git a/src/api/message_router.py b/src/api/message_router.py index 513d3d2df..a8551ba04 100644 --- a/src/api/message_router.py +++ b/src/api/message_router.py @@ -58,115 +58,108 @@ async def get_message_stats( @router.get("/messages/stats_by_chat") async def get_message_stats_by_chat( days: int = Query(1, ge=1, description="指定查询过去多少天的数据"), - group_by_user: bool = Query(False, description="是否按用户进行分组统计"), + source: Literal["user", "bot"] = Query("user", description="筛选消息来源: 'user' (用户发送的), 'bot' (BOT发送的)"), + group_by_user: bool = Query(False, description="是否按用户进行分组统计 (仅当 source='user' 时有效)"), format: bool = Query(False, description="是否格式化输出,包含群聊和用户信息"), ): """ - 获取BOT在指定天数内按聊天流或按用户统计的消息数据。 + 获取在指定天数内,按聊天会话统计的消息数据。 + 可根据消息来源 (用户或BOT) 进行筛选。 """ try: + # --- 1. 数据准备 --- + # 计算查询的时间范围 end_time = time.time() start_time = end_time - (days * 24 * 3600) + # 从数据库获取指定时间范围内的所有消息 messages = await message_api.get_messages_by_time(start_time, end_time) bot_qq = str(global_config.bot.qq_account) - messages = [msg for msg in messages if msg.get("user_id") != bot_qq] + # --- 2. 消息筛选 --- + # 根据 source 参数筛选消息来源 + if source == "user": + # 筛选出用户发送的消息(即非机器人发送的消息) + messages = [msg for msg in messages if msg.get("user_id") != bot_qq] + else: # source == "bot" + # 筛选出机器人发送的消息 + messages = [msg for msg in messages if msg.get("user_id") == bot_qq] + # --- 3. 数据统计 --- stats = {} + # 如果统计来源是用户 + if source == "user": + # 遍历用户消息进行统计 + for msg in messages: + chat_id = msg.get("chat_id", "unknown") + user_id = msg.get("user_id") + # 初始化聊天会话的统计结构 + if chat_id not in stats: + stats[chat_id] = {"total_stats": {"total": 0}, "user_stats": {}} + # 累加总消息数 + stats[chat_id]["total_stats"]["total"] += 1 + # 如果需要按用户分组,则进一步统计每个用户的消息数 + if group_by_user: + if user_id not in stats[chat_id]["user_stats"]: + stats[chat_id]["user_stats"][user_id] = 0 + stats[chat_id]["user_stats"][user_id] += 1 + # 如果不按用户分组,则简化统计结果,只保留总数 + if not group_by_user: + stats = {chat_id: data["total_stats"] for chat_id, data in stats.items()} + # 如果统计来源是机器人 + else: + # 遍历机器人消息进行统计 + for msg in messages: + chat_id = msg.get("chat_id", "unknown") + # 初始化聊天会话的统计结构 + if chat_id not in stats: + stats[chat_id] = 0 + # 累加机器人发送的消息数 + stats[chat_id] += 1 - for msg in messages: - chat_id = msg.get("chat_id", "unknown") - user_id = msg.get("user_id") + # --- 4. 格式化输出 --- + # 如果 format 参数为 False,直接返回原始统计数据 + if not format: + return stats - if chat_id not in stats: - stats[chat_id] = {"total_stats": {"total": 0}, "user_stats": {}} + # 获取聊天管理器以查询会话信息 + chat_manager = get_chat_manager() + formatted_stats = {} + # 遍历统计结果进行格式化 + for chat_id, data in stats.items(): + stream = chat_manager.streams.get(chat_id) + chat_name = f"未知会话 ({chat_id})" + # 尝试获取更友好的会话名称(群名或用户名) + if stream: + if stream.group_info and stream.group_info.group_name: + chat_name = stream.group_info.group_name + elif stream.user_info and stream.user_info.user_nickname: + chat_name = stream.user_info.user_nickname - stats[chat_id]["total_stats"]["total"] += 1 + # 如果是机器人消息统计,直接格式化 + if source == "bot": + formatted_stats[chat_id] = {"chat_name": chat_name, "count": data} + continue - if group_by_user: - if user_id not in stats[chat_id]["user_stats"]: - stats[chat_id]["user_stats"][user_id] = 0 + # 如果是用户消息统计,进行更复杂的格式化 + formatted_data = { + "chat_name": chat_name, + "total_stats": data if not group_by_user else data["total_stats"], + } + # 如果按用户分组,则添加用户信息 + if group_by_user and "user_stats" in data: + formatted_data["user_stats"] = {} + for user_id, count in data["user_stats"].items(): + person_id = person_api.get_person_id("qq", user_id) + person_info = await person_api.get_person_info(person_id) + nickname = person_info.get("nickname", "未知用户") + formatted_data["user_stats"][user_id] = {"nickname": nickname, "count": count} + formatted_stats[chat_id] = formatted_data - stats[chat_id]["user_stats"][user_id] += 1 - - if not group_by_user: - stats = {chat_id: data["total_stats"] for chat_id, data in stats.items()} - - if format: - chat_manager = get_chat_manager() - formatted_stats = {} - for chat_id, data in stats.items(): - stream = chat_manager.streams.get(chat_id) - chat_name = "未知会话" - if stream: - if stream.group_info and stream.group_info.group_name: - chat_name = stream.group_info.group_name - elif stream.user_info and stream.user_info.user_nickname: - chat_name = stream.user_info.user_nickname - else: - chat_name = f"未知会话 ({chat_id})" - - formatted_data = { - "chat_name": chat_name, - "total_stats": data if not group_by_user else data["total_stats"], - } - - if group_by_user and "user_stats" in data: - formatted_data["user_stats"] = {} - for user_id, count in data["user_stats"].items(): - person_id = person_api.get_person_id("qq", user_id) - nickname = await person_api.get_person_value(person_id, "nickname", "未知用户") - formatted_data["user_stats"][user_id] = {"nickname": nickname, "count": count} - - formatted_stats[chat_id] = formatted_data - return formatted_stats - - return stats + return formatted_stats except Exception as e: + # 统一异常处理 + logger.error(f"获取消息统计时发生错误: {e}") raise HTTPException(status_code=500, detail=str(e)) -@router.get("/messages/bot_stats_by_chat") -async def get_bot_message_stats_by_chat( - days: int = Query(1, ge=1, description="指定查询过去多少天的数据"), - format: bool = Query(False, description="是否格式化输出,包含群聊和用户信息"), -): - """ - 获取BOT在指定天数内按聊天流统计的已发送消息数据。 - """ - try: - end_time = time.time() - start_time = end_time - (days * 24 * 3600) - messages = await message_api.get_messages_by_time(start_time, end_time) - bot_qq = str(global_config.bot.qq_account) - - # 筛选出机器人发送的消息 - bot_messages = [msg for msg in messages if msg.get("user_id") == bot_qq] - - stats = {} - for msg in bot_messages: - chat_id = msg.get("chat_id", "unknown") - if chat_id not in stats: - stats[chat_id] = 0 - stats[chat_id] += 1 - - if format: - chat_manager = get_chat_manager() - formatted_stats = {} - for chat_id, count in stats.items(): - stream = chat_manager.streams.get(chat_id) - chat_name = f"未知会话 ({chat_id})" - if stream: - if stream.group_info and stream.group_info.group_name: - chat_name = stream.group_info.group_name - elif stream.user_info and stream.user_info.user_nickname: - chat_name = stream.user_info.user_nickname - - formatted_stats[chat_id] = {"chat_name": chat_name, "count": count} - return formatted_stats - - return stats - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e))