Merge remote-tracking branch 'origin/master' into dev

This commit is contained in:
minecraft1024a
2025-09-05 19:34:47 +08:00
6 changed files with 221 additions and 165 deletions

View File

@@ -64,7 +64,7 @@ class ResponseHandler:
- 构建并返回完整的循环信息 - 构建并返回完整的循环信息
- 用于上级方法的状态跟踪 - 用于上级方法的状态跟踪
""" """
reply_text = await self.send_response(response_set, reply_to_str, loop_start_time, action_message) reply_text = await self.send_response(response_set, loop_start_time, action_message)
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
@@ -166,8 +166,8 @@ class ResponseHandler:
await send_api.text_to_stream( await send_api.text_to_stream(
text=data, text=data,
stream_id=self.context.stream_id, stream_id=self.context.stream_id,
reply_to_message = message_data, reply_to_message=None,
set_reply=need_reply, set_reply=False,
typing=True, typing=True,
) )
@@ -209,7 +209,7 @@ class ResponseHandler:
) )
# 根据反注入结果处理消息数据 # 根据反注入结果处理消息数据
await anti_injector.handle_message_storage(result, modified_content, reason, message_data) await anti_injector.handle_message_storage(result, modified_content, reason or "", message_data)
if result == ProcessResult.BLOCKED_BAN: if result == ProcessResult.BLOCKED_BAN:
# 用户被封禁 - 直接阻止回复生成 # 用户被封禁 - 直接阻止回复生成

View File

@@ -9,13 +9,9 @@ from typing import Dict, Any, Optional, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
build_readable_messages_with_id,
)
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager from src.person_info.person_info import get_person_info_manager
from src.plugin_system.apis import cross_context_api
logger = get_logger("prompt_utils") logger = get_logger("prompt_utils")
@@ -84,113 +80,29 @@ class PromptUtils:
@staticmethod @staticmethod
async def build_cross_context( async def build_cross_context(
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
) -> str: ) -> str:
""" """
构建跨群聊上下文 - 统一实现完全继承DefaultReplyer功能 构建跨群聊上下文 - 统一实现完全继承DefaultReplyer功能
"""
if not global_config.cross_context.enable:
return ""
Args: other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
chat_id: 当前聊天ID if not other_chat_raw_ids:
target_user_info: 目标用户信息 return ""
current_prompt_mode: 当前提示模式
chat_stream = get_chat_manager().get_stream(chat_id)
if not chat_stream:
return ""
if current_prompt_mode == "normal":
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
elif current_prompt_mode == "s4u":
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
Returns:
str: 跨群上下文块
"""
if not global_config.cross_context.enable:
return "" return ""
# 找到当前群聊所在的共享组
target_group = None
current_stream = get_chat_manager().get_stream(chat_id)
if not current_stream or not current_stream.group_info:
return ""
try:
current_chat_raw_id = current_stream.group_info.group_id
except Exception as e:
logger.error(f"获取群聊ID失败: {e}")
return ""
for group in global_config.cross_context.groups:
if str(current_chat_raw_id) in group.chat_ids:
target_group = group
break
if not target_group:
return ""
# 根据prompt_mode选择策略
other_chat_raw_ids = [chat_id for chat_id in target_group.chat_ids if chat_id != str(current_chat_raw_id)]
cross_context_messages = []
if current_prompt_mode == "normal":
# normal模式获取其他群聊的最近N条消息
for chat_raw_id in other_chat_raw_ids:
stream_id = get_chat_manager().get_stream_id(current_stream.platform, chat_raw_id, is_group=True)
if not stream_id:
continue
try:
messages = get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=5, # 可配置
)
if messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
except Exception as e:
logger.error(f"获取群聊{chat_raw_id}的消息失败: {e}")
continue
elif current_prompt_mode == "s4u":
# s4u模式获取当前发言用户在其他群聊的消息
if target_user_info:
user_id = target_user_info.get("user_id")
if user_id:
for chat_raw_id in other_chat_raw_ids:
stream_id = get_chat_manager().get_stream_id(
current_stream.platform, chat_raw_id, is_group=True
)
if not stream_id:
continue
try:
messages = get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=20, # 获取更多消息以供筛选
)
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][
-5:
] # 筛选并取最近5条
if user_messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
user_name = (
target_user_info.get("person_name")
or target_user_info.get("user_nickname")
or user_id
)
formatted_messages, _ = build_readable_messages_with_id(
user_messages, timestamp_mode="relative"
)
cross_context_messages.append(
f'[以下是"{user_name}""{chat_name}"的近期发言]\n{formatted_messages}'
)
except Exception as e:
logger.error(f"获取用户{user_id}在群聊{chat_raw_id}的消息失败: {e}")
continue
if not cross_context_messages:
return ""
return "# 跨群上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
@staticmethod @staticmethod
def parse_reply_target_id(reply_to: str) -> str: def parse_reply_target_id(reply_to: str) -> str:
""" """

View File

