refactor(chat): 抽象化跨群聊上下文构建逻辑
将 `build_cross_context` 方法的实现委托给 `cross_context_api`。 这简化了 `prompt_utils` 中的代码,将复杂的上下文构建逻辑(包括获取其他群聊、根据模式获取和格式化消息)封装到专用的API中,提高了代码的模块化和可维护性。
This commit is contained in:
committed by
Windpicker-owo
parent
db9bc2f701
commit
b8d31207cb
@@ -9,13 +9,9 @@ from typing import Dict, Any, Optional, Tuple
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.utils.chat_message_builder import (
|
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
|
||||||
build_readable_messages_with_id,
|
|
||||||
)
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.person_info.person_info import get_person_info_manager
|
from src.person_info.person_info import get_person_info_manager
|
||||||
|
from src.plugin_system.apis import cross_context_api
|
||||||
logger = get_logger("prompt_utils")
|
logger = get_logger("prompt_utils")
|
||||||
|
|
||||||
|
|
||||||
@@ -84,113 +80,29 @@ class PromptUtils:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def build_cross_context(
|
async def build_cross_context(
|
||||||
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
|
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能
|
构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能
|
||||||
|
"""
|
||||||
|
if not global_config.cross_context.enable:
|
||||||
|
return ""
|
||||||
|
|
||||||
Args:
|
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
|
||||||
chat_id: 当前聊天ID
|
if not other_chat_raw_ids:
|
||||||
target_user_info: 目标用户信息
|
return ""
|
||||||
current_prompt_mode: 当前提示模式
|
|
||||||
|
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||||
|
if not chat_stream:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if current_prompt_mode == "normal":
|
||||||
|
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
|
||||||
|
elif current_prompt_mode == "s4u":
|
||||||
|
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 跨群上下文块
|
|
||||||
"""
|
|
||||||
if not global_config.cross_context.enable:
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# 找到当前群聊所在的共享组
|
|
||||||
target_group = None
|
|
||||||
current_stream = get_chat_manager().get_stream(chat_id)
|
|
||||||
if not current_stream or not current_stream.group_info:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
try:
|
|
||||||
current_chat_raw_id = current_stream.group_info.group_id
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取群聊ID失败: {e}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
for group in global_config.cross_context.groups:
|
|
||||||
if str(current_chat_raw_id) in group.chat_ids:
|
|
||||||
target_group = group
|
|
||||||
break
|
|
||||||
|
|
||||||
if not target_group:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# 根据prompt_mode选择策略
|
|
||||||
other_chat_raw_ids = [chat_id for chat_id in target_group.chat_ids if chat_id != str(current_chat_raw_id)]
|
|
||||||
|
|
||||||
cross_context_messages = []
|
|
||||||
|
|
||||||
if current_prompt_mode == "normal":
|
|
||||||
# normal模式:获取其他群聊的最近N条消息
|
|
||||||
for chat_raw_id in other_chat_raw_ids:
|
|
||||||
stream_id = get_chat_manager().get_stream_id(current_stream.platform, chat_raw_id, is_group=True)
|
|
||||||
if not stream_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
messages = get_raw_msg_before_timestamp_with_chat(
|
|
||||||
chat_id=stream_id,
|
|
||||||
timestamp=time.time(),
|
|
||||||
limit=5, # 可配置
|
|
||||||
)
|
|
||||||
if messages:
|
|
||||||
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
|
|
||||||
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
|
|
||||||
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取群聊{chat_raw_id}的消息失败: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif current_prompt_mode == "s4u":
|
|
||||||
# s4u模式:获取当前发言用户在其他群聊的消息
|
|
||||||
if target_user_info:
|
|
||||||
user_id = target_user_info.get("user_id")
|
|
||||||
|
|
||||||
if user_id:
|
|
||||||
for chat_raw_id in other_chat_raw_ids:
|
|
||||||
stream_id = get_chat_manager().get_stream_id(
|
|
||||||
current_stream.platform, chat_raw_id, is_group=True
|
|
||||||
)
|
|
||||||
if not stream_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
messages = get_raw_msg_before_timestamp_with_chat(
|
|
||||||
chat_id=stream_id,
|
|
||||||
timestamp=time.time(),
|
|
||||||
limit=20, # 获取更多消息以供筛选
|
|
||||||
)
|
|
||||||
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][
|
|
||||||
-5:
|
|
||||||
] # 筛选并取最近5条
|
|
||||||
|
|
||||||
if user_messages:
|
|
||||||
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
|
|
||||||
user_name = (
|
|
||||||
target_user_info.get("person_name")
|
|
||||||
or target_user_info.get("user_nickname")
|
|
||||||
or user_id
|
|
||||||
)
|
|
||||||
formatted_messages, _ = build_readable_messages_with_id(
|
|
||||||
user_messages, timestamp_mode="relative"
|
|
||||||
)
|
|
||||||
cross_context_messages.append(
|
|
||||||
f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}'
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取用户{user_id}在群聊{chat_raw_id}的消息失败: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not cross_context_messages:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
return "# 跨群上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_reply_target_id(reply_to: str) -> str:
|
def parse_reply_target_id(reply_to: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
118
src/plugin_system/apis/cross_context_api.py
Normal file
118
src/plugin_system/apis/cross_context_api.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""
|
||||||
|
跨群聊上下文API
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.chat.utils.chat_message_builder import (
|
||||||
|
get_raw_msg_before_timestamp_with_chat,
|
||||||
|
build_readable_messages_with_id,
|
||||||
|
)
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||||
|
|
||||||
|
logger = get_logger("cross_context_api")
|
||||||
|
|
||||||
|
|
||||||
|
def get_context_groups(chat_id: str) -> Optional[List[str]]:
|
||||||
|
"""
|
||||||
|
获取当前群聊所在的共享组的其他群聊ID
|
||||||
|
"""
|
||||||
|
current_stream = get_chat_manager().get_stream(chat_id)
|
||||||
|
if not current_stream or not current_stream.group_info:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_chat_raw_id = current_stream.group_info.group_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取群聊ID失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
for group in global_config.cross_context.groups:
|
||||||
|
if str(current_chat_raw_id) in group.chat_ids:
|
||||||
|
return [chat_id for chat_id in group.chat_ids if chat_id != str(current_chat_raw_id)]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def build_cross_context_normal(chat_stream: ChatStream, other_chat_raw_ids: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
构建跨群聊上下文 (Normal模式)
|
||||||
|
"""
|
||||||
|
cross_context_messages = []
|
||||||
|
for chat_raw_id in other_chat_raw_ids:
|
||||||
|
stream_id = get_chat_manager().get_stream_id(chat_stream.platform, chat_raw_id, is_group=True)
|
||||||
|
if not stream_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
messages = get_raw_msg_before_timestamp_with_chat(
|
||||||
|
chat_id=stream_id,
|
||||||
|
timestamp=time.time(),
|
||||||
|
limit=5, # 可配置
|
||||||
|
)
|
||||||
|
if messages:
|
||||||
|
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
|
||||||
|
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||||
|
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取群聊{chat_raw_id}的消息失败: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not cross_context_messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return "# 跨群上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def build_cross_context_s4u(
|
||||||
|
chat_stream: ChatStream, other_chat_raw_ids: List[str], target_user_info: Optional[Dict[str, Any]]
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
构建跨群聊上下文 (S4U模式)
|
||||||
|
"""
|
||||||
|
cross_context_messages = []
|
||||||
|
if target_user_info:
|
||||||
|
user_id = target_user_info.get("user_id")
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
for chat_raw_id in other_chat_raw_ids:
|
||||||
|
stream_id = get_chat_manager().get_stream_id(
|
||||||
|
chat_stream.platform, chat_raw_id, is_group=True
|
||||||
|
)
|
||||||
|
if not stream_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
messages = get_raw_msg_before_timestamp_with_chat(
|
||||||
|
chat_id=stream_id,
|
||||||
|
timestamp=time.time(),
|
||||||
|
limit=20, # 获取更多消息以供筛选
|
||||||
|
)
|
||||||
|
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][
|
||||||
|
-5:
|
||||||
|
] # 筛选并取最近5条
|
||||||
|
|
||||||
|
if user_messages:
|
||||||
|
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
|
||||||
|
user_name = (
|
||||||
|
target_user_info.get("person_name")
|
||||||
|
or target_user_info.get("user_nickname")
|
||||||
|
or user_id
|
||||||
|
)
|
||||||
|
formatted_messages, _ = build_readable_messages_with_id(
|
||||||
|
user_messages, timestamp_mode="relative"
|
||||||
|
)
|
||||||
|
cross_context_messages.append(
|
||||||
|
f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取用户{user_id}在群聊{chat_raw_id}的消息失败: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not cross_context_messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return "# 跨群上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||||
Reference in New Issue
Block a user