refactor(cross_context): 提取互通组上下文获取逻辑为通用API

将原本在 `maizone` 插件中用于获取互通组聊天上下文的逻辑,提取并重构为一个更通用的 `cross_context_api.get_intercom_group_context_by_name` 函数。

这次重构提高了代码的模块化和复用性,使得其他需要跨聊天上下文功能的插件也能方便地调用此API,而无需重复实现相似的逻辑。`maizone` 插件现在直接调用这个新的API来获取上下文,简化了其内部实现。
This commit is contained in:
minecraft1024a
2025-10-05 21:44:14 +08:00
committed by Windpicker-owo
parent 3c2a90bad4
commit af4e8fe34a
2 changed files with 85 additions and 60 deletions

View File

@@ -9,9 +9,11 @@ from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat,
)
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis import config_api
logger = get_logger("cross_context_api")
@@ -178,3 +180,82 @@ async def get_chat_history_by_group_name(group_name: str) -> str:
return f"无法从互通组 {group_name} 中获取任何聊天记录。"
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
async def get_intercom_group_context_by_name(
group_name: str, days: int = 3, limit_per_chat: int = 20, total_limit: int = 100
) -> str | None:
"""
根据互通组的名称,构建该组的聊天上下文。
Args:
group_name: 互通组的名称。
days: 获取过去多少天的消息。
limit_per_chat: 每个聊天最多获取的消息条数。
total_limit: 返回的总消息条数上限。
Returns:
如果找到匹配的组,则返回一个包含聊天记录的字符串;否则返回 None。
"""
cross_context_config = global_config.cross_context
if not (cross_context_config and cross_context_config.enable):
return None
target_group = None
for group in cross_context_config.groups:
if group.name == group_name:
target_group = group
break
if not target_group:
logger.error(f"在 cross_context 配置中未找到名为 '{group_name}' 的组。")
return None
chat_manager = get_chat_manager()
all_messages = []
end_time = time.time()
start_time = end_time - (days * 24 * 60 * 60)
for chat_type, chat_raw_id in target_group.chat_ids:
is_group = chat_type == "group"
found_stream = None
# 采用与 get_chat_history_by_group_name 相同的健壮的 stream 查找方式
for stream in chat_manager.streams.values():
if is_group:
if stream.group_info and stream.group_info.group_id == chat_raw_id:
found_stream = stream
break
else: # private
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
found_stream = stream
break
if not found_stream:
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
continue
stream_id = found_stream.stream_id
messages = await get_raw_msg_by_timestamp_with_chat(
chat_id=stream_id,
timestamp_start=start_time,
timestamp_end=end_time,
limit=limit_per_chat,
limit_mode="latest",
)
all_messages.extend(messages)
if not all_messages:
return None
# 按时间戳对所有消息进行排序
all_messages.sort(key=lambda x: x.get("time", 0))
# 限制总消息数
if len(all_messages) > total_limit:
all_messages = all_messages[-total_limit:]
# build_readable_messages_with_id 返回一个元组 (formatted_string, message_id_list)
formatted_string, _ = await build_readable_messages_with_id(all_messages)
return formatted_string

View File

@@ -17,13 +17,8 @@ import bs4
import json5
import orjson
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_by_timestamp_with_chat,
)
from src.common.logger import get_logger
from src.plugin_system.apis import config_api, person_api
from src.plugin_system.apis import config_api, cross_context_api, person_api
from .content_service import ContentService
from .cookie_service import CookieService
@@ -192,61 +187,10 @@ class QZoneService:
async def _get_intercom_context(self, stream_id: str) -> str | None:
"""
根据 stream_id 查找其所属的互通组,并构建该组的聊天上下文。
Args:
stream_id: 需要查找的当前聊天流ID。
Returns:
如果找到匹配的组,则返回一个包含聊天记录的字符串;否则返回 None。
获取互通组的聊天上下文。
"""
intercom_config = config_api.get_global_config("maizone_intercom")
if not (intercom_config and intercom_config.enable):
return None
chat_manager = get_chat_manager()
bot_platform = config_api.get_global_config("bot.platform")
for group in intercom_config.groups:
# 使用集合以优化查找效率
group_stream_ids = {chat_manager.get_stream_id(bot_platform, chat_id, True) for chat_id in group.chat_ids}
if stream_id in group_stream_ids:
logger.debug(
f"Stream ID '{stream_id}' 在互通组 '{getattr(group, 'name', 'Unknown')}' 中找到,正在构建上下文。"
)
all_messages = []
end_time = time.time()
start_time = end_time - (3 * 24 * 60 * 60) # 获取过去3天的消息
for chat_id in group.chat_ids:
# 使用正确的函数获取历史消息
messages = await get_raw_msg_by_timestamp_with_chat(
chat_id=chat_id,
timestamp_start=start_time,
timestamp_end=end_time,
limit=20, # 每个聊天最多获取20条
limit_mode="latest",
)
all_messages.extend(messages)
if not all_messages:
return None
# 按时间戳对所有消息进行排序
all_messages.sort(key=lambda x: x.get("time", 0))
# 限制总消息数例如最多100条
if len(all_messages) > 100:
all_messages = all_messages[-100:]
# build_readable_messages_with_id 返回一个元组 (formatted_string, message_id_list)
formatted_string, _ = await build_readable_messages_with_id(all_messages)
return formatted_string
logger.debug(f"Stream ID '{stream_id}' 未在任何互通组中找到。")
return None
# 实际的逻辑已迁移到 cross_context_api
return await cross_context_api.get_intercom_group_context_by_name("maizone_context_group")
async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict):
"""处理对自己说说的评论并进行回复"""