@@ -652,7 +652,9 @@ class ContextGroup(ValidatedConfigBase):
"""上下文共享组配置""" """上下文共享组配置"""
name: str = Field(..., description="共享组的名称") name: str = Field(..., description="共享组的名称")
chat_ids: List[str] = Field(..., description="属于该组的聊天ID列表") chat_ids: List[List[str]] = Field(
..., description='属于该组的聊天ID列表格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]'
)
class CrossContextConfig(ValidatedConfigBase): class CrossContextConfig(ValidatedConfigBase):

View File

@@ -28,12 +28,36 @@ from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
# 导入消息API和traceback模块 # 导入消息API和traceback模块
from src.common.message import get_global_api from src.common.message import get_global_api
# 条件导入记忆系统 from src.chat.memory_system.Hippocampus import hippocampus_manager
if global_config.memory.enable_memory: if not global_config.memory.enable_memory:
from src.chat.memory_system.Hippocampus import hippocampus_manager import src.chat.memory_system.Hippocampus as hippocampus_module
# 插件系统现在使用统一的插件加载器 class MockHippocampusManager:
def initialize(self):
pass
def get_hippocampus(self):
return None
async def build_memory(self):
pass
async def forget_memory(self, percentage: float = 0.005):
pass
async def consolidate_memory(self):
pass
async def get_memory_from_text(self, text: str, max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3, fast_retrieval: bool = False) -> list:
return []
async def get_memory_from_topic(self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3) -> list:
return []
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
return 0.0, []
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
return []
def get_all_node_names(self) -> list:
return []
hippocampus_module.hippocampus_manager = MockHippocampusManager()
# 插件系统现在使用统一的插件加载器
install(extra_lines=3) install(extra_lines=3)
@@ -42,12 +66,8 @@ logger = get_logger("main")
class MainSystem: class MainSystem:
def __init__(self): def __init__(self):
# 根据配置条件性地初始化记忆系统 self.hippocampus_manager = hippocampus_manager
if global_config.memory.enable_memory:
self.hippocampus_manager = hippocampus_manager
else:
self.hippocampus_manager = None
self.individuality: Individuality = get_individuality() self.individuality: Individuality = get_individuality()
# 使用消息API替代直接的FastAPI实例 # 使用消息API替代直接的FastAPI实例
@@ -103,8 +123,6 @@ class MainSystem:
else: else:
loop.run_until_complete(async_memory_manager.shutdown()) loop.run_until_complete(async_memory_manager.shutdown())
logger.info("🛑 记忆管理器已停止") logger.info("🛑 记忆管理器已停止")
except ImportError:
pass # 异步记忆优化器不存在
except Exception as e: except Exception as e:
logger.error(f"停止记忆管理器时出错: {e}") logger.error(f"停止记忆管理器时出错: {e}")
@@ -201,22 +219,18 @@ MoFox_Bot(第三方修改版)
logger.info("聊天管理器初始化成功") logger.info("聊天管理器初始化成功")
# 根据配置条件性地初始化记忆系统 # 初始化记忆系统
if global_config.memory.enable_memory: self.hippocampus_manager.initialize()
if self.hippocampus_manager: logger.info("记忆系统初始化成功")
self.hippocampus_manager.initialize()
logger.info("记忆系统初始化成功") # 初始化异步记忆管理器
try:
# 初始化异步记忆管理器 from src.chat.memory_system.async_memory_optimizer import async_memory_manager
try:
from src.chat.memory_system.async_memory_optimizer import async_memory_manager await async_memory_manager.initialize()
logger.info("记忆管理器初始化成功")
await async_memory_manager.initialize() except Exception as e:
logger.info("记忆管理器初始化成功") logger.error(f"记忆管理器初始化失败: {e}")
except Exception as e:
logger.error(f"记忆管理器初始化失败: {e}")
else:
logger.info("记忆系统已禁用,跳过初始化")
# await asyncio.sleep(0.5) #防止logger输出飞了 # await asyncio.sleep(0.5) #防止logger输出飞了
@@ -265,15 +279,14 @@ MoFox_Bot(第三方修改版)
self.server.run(), self.server.run(),
] ]
# 根据配置条件性地添加记忆系统相关任务 # 添加记忆系统相关任务
if global_config.memory.enable_memory and self.hippocampus_manager: tasks.extend(
tasks.extend( [
[ self.build_memory_task(),
self.build_memory_task(), self.forget_memory_task(),
self.forget_memory_task(), self.consolidate_memory_task(),
self.consolidate_memory_task(), ]
] )
)
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@@ -305,10 +318,6 @@ MoFox_Bot(第三方修改版)
def sync_build_memory(): def sync_build_memory():
"""在线程池中执行同步记忆构建""" """在线程池中执行同步记忆构建"""
if not self.hippocampus_manager:
logger.error("尝试在禁用记忆系统时构建记忆,操作已取消。")
return
try: try:
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)

View File

@@ -0,0 +1,132 @@
"""
跨群聊上下文API
"""
import time
from typing import Dict, Any, Optional, List
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
build_readable_messages_with_id,
)
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
logger = get_logger("cross_context_api")
def get_context_groups(chat_id: str) -> Optional[List[List[str]]]:
"""
获取当前聊天所在的共享组的其他聊天ID
"""
current_stream = get_chat_manager().get_stream(chat_id)
if not current_stream:
return None
is_group = current_stream.group_info is not None
current_chat_raw_id = (
current_stream.group_info.group_id if is_group else current_stream.user_info.user_id
)
current_type = "group" if is_group else "private"
for group in global_config.cross_context.groups:
# 检查当前聊天的ID和类型是否在组的chat_ids中
if [current_type, str(current_chat_raw_id)] in group.chat_ids:
# 返回组内其他聊天的 [type, id] 列表
return [
chat_info
for chat_info in group.chat_ids
if chat_info != [current_type, str(current_chat_raw_id)]
]
return None
async def build_cross_context_normal(
chat_stream: ChatStream, other_chat_infos: List[List[str]]
) -> str:
"""
构建跨群聊/私聊上下文 (Normal模式)
"""
cross_context_messages = []
for chat_type, chat_raw_id in other_chat_infos:
is_group = chat_type == "group"
stream_id = get_chat_manager().get_stream_id(
chat_stream.platform, chat_raw_id, is_group=is_group
)
if not stream_id:
continue
try:
messages = get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=5, # 可配置
)
if messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
formatted_messages, _ = build_readable_messages_with_id(
messages, timestamp_mode="relative"
)
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
except Exception as e:
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
continue
if not cross_context_messages:
return ""
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
async def build_cross_context_s4u(
chat_stream: ChatStream,
other_chat_infos: List[List[str]],
target_user_info: Optional[Dict[str, Any]],
) -> str:
"""
构建跨群聊/私聊上下文 (S4U模式)
"""
cross_context_messages = []
if target_user_info:
user_id = target_user_info.get("user_id")
if user_id:
for chat_type, chat_raw_id in other_chat_infos:
is_group = chat_type == "group"
stream_id = get_chat_manager().get_stream_id(
chat_stream.platform, chat_raw_id, is_group=is_group
)
if not stream_id:
continue
try:
messages = get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=20, # 获取更多消息以供筛选
)
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-5:]
if user_messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
user_name = (
target_user_info.get("person_name")
or target_user_info.get("user_nickname")
or user_id
)
formatted_messages, _ = build_readable_messages_with_id(
user_messages, timestamp_mode="relative"
)
cross_context_messages.append(
f'[以下是"{user_name}""{chat_name}"的近期发言]\n{formatted_messages}'
)
except Exception as e:
logger.error(f"获取用户 {user_id} 在聊天 {chat_raw_id} 的消息失败: {e}")
continue
if not cross_context_messages:
return ""
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"

View File

@@ -476,19 +476,20 @@ pre_sleep_notification_groups = []
# 用于生成睡前消息的提示。AI会根据这个提示生成一句晚安问候。 # 用于生成睡前消息的提示。AI会根据这个提示生成一句晚安问候。
pre_sleep_prompt = "我准备睡觉了,请生成一句简短自然的晚安问候。" pre_sleep_prompt = "我准备睡觉了,请生成一句简短自然的晚安问候。"
[cross_context] # 跨群聊上下文共享配置 [cross_context] # 跨群聊/私聊上下文共享配置
# 这是总开关,用于一键启用或禁用此功能 # 这是总开关,用于一键启用或禁用此功能
enable = false enable = true
# 在这里定义您的“共享组” # 在这里定义您的“共享组”
# 只有在同一个组内的聊才会共享上下文 # 只有在同一个组内的聊才会共享上下文
# 注意:这里的chat_ids需要填写群号 # 格式:chat_ids = [["type", "id"], ["type", "id"], ...]
# type 可选 "group" 或 "private"
[[cross_context.groups]] [[cross_context.groups]]
name = "项目A技术讨论组" name = "项目A技术讨论组"
chat_ids = [ chat_ids = [
"111111", # 假设这是“开发群”的ID ["group", "169850076"], # 假设这是“开发群”的ID
"222222" # 假设这是“产品群”的ID ["group", "1025509724"], # 假设这是“产品群”的ID
["private", "123456789"] # 假设这是某个用户的私聊
] ]
[maizone_intercom] [maizone_intercom]
# QQ空间互通组配置 # QQ空间互通组配置
# 启用后,发布说说时会读取指定互通组的上下文 # 启用后,发布说说时会读取指定互通组的上下文
@@ -498,6 +499,6 @@ enable = false
[[maizone_intercom.groups]] [[maizone_intercom.groups]]
name = "Maizone默认互通组" name = "Maizone默认互通组"
chat_ids = [ chat_ids = [
"111111", # 示例群聊1 ["group", "111111"], # 示例群聊1
"222222" # 示例聊2 ["private", "222222"] # 示例聊2
] ]