perf(cross_context): 优化S4U上下文生成性能与逻辑
通过引入批量数据库查询,显著提升了跨群聊上下文(S4U)功能的性能和效率。旧的实现方式会对每个群聊进行一次数据库查询,导致在群聊数量多时性能低下。 主要变更: - 在 `message_repository` 中新增 `get_user_messages_from_streams` 函数,使用 CTE 和 `row_number()` 在单个请求中高效检索所有目标聊天流中的用户消息。 - 重构 `build_cross_context_s4u` 以使用新的批量查询方法,大幅减少了数据库I/O和应用层循环。 - 增强了私聊场景下的逻辑,会同时获取机器人的消息以提供更完整的对话历史。 - 改进了私聊上下文的标题,使其对用户更加友好。 - 为S4U流程添加了更详细的日志,便于问题排查。
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import traceback
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy import func, not_, select
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
@@ -212,3 +213,69 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 await session.add() 和 await session.commit()。
|
||||
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
|
||||
|
||||
|
||||
async def get_user_messages_from_streams(
|
||||
user_ids: list[str],
|
||||
stream_ids: list[str],
|
||||
timestamp_after: float,
|
||||
limit_per_stream: int,
|
||||
) -> dict[str, list[dict[str, Any]]]:
|
||||
"""
|
||||
一次性从多个聊天流中获取特定用户的近期消息。
|
||||
|
||||
Args:
|
||||
user_ids: 目标用户的ID列表。
|
||||
stream_ids: 要查询的聊天流ID列表。
|
||||
timestamp_after: 只获取此时间戳之后的消息。
|
||||
limit_per_stream: 每个聊天流中获取该用户的消息数量上限。
|
||||
|
||||
Returns:
|
||||
一个字典,键为 stream_id,值为该聊天流中的消息列表。
|
||||
"""
|
||||
if not stream_ids or not user_ids:
|
||||
return {}
|
||||
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
# 使用 CTE 和 row_number() 来为每个聊天流中的用户消息进行排序和编号
|
||||
ranked_messages_cte = (
|
||||
select(
|
||||
Messages,
|
||||
func.row_number().over(partition_by=Messages.chat_id, order_by=Messages.time.desc()).label("row_num"),
|
||||
)
|
||||
.where(
|
||||
Messages.user_id.in_(user_ids),
|
||||
Messages.chat_id.in_(stream_ids),
|
||||
Messages.time > timestamp_after,
|
||||
)
|
||||
.cte("ranked_messages")
|
||||
)
|
||||
|
||||
# 从 CTE 中选择每个聊天流最新的 `limit_per_stream` 条消息
|
||||
query = select(ranked_messages_cte).where(ranked_messages_cte.c.row_num <= limit_per_stream)
|
||||
|
||||
result = await session.execute(query)
|
||||
messages = result.all()
|
||||
|
||||
# 按 stream_id 分组
|
||||
messages_by_stream = defaultdict(list)
|
||||
for row in messages:
|
||||
# Since the row is a Row object from a CTE, we need to manually construct the model instance
|
||||
msg_instance = Messages(**{c.name: getattr(row, c.name) for c in Messages.__table__.columns})
|
||||
msg_dict = _model_to_dict(msg_instance)
|
||||
messages_by_stream[msg_dict["chat_id"]].append(msg_dict)
|
||||
|
||||
# 对每个流内的消息按时间升序排序
|
||||
for stream_id in messages_by_stream:
|
||||
messages_by_stream[stream_id].sort(key=lambda m: m["time"])
|
||||
|
||||
return dict(messages_by_stream)
|
||||
|
||||
except Exception as e:
|
||||
log_message = (
|
||||
f"使用 SQLAlchemy 批量查找用户消息失败 (user_ids={user_ids}, streams={len(stream_ids)}): {e}\n"
|
||||
+ traceback.format_exc()
|
||||
)
|
||||
logger.error(log_message)
|
||||
return {}
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import get_user_messages_from_streams
|
||||
from src.config.config import global_config
|
||||
from src.config.official_configs import ContextGroup
|
||||
|
||||
@@ -127,9 +128,12 @@ async def build_cross_context_s4u(
|
||||
|
||||
基于全局S4U配置,检索目标用户在其他聊天中的发言。
|
||||
"""
|
||||
logger.info("[S4U] Starting S4U context build.")
|
||||
cross_context_config = global_config.cross_context
|
||||
if not target_user_info or not (user_id := target_user_info.get("user_id")):
|
||||
logger.warning(f"[S4U] Failed: target_user_info ({target_user_info}) or user_id is missing.")
|
||||
return ""
|
||||
logger.info(f"[S4U] Target user ID: {user_id}")
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
all_user_messages = []
|
||||
@@ -161,25 +165,37 @@ async def build_cross_context_s4u(
|
||||
for stream_id in chat_manager.streams:
|
||||
if stream_id != chat_stream.stream_id and stream_id not in blacklisted_streams:
|
||||
streams_to_scan.append(stream_id)
|
||||
|
||||
logger.info(f"[S4U] Scan mode: {cross_context_config.s4u_mode}.")
|
||||
logger.info(f"[S4U] Whitelist: {cross_context_config.s4u_whitelist_chats}, Blacklist: {cross_context_config.s4u_blacklist_chats}.")
|
||||
logger.info(f"[S4U] Found {len(streams_to_scan)} streams to scan: {streams_to_scan}")
|
||||
|
||||
# 2. 从筛选出的聊天流中获取目标用户的消息
|
||||
# 2. 从筛选出的聊天流中获取目标用户的消息 (优化后)
|
||||
limit = cross_context_config.s4u_limit
|
||||
for stream_id in streams_to_scan:
|
||||
try:
|
||||
# 获取稍多一些消息以确保能筛选出用户消息
|
||||
messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=stream_id, timestamp=time.time(), limit=limit * 5
|
||||
)
|
||||
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-limit:]
|
||||
# 约 3 天内的消息
|
||||
timestamp_after = time.time() - (3 * 24 * 60 * 60)
|
||||
|
||||
# 如果是私聊,则同时获取bot自身的消息
|
||||
user_ids_to_fetch = [str(user_id)]
|
||||
if chat_stream.group_info is None:
|
||||
user_ids_to_fetch.append(str(global_config.bot.qq_account))
|
||||
|
||||
if user_messages:
|
||||
# 记录消息来源的 stream_id 和最新消息的时间戳
|
||||
latest_timestamp = max(msg.get("time", 0) for msg in user_messages)
|
||||
all_user_messages.append(
|
||||
{"stream_id": stream_id, "messages": user_messages, "latest_timestamp": latest_timestamp}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"S4U模式下获取聊天 {stream_id} 的消息失败: {e}")
|
||||
# 一次性批量查询
|
||||
messages_by_stream = await get_user_messages_from_streams(
|
||||
user_ids=user_ids_to_fetch,
|
||||
stream_ids=streams_to_scan,
|
||||
timestamp_after=timestamp_after,
|
||||
limit_per_stream=limit,
|
||||
)
|
||||
|
||||
for stream_id, user_messages in messages_by_stream.items():
|
||||
if user_messages:
|
||||
logger.info(f"[S4U] Found {len(user_messages)} messages for user {user_id} in stream {stream_id}.")
|
||||
latest_timestamp = max(msg.get("time", 0) for msg in user_messages)
|
||||
all_user_messages.append(
|
||||
{"stream_id": stream_id, "messages": user_messages, "latest_timestamp": latest_timestamp}
|
||||
)
|
||||
|
||||
# 3. 按最新消息时间排序,并根据 stream_limit 截断
|
||||
all_user_messages.sort(key=lambda x: x["latest_timestamp"], reverse=True)
|
||||
@@ -193,15 +209,24 @@ async def build_cross_context_s4u(
|
||||
try:
|
||||
chat_name = await chat_manager.get_stream_name(stream_id) or "未知聊天"
|
||||
user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
|
||||
|
||||
# 如果是私聊,标题不显示用户名
|
||||
stream = await chat_manager.get_stream(stream_id)
|
||||
is_private = stream.group_info is None if stream else False
|
||||
title = f'[以下是您在"{chat_name}"的近期发言]\n' if is_private else f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n'
|
||||
|
||||
formatted, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||
cross_context_blocks.append(f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted}')
|
||||
cross_context_blocks.append(f"{title}{formatted}")
|
||||
except Exception as e:
|
||||
logger.error(f"S4U模式下格式化消息失败 (stream: {stream_id}): {e}")
|
||||
|
||||
if not cross_context_blocks:
|
||||
logger.info("[S4U] No context blocks were generated. Returning empty string.")
|
||||
return ""
|
||||
|
||||
return "### 其他群聊中的聊天记录\n" + "\n\n".join(cross_context_blocks) + "\n"
|
||||
final_context = "### 其他群聊中的聊天记录\n" + "\n\n".join(cross_context_blocks) + "\n"
|
||||
logger.info(f"[S4U] Successfully generated {len(cross_context_blocks)} context blocks. Total length: {len(final_context)}.")
|
||||
return final_context
|
||||
|
||||
|
||||
async def get_intercom_group_context(group_name: str, limit_per_chat: int = 20, total_limit: int = 100) -> str | None:
|
||||
|
||||
Reference in New Issue
Block a user