perf(cross_context): 优化S4U上下文生成性能与逻辑

通过引入批量数据库查询,显著提升了跨群聊上下文(S4U)功能的性能和效率。旧的实现方式会对每个群聊进行一次数据库查询,导致在群聊数量多时性能低下。

主要变更:
- 在 `message_repository` 中新增 `get_user_messages_from_streams` 函数,使用 CTE 和 `row_number()` 在单个请求中高效检索所有目标聊天流中的用户消息。
- 重构 `build_cross_context_s4u` 以使用新的批量查询方法,大幅减少了数据库I/O和应用层循环。
- 增强了私聊场景下的逻辑,会同时获取机器人的消息以提供更完整的对话历史。
- 改进了私聊上下文的标题,使其对用户更加友好。
- 为S4U流程添加了更详细的日志,便于问题排查。
This commit is contained in:
tt-P607
2025-10-20 09:25:27 +08:00
parent 59534f5dfc
commit 30f2f1345c
2 changed files with 109 additions and 17 deletions

View File

@@ -1,5 +1,6 @@
import traceback
from typing import Any
from collections import defaultdict
from sqlalchemy import func, not_, select
from sqlalchemy.orm import DeclarativeBase
@@ -212,3 +213,69 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 SQLAlchemy插入操作通常是使用 await session.add() 和 await session.commit()。
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
async def get_user_messages_from_streams(
user_ids: list[str],
stream_ids: list[str],
timestamp_after: float,
limit_per_stream: int,
) -> dict[str, list[dict[str, Any]]]:
"""
一次性从多个聊天流中获取特定用户的近期消息。
Args:
user_ids: 目标用户的ID列表。
stream_ids: 要查询的聊天流ID列表。
timestamp_after: 只获取此时间戳之后的消息。
limit_per_stream: 每个聊天流中获取该用户的消息数量上限。
Returns:
一个字典,键为 stream_id值为该聊天流中的消息列表。
"""
if not stream_ids or not user_ids:
return {}
try:
async with get_db_session() as session:
# 使用 CTE 和 row_number() 来为每个聊天流中的用户消息进行排序和编号
ranked_messages_cte = (
select(
Messages,
func.row_number().over(partition_by=Messages.chat_id, order_by=Messages.time.desc()).label("row_num"),
)
.where(
Messages.user_id.in_(user_ids),
Messages.chat_id.in_(stream_ids),
Messages.time > timestamp_after,
)
.cte("ranked_messages")
)
# 从 CTE 中选择每个聊天流最新的 `limit_per_stream` 条消息
query = select(ranked_messages_cte).where(ranked_messages_cte.c.row_num <= limit_per_stream)
result = await session.execute(query)
messages = result.all()
# 按 stream_id 分组
messages_by_stream = defaultdict(list)
for row in messages:
# Since the row is a Row object from a CTE, we need to manually construct the model instance
msg_instance = Messages(**{c.name: getattr(row, c.name) for c in Messages.__table__.columns})
msg_dict = _model_to_dict(msg_instance)
messages_by_stream[msg_dict["chat_id"]].append(msg_dict)
# 对每个流内的消息按时间升序排序
for stream_id in messages_by_stream:
messages_by_stream[stream_id].sort(key=lambda m: m["time"])
return dict(messages_by_stream)
except Exception as e:
log_message = (
f"使用 SQLAlchemy 批量查找用户消息失败 (user_ids={user_ids}, streams={len(stream_ids)}): {e}\n"
+ traceback.format_exc()
)
logger.error(log_message)
return {}