diff --git a/bot.py b/bot.py index 472ee5f08..a4a1030c5 100644 --- a/bot.py +++ b/bot.py @@ -35,16 +35,18 @@ script_dir = os.path.dirname(os.path.abspath(__file__)) os.chdir(script_dir) logger.info(f"已设置工作目录为: {script_dir}") + # 检查并创建.env文件 def ensure_env_file(): """确保.env文件存在,如果不存在则从模板创建""" env_file = Path(".env") template_env = Path("template/template.env") - + if not env_file.exists(): if template_env.exists(): logger.info("未找到.env文件,正在从模板创建...") import shutil + shutil.copy(template_env, env_file) logger.info("已从template/template.env创建.env文件") logger.warning("请编辑.env文件,将EULA_CONFIRMED设置为true并配置其他必要参数") @@ -52,6 +54,7 @@ def ensure_env_file(): logger.error("未找到.env文件和template.env模板文件") sys.exit(1) + # 确保环境文件存在 ensure_env_file() @@ -131,32 +134,32 @@ async def graceful_shutdown(): def check_eula(): """检查EULA和隐私条款确认状态 - 环境变量版(类似Minecraft)""" # 检查环境变量中的EULA确认 - eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower() - - if eula_confirmed == 'true': + eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower() + + if eula_confirmed == "true": logger.info("EULA已通过环境变量确认") return - + # 如果没有确认,提示用户 confirm_logger.critical("您需要同意EULA和隐私条款才能使用MoFox_Bot") confirm_logger.critical("请阅读以下文件:") confirm_logger.critical(" - EULA.md (用户许可协议)") confirm_logger.critical(" - PRIVACY.md (隐私条款)") confirm_logger.critical("然后编辑 .env 文件,将 'EULA_CONFIRMED=false' 改为 'EULA_CONFIRMED=true'") - + # 等待用户确认 while True: try: load_dotenv(override=True) # 重新加载.env文件 - - eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower() - if eula_confirmed == 'true': + + eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower() + if eula_confirmed == "true": confirm_logger.info("EULA确认成功,感谢您的同意") return - + confirm_logger.critical("请修改 .env 文件中的 EULA_CONFIRMED=true 后重新启动程序") input("按Enter键检查.env文件状态...") - + except KeyboardInterrupt: confirm_logger.info("用户取消,程序退出") sys.exit(0) diff --git a/plugins/set_emoji_like/_manifest.json b/plugins/set_emoji_like/_manifest.json index 906fe81c7..2e322b64f 100644 --- a/plugins/set_emoji_like/_manifest.json +++ b/plugins/set_emoji_like/_manifest.json @@ -25,7 +25,7 @@ { "type": "action", "name": "set_emoji_like", - "description": "为消息设置表情回应" + "description": "为某条已经存在的消息添加‘贴表情’回应(类似点赞),而不是发送新消息。当用户明确要求‘贴表情’时使用。" } ], "features": [ diff --git a/plugins/set_emoji_like/plugin.py b/plugins/set_emoji_like/plugin.py index 810f0639e..5bc1a3ae8 100644 --- a/plugins/set_emoji_like/plugin.py +++ b/plugins/set_emoji_like/plugin.py @@ -45,7 +45,7 @@ class SetEmojiLikeAction(BaseAction): # === 基本信息(必须填写)=== action_name = "set_emoji_like" - action_description = "为一个已存在的消息添加点赞或表情回应(也叫‘贴表情’)" + action_description = "为某条已经存在的消息添加‘贴表情’回应(类似点赞),而不是发送新消息。可以在觉得某条消息非常有趣、值得赞同或者需要特殊情感回应时主动使用。" activation_type = ActionActivationType.ALWAYS # 消息接收时激活(?) chat_type_allow = ChatType.GROUP parallel_action = True diff --git a/scripts/update_prompt_imports.py b/scripts/update_prompt_imports.py index 289d7f327..227491ec2 100644 --- a/scripts/update_prompt_imports.py +++ b/scripts/update_prompt_imports.py @@ -20,25 +20,26 @@ files_to_update = [ "src/mais4u/mais4u_chat/s4u_mood_manager.py", "src/plugin_system/core/tool_use.py", "src/chat/memory_system/memory_activator.py", - "src/chat/utils/smart_prompt.py" + "src/chat/utils/smart_prompt.py", ] + def update_prompt_imports(file_path): """更新文件中的Prompt导入""" if not os.path.exists(file_path): print(f"文件不存在: {file_path}") return False - - with open(file_path, 'r', encoding='utf-8') as f: + + with open(file_path, "r", encoding="utf-8") as f: content = f.read() - + # 替换导入语句 old_import = "from src.chat.utils.prompt_builder import Prompt, global_prompt_manager" new_import = "from src.chat.utils.prompt import Prompt, global_prompt_manager" - + if old_import in content: new_content = content.replace(old_import, new_import) - with open(file_path, 'w', encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: f.write(new_content) print(f"已更新: {file_path}") return True @@ -46,14 +47,16 @@ def update_prompt_imports(file_path): print(f"无需更新: {file_path}") return False + def main(): """主函数""" updated_count = 0 for file_path in files_to_update: if update_prompt_imports(file_path): updated_count += 1 - + print(f"\n更新完成!共更新了 {updated_count} 个文件") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/chat/chat_loop/cycle_processor.py b/src/chat/chat_loop/cycle_processor.py deleted file mode 100644 index 79d4eca9d..000000000 --- a/src/chat/chat_loop/cycle_processor.py +++ /dev/null @@ -1,460 +0,0 @@ -import asyncio -import time -import traceback -import math -import random -from typing import Dict, Any, Tuple - -from src.chat.utils.timer_calculator import Timer -from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.planner_actions.planner import ActionPlanner -from src.chat.planner_actions.action_modifier import ActionModifier -from src.person_info.person_info import get_person_info_manager -from src.plugin_system.apis import database_api, generator_api -from src.plugin_system.base.component_types import ChatMode -from src.mais4u.constant_s4u import ENABLE_S4U -from src.chat.chat_loop.hfc_utils import send_typing, stop_typing -from .hfc_context import HfcContext -from .response_handler import ResponseHandler -from .cycle_tracker import CycleTracker - -# 日志记录器 -logger = get_logger("hfc.processor") - - -class CycleProcessor: - """ - 循环处理器类,负责处理单次思考循环的逻辑。 - """ - def __init__(self, context: HfcContext, response_handler: ResponseHandler, cycle_tracker: CycleTracker): - """ - 初始化循环处理器 - - Args: - context: HFC聊天上下文对象,包含聊天流、能量值等信息 - response_handler: 响应处理器,负责生成和发送回复 - cycle_tracker: 循环跟踪器,负责记录和管理每次思考循环的信息 - """ - self.context = context - self.response_handler = response_handler - self.cycle_tracker = cycle_tracker - self.action_planner = ActionPlanner(chat_id=self.context.stream_id, action_manager=self.context.action_manager) - self.action_modifier = ActionModifier( - action_manager=self.context.action_manager, chat_id=self.context.stream_id - ) - - self.log_prefix = self.context.log_prefix - - async def _send_and_store_reply( - self, - response_set, - loop_start_time, - action_message, - cycle_timers: Dict[str, float], - thinking_id, - actions, - ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: - """ - 发送并存储回复信息 - - Args: - response_set: 回复内容集合 - loop_start_time: 循环开始时间 - action_message: 动作消息 - cycle_timers: 循环计时器 - thinking_id: 思考ID - actions: 动作列表 - - Returns: - Tuple[Dict[str, Any], str, Dict[str, float]]: 循环信息, 回复文本, 循环计时器 - """ - # 发送回复 - with Timer("回复发送", cycle_timers): - reply_text = await self.response_handler.send_response(response_set, loop_start_time, action_message) - - # 存储reply action信息 - person_info_manager = get_person_info_manager() - - # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 - platform = action_message.get("chat_info_platform") - if platform is None: - platform = getattr(self.context.chat_stream, "platform", "unknown") - - # 获取用户信息并生成回复提示 - person_id = person_info_manager.get_person_id( - platform, - action_message.get("chat_info_user_id", ""), - ) - person_info = await person_info_manager.get_values(person_id, ["person_name"]) - person_name = person_info.get("person_name") - action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" - - # 存储动作信息到数据库 - await database_api.store_action_info( - chat_stream=self.context.chat_stream, - action_build_into_prompt=False, - action_prompt_display=action_prompt_display, - action_done=True, - thinking_id=thinking_id, - action_data={"reply_text": reply_text}, - action_name="reply", - ) - - # 构建循环信息 - loop_info: Dict[str, Any] = { - "loop_plan_info": { - "action_result": actions, - }, - "loop_action_info": { - "action_taken": True, - "reply_text": reply_text, - "command": "", - "taken_time": time.time(), - }, - } - - return loop_info, reply_text, cycle_timers - - async def observe(self, interest_value: float = 0.0) -> str: - """ - 观察和处理单次思考循环的核心方法 - - Args: - interest_value: 兴趣值 - - Returns: - str: 动作类型 - - 功能说明: - - 开始新的思考循环并记录计时 - - 修改可用动作并获取动作列表 - - 根据聊天模式和提及情况决定是否跳过规划器 - - 执行动作规划或直接回复 - - 根据动作类型分发到相应的处理方法 - """ - action_type = "no_action" - reply_text = "" # 初始化reply_text变量,避免UnboundLocalError - - # 使用sigmoid函数将interest_value转换为概率 - # 当interest_value为0时,概率接近0(使用Focus模式) - # 当interest_value很高时,概率接近1(使用Normal模式) - def calculate_normal_mode_probability(interest_val: float) -> float: - """ - 计算普通模式的概率 - - Args: - interest_val: 兴趣值 - - Returns: - float: 概率 - """ - # 使用sigmoid函数,调整参数使概率分布更合理 - # 当interest_value = 0时,概率约为0.1 - # 当interest_value = 1时,概率约为0.5 - # 当interest_value = 2时,概率约为0.8 - # 当interest_value = 3时,概率约为0.95 - k = 2.0 # 控制曲线陡峭程度 - x0 = 1.0 # 控制曲线中心点 - return 1.0 / (1.0 + math.exp(-k * (interest_val - x0))) - - # 计算普通模式概率 - normal_mode_probability = ( - calculate_normal_mode_probability(interest_value) - * 0.5 - / global_config.chat.get_current_talk_frequency(self.context.stream_id) - ) - - # 根据概率决定使用哪种模式 - if random.random() < normal_mode_probability: - mode = ChatMode.NORMAL - logger.info( - f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Normal planner模式" - ) - else: - mode = ChatMode.FOCUS - logger.info( - f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Focus planner模式" - ) - - # 开始新的思考循环 - cycle_timers, thinking_id = self.cycle_tracker.start_cycle() - logger.info(f"{self.log_prefix} 开始第{self.context.cycle_counter}次思考") - - if ENABLE_S4U and self.context.chat_stream and self.context.chat_stream.user_info: - await send_typing(self.context.chat_stream.user_info.user_id) - - loop_start_time = time.time() - - # 第一步:动作修改 - with Timer("动作修改", cycle_timers): - try: - await self.action_modifier.modify_actions() - available_actions = self.context.action_manager.get_using_actions() - except Exception as e: - logger.error(f"{self.log_prefix} 动作修改失败: {e}") - available_actions = {} - - # 规划动作 - from src.plugin_system.core.event_manager import event_manager - from src.plugin_system import EventType - - result = await event_manager.trigger_event( - EventType.ON_PLAN, permission_group="SYSTEM", stream_id=self.context.chat_stream - ) - if result and not result.all_continue_process(): - raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于规划前中断了内容生成") - with Timer("规划器", cycle_timers): - actions, _ = await self.action_planner.plan(mode=mode) - - async def execute_action(action_info): - """执行单个动作的通用函数""" - try: - if action_info["action_type"] == "no_action": - return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""} - if action_info["action_type"] == "no_reply": - # 直接处理no_reply逻辑,不再通过动作系统 - reason = action_info.get("reasoning", "选择不回复") - logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}") - - # 存储no_reply信息到数据库 - await database_api.store_action_info( - chat_stream=self.context.chat_stream, - action_build_into_prompt=False, - action_prompt_display=reason, - action_done=True, - thinking_id=thinking_id, - action_data={"reason": reason}, - action_name="no_reply", - ) - - return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""} - elif action_info["action_type"] != "reply" and action_info["action_type"] != "no_action": - # 记录并执行普通动作 - reason = action_info.get("reasoning", f"执行动作 {action_info['action_type']}") - logger.info(f"{self.log_prefix} 决定执行动作 '{action_info['action_type']}',内心思考: {reason}") - with Timer("动作执行", cycle_timers): - success, reply_text, command = await self._handle_action( - action_info["action_type"], - reason, # 使用已获取的reason - action_info["action_data"], - cycle_timers, - thinking_id, - action_info["action_message"], - ) - return { - "action_type": action_info["action_type"], - "success": success, - "reply_text": reply_text, - "command": command, - } - else: - # 生成回复 - try: - reason = action_info.get("reasoning", "决定进行回复") - logger.info(f"{self.log_prefix} 决定进行回复,内心思考: {reason}") - success, response_set, _ = await generator_api.generate_reply( - chat_stream=self.context.chat_stream, - reply_message=action_info["action_message"], - available_actions=available_actions, - enable_tool=global_config.tool.enable_tool, - request_type="chat.replyer", - from_plugin=False, - read_mark=action_info.get("action_message", {}).get("time", 0.0), - ) - if not success or not response_set: - logger.info( - f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败" - ) - return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} - except asyncio.CancelledError: - logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消") - return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} - - # 发送并存储回复 - loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply( - response_set, - loop_start_time, - action_info["action_message"], - cycle_timers, - thinking_id, - actions, - ) - return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info} - except Exception as e: - logger.error(f"{self.log_prefix} 执行动作时出错: {e}") - logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}") - return { - "action_type": action_info["action_type"], - "success": False, - "reply_text": "", - "loop_info": None, - "error": str(e), - } - - # 分离 reply 动作和其他动作 - reply_actions = [a for a in actions if a.get("action_type") == "reply"] - other_actions = [a for a in actions if a.get("action_type") != "reply"] - - reply_loop_info = None - reply_text_from_reply = "" - other_actions_results = [] - - # 1. 首先串行执行所有 reply 动作(通常只有一个) - if reply_actions: - logger.info(f"{self.log_prefix} 正在执行文本回复...") - for action in reply_actions: - action_message = action.get("action_message") - if not action_message: - logger.warning(f"{self.log_prefix} reply 动作缺少 action_message,跳过") - continue - - # 检查是否是空的DatabaseMessages对象 - if hasattr(action_message, 'chat_info') and hasattr(action_message.chat_info, 'user_info'): - target_user_id = action_message.chat_info.user_info.user_id - else: - # 如果是字典格式,使用原来的方式 - target_user_id = action_message.get("chat_info_user_id", "") - - if not target_user_id: - logger.warning(f"{self.log_prefix} reply 动作的 action_message 缺少用户ID,跳过") - continue - - if target_user_id == global_config.bot.qq_account and not global_config.chat.allow_reply_self: - logger.warning("选取的reply的目标为bot自己,跳过reply action") - continue - result = await execute_action(action) - if isinstance(result, Exception): - logger.error(f"{self.log_prefix} 回复动作执行异常: {result}") - continue - if result.get("success"): - reply_loop_info = result.get("loop_info") - reply_text_from_reply = result.get("reply_text", "") - else: - logger.warning(f"{self.log_prefix} 回复动作执行失败") - - # 2. 然后并行执行所有其他动作 - if other_actions: - logger.info(f"{self.log_prefix} 正在执行附加动作: {[a.get('action_type') for a in other_actions]}") - other_action_tasks = [asyncio.create_task(execute_action(action)) for action in other_actions] - results = await asyncio.gather(*other_action_tasks, return_exceptions=True) - for i, result in enumerate(results): - if isinstance(result, BaseException): - logger.error(f"{self.log_prefix} 附加动作执行异常: {result}") - continue - other_actions_results.append(result) - - # 构建最终的循环信息 - if reply_loop_info: - loop_info = reply_loop_info - # 将其他动作的结果合并到loop_info中 - if "other_actions" not in loop_info["loop_action_info"]: - loop_info["loop_action_info"]["other_actions"] = [] - loop_info["loop_action_info"]["other_actions"].extend(other_actions_results) - reply_text = reply_text_from_reply - else: - # 没有回复信息,构建纯动作的loop_info - # 即使没有回复,也要正确处理其他动作 - final_action_taken = any(res.get("success", False) for res in other_actions_results) - final_reply_text = " ".join(res.get("reply_text", "") for res in other_actions_results if res.get("reply_text")) - final_command = " ".join(res.get("command", "") for res in other_actions_results if res.get("command")) - - loop_info = { - "loop_plan_info": { - "action_result": actions, - }, - "loop_action_info": { - "action_taken": final_action_taken, - "reply_text": final_reply_text, - "command": final_command, - "taken_time": time.time(), - "other_actions": other_actions_results, - }, - } - reply_text = final_reply_text - - # 停止正在输入状态 - if ENABLE_S4U: - await stop_typing() - - # 结束循环 - self.context.chat_instance.cycle_tracker.end_cycle(loop_info, cycle_timers) - self.context.chat_instance.cycle_tracker.print_cycle_info(cycle_timers) - - action_type = actions[0]["action_type"] if actions else "no_action" - return action_type - - async def _handle_action( - self, action, reasoning, action_data, cycle_timers, thinking_id, action_message - ) -> tuple[bool, str, str]: - """ - 处理具体的动作执行 - - Args: - action: 动作名称 - reasoning: 执行理由 - action_data: 动作数据 - cycle_timers: 循环计时器 - thinking_id: 思考ID - action_message: 动作消息 - - Returns: - tuple: (执行是否成功, 回复文本, 命令文本) - - 功能说明: - - 创建对应的动作处理器 - - 执行动作并捕获异常 - - 返回执行结果供上级方法整合 - """ - if not self.context.chat_stream: - return False, "", "" - try: - # 创建动作处理器 - action_handler = self.context.action_manager.create_action( - action_name=action, - action_data=action_data, - reasoning=reasoning, - cycle_timers=cycle_timers, - thinking_id=thinking_id, - chat_stream=self.context.chat_stream, - log_prefix=self.context.log_prefix, - action_message=action_message, - ) - if not action_handler: - # 动作处理器创建失败,尝试回退机制 - logger.warning(f"{self.context.log_prefix} 创建动作处理器失败: {action},尝试回退方案") - - # 获取当前可用的动作 - available_actions = self.context.action_manager.get_using_actions() - fallback_action = None - - # 回退优先级:reply > 第一个可用动作 - if "reply" in available_actions: - fallback_action = "reply" - elif available_actions: - fallback_action = list(available_actions.keys())[0] - - if fallback_action and fallback_action != action: - logger.info(f"{self.context.log_prefix} 使用回退动作: {fallback_action}") - action_handler = self.context.action_manager.create_action( - action_name=fallback_action, - action_data=action_data, - reasoning=f"原动作'{action}'不可用,自动回退。{reasoning}", - cycle_timers=cycle_timers, - thinking_id=thinking_id, - chat_stream=self.context.chat_stream, - log_prefix=self.context.log_prefix, - action_message=action_message, - ) - - if not action_handler: - logger.error(f"{self.context.log_prefix} 回退方案也失败,无法创建任何动作处理器") - return False, "", "" - - # 执行动作 - success, reply_text = await action_handler.handle_action() - return success, reply_text, "" - except Exception as e: - logger.error(f"{self.context.log_prefix} 处理{action}时出错: {e}") - traceback.print_exc() - return False, "", "" diff --git a/src/chat/chat_loop/cycle_tracker.py b/src/chat/chat_loop/cycle_tracker.py deleted file mode 100644 index 1f45c4caf..000000000 --- a/src/chat/chat_loop/cycle_tracker.py +++ /dev/null @@ -1,114 +0,0 @@ -import time -from typing import Dict, Any, Tuple - -from src.common.logger import get_logger -from src.chat.chat_loop.hfc_utils import CycleDetail -from .hfc_context import HfcContext - -logger = get_logger("hfc") - - -class CycleTracker: - def __init__(self, context: HfcContext): - """ - 初始化循环跟踪器 - - Args: - context: HFC聊天上下文对象 - - 功能说明: - - 负责跟踪和记录每次思考循环的详细信息 - - 管理循环的开始、结束和信息存储 - """ - self.context = context - - def start_cycle(self, is_proactive: bool = False) -> Tuple[Dict[str, float], str]: - """ - 开始新的思考循环 - - Args: - is_proactive: 标记这个循环是否由主动思考发起 - - Returns: - tuple: (循环计时器字典, 思考ID字符串) - - 功能说明: - - 增加循环计数器 - - 创建新的循环详情对象 - - 生成唯一的思考ID - - 初始化循环计时器 - """ - if not is_proactive: - self.context.cycle_counter += 1 - - cycle_id = self.context.cycle_counter if not is_proactive else f"{self.context.cycle_counter}.p" - self.context.current_cycle_detail = CycleDetail(cycle_id) - self.context.current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}" - cycle_timers = {} - return cycle_timers, self.context.current_cycle_detail.thinking_id - - def end_cycle(self, loop_info: Dict[str, Any], cycle_timers: Dict[str, float]): - """ - 结束当前思考循环 - - Args: - loop_info: 循环信息,包含规划和动作信息 - cycle_timers: 循环计时器,记录各阶段耗时 - - 功能说明: - - 设置循环详情的完整信息 - - 将当前循环加入历史记录 - - 记录计时器和结束时间 - - 打印循环统计信息 - """ - if self.context.current_cycle_detail: - self.context.current_cycle_detail.set_loop_info(loop_info) - self.context.history_loop.append(self.context.current_cycle_detail) - self.context.current_cycle_detail.timers = cycle_timers - self.context.current_cycle_detail.end_time = time.time() - self.print_cycle_info(cycle_timers) - - def print_cycle_info(self, cycle_timers: Dict[str, float]): - """ - 打印循环统计信息 - - Args: - cycle_timers: 循环计时器字典 - - 功能说明: - - 格式化各阶段的耗时信息 - - 计算总体循环持续时间 - - 输出详细的性能统计日志 - - 显示选择的动作类型 - """ - if not self.context.current_cycle_detail: - return - - timer_strings = [] - for name, elapsed in cycle_timers.items(): - formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒" - timer_strings.append(f"{name}: {formatted_time}") - - # 获取动作类型,兼容新旧格式 - # 获取动作类型 - action_type = "未知动作" - if self.context.current_cycle_detail: - loop_plan_info = self.context.current_cycle_detail.loop_plan_info - actions = loop_plan_info.get("action_result") - - if isinstance(actions, list) and actions: - # 从actions列表中提取所有action_type - action_types = [a.get("action_type", "未知") for a in actions] - action_type = ", ".join(action_types) - elif isinstance(actions, dict): - # 兼容旧格式 - action_type = actions.get("action_type", "未知动作") - - - if self.context.current_cycle_detail.end_time and self.context.current_cycle_detail.start_time: - duration = self.context.current_cycle_detail.end_time - self.context.current_cycle_detail.start_time - logger.info( - f"{self.context.log_prefix} 第{self.context.current_cycle_detail.cycle_id}次思考," - f"耗时: {duration:.1f}秒, " - f"选择动作: {action_type}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "") - ) diff --git a/src/chat/chat_loop/energy_manager.py b/src/chat/chat_loop/energy_manager.py deleted file mode 100644 index 2eb7e7265..000000000 --- a/src/chat/chat_loop/energy_manager.py +++ /dev/null @@ -1,162 +0,0 @@ -import asyncio -import time -from typing import Optional -from src.common.logger import get_logger -from src.config.config import global_config -from .hfc_context import HfcContext -from src.chat.chat_loop.sleep_manager import sleep_manager -logger = get_logger("hfc") - - -class EnergyManager: - def __init__(self, context: HfcContext): - """ - 初始化能量管理器 - - Args: - context: HFC聊天上下文对象 - - 功能说明: - - 管理聊天机器人的能量值系统 - - 根据聊天模式自动调整能量消耗 - - 控制能量值的衰减和记录 - """ - self.context = context - self._energy_task: Optional[asyncio.Task] = None - self.last_energy_log_time = 0 - self.energy_log_interval = 90 - - async def start(self): - """ - 启动能量管理器 - - 功能说明: - - 检查运行状态,避免重复启动 - - 创建能量循环异步任务 - - 设置任务完成回调 - - 记录启动日志 - """ - if self.context.running and not self._energy_task: - self._energy_task = asyncio.create_task(self._energy_loop()) - self._energy_task.add_done_callback(self._handle_energy_completion) - logger.info(f"{self.context.log_prefix} 能量管理器已启动") - - async def stop(self): - """ - 停止能量管理器 - - 功能说明: - - 取消正在运行的能量循环任务 - - 等待任务完全停止 - - 记录停止日志 - """ - if self._energy_task and not self._energy_task.done(): - self._energy_task.cancel() - await asyncio.sleep(0) - logger.info(f"{self.context.log_prefix} 能量管理器已停止") - - async def _energy_loop(self): - """ - 能量与睡眠压力管理的主循环 - - 功能说明: - - 每10秒执行一次能量更新 - - 根据群聊配置设置固定的聊天模式和能量值 - - 在自动模式下根据聊天模式进行能量衰减 - - NORMAL模式每次衰减0.3,FOCUS模式每次衰减0.6 - - 确保能量值不低于0.3的最小值 - """ - while self.context.running: - await asyncio.sleep(10) - - if not self.context.chat_stream: - continue - - # 判断当前是否为睡眠时间 - is_sleeping = sleep_manager.SleepManager().is_sleeping() - - if is_sleeping: - # 睡眠中:减少睡眠压力 - decay_per_10s = global_config.sleep_system.sleep_pressure_decay_rate / 6 - self.context.sleep_pressure -= decay_per_10s - self.context.sleep_pressure = max(self.context.sleep_pressure, 0) - self._log_sleep_pressure_change("睡眠压力释放") - self.context.save_context_state() - else: - # 清醒时:处理能量衰减 - is_group_chat = self.context.chat_stream.group_info is not None - if is_group_chat: - self.context.energy_value = 25 - - await asyncio.sleep(12) - self.context.energy_value -= 0.5 - self.context.energy_value = max(self.context.energy_value, 0.3) - - self._log_energy_change("能量值衰减") - self.context.save_context_state() - - def _should_log_energy(self) -> bool: - """ - 判断是否应该记录能量变化日志 - - Returns: - bool: 如果距离上次记录超过间隔时间则返回True - - 功能说明: - - 控制能量日志的记录频率,避免日志过于频繁 - - 默认间隔90秒记录一次详细日志 - - 其他时间使用调试级别日志 - """ - current_time = time.time() - if current_time - self.last_energy_log_time >= self.energy_log_interval: - self.last_energy_log_time = current_time - return True - return False - - def increase_sleep_pressure(self): - """ - 在执行动作后增加睡眠压力 - """ - increment = global_config.sleep_system.sleep_pressure_increment - self.context.sleep_pressure += increment - self.context.sleep_pressure = min(self.context.sleep_pressure, 100.0) # 设置一个100的上限 - self._log_sleep_pressure_change("执行动作,睡眠压力累积") - self.context.save_context_state() - - def _log_energy_change(self, action: str, reason: str = ""): - """ - 记录能量变化日志 - - Args: - action: 能量变化的动作描述 - reason: 可选的变化原因 - - 功能说明: - - 根据时间间隔决定使用info还是debug级别的日志 - - 格式化能量值显示(保留一位小数) - - 可选择性地包含变化原因 - """ - if self._should_log_energy(): - log_message = f"{self.context.log_prefix} {action},当前能量值:{self.context.energy_value:.1f}" - if reason: - log_message = ( - f"{self.context.log_prefix} {action},{reason},当前能量值:{self.context.energy_value:.1f}" - ) - logger.info(log_message) - else: - log_message = f"{self.context.log_prefix} {action},当前能量值:{self.context.energy_value:.1f}" - if reason: - log_message = ( - f"{self.context.log_prefix} {action},{reason},当前能量值:{self.context.energy_value:.1f}" - ) - logger.debug(log_message) - - def _log_sleep_pressure_change(self, action: str): - """ - 记录睡眠压力变化日志 - """ - # 使用与能量日志相同的频率控制 - if self._should_log_energy(): - logger.info(f"{self.context.log_prefix} {action},当前睡眠压力:{self.context.sleep_pressure:.1f}") - else: - logger.debug(f"{self.context.log_prefix} {action},当前睡眠压力:{self.context.sleep_pressure:.1f}") diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py deleted file mode 100644 index adc868117..000000000 --- a/src/chat/chat_loop/heartFC_chat.py +++ /dev/null @@ -1,574 +0,0 @@ -import asyncio -import time -import traceback -import random -from typing import Optional, List, Dict, Any -from collections import deque - -from src.common.logger import get_logger -from src.config.config import global_config -from src.person_info.relationship_builder_manager import relationship_builder_manager -from src.chat.express.expression_learner import expression_learner_manager -from src.chat.chat_loop.sleep_manager.sleep_manager import SleepManager, SleepState - -from .hfc_context import HfcContext -from .energy_manager import EnergyManager -from .proactive.proactive_thinker import ProactiveThinker -from .cycle_processor import CycleProcessor -from .response_handler import ResponseHandler -from .cycle_tracker import CycleTracker -from .sleep_manager.wakeup_manager import WakeUpManager -from .proactive.events import ProactiveTriggerEvent - -logger = get_logger("hfc") - - -class HeartFChatting: - def __init__(self, chat_id: str): - """ - 初始化心跳聊天管理器 - - Args: - chat_id: 聊天ID标识符 - - 功能说明: - - 创建聊天上下文和所有子管理器 - - 初始化循环跟踪器、响应处理器、循环处理器等核心组件 - - 设置能量管理器、主动思考器和普通模式处理器 - - 初始化聊天模式并记录初始化完成日志 - """ - self.context = HfcContext(chat_id) - self.context.new_message_queue = asyncio.Queue() - self._processing_lock = asyncio.Lock() - - self.cycle_tracker = CycleTracker(self.context) - self.response_handler = ResponseHandler(self.context) - self.cycle_processor = CycleProcessor(self.context, self.response_handler, self.cycle_tracker) - self.energy_manager = EnergyManager(self.context) - self.proactive_thinker = ProactiveThinker(self.context, self.cycle_processor) - self.wakeup_manager = WakeUpManager(self.context) - self.sleep_manager = SleepManager() - - # 将唤醒度管理器设置到上下文中 - self.context.wakeup_manager = self.wakeup_manager - self.context.energy_manager = self.energy_manager - self.context.sleep_manager = self.sleep_manager - # 将HeartFChatting实例设置到上下文中,以便其他组件可以调用其方法 - self.context.chat_instance = self - - self._loop_task: Optional[asyncio.Task] = None - self._proactive_monitor_task: Optional[asyncio.Task] = None - - # 记录最近3次的兴趣度 - self.recent_interest_records: deque = deque(maxlen=3) - self._initialize_chat_mode() - logger.info(f"{self.context.log_prefix} HeartFChatting 初始化完成") - - def _initialize_chat_mode(self): - """ - 初始化聊天模式 - - 功能说明: - - 检测是否为群聊环境 - - 根据全局配置设置强制聊天模式 - - 在focus模式下设置能量值为35 - - 在normal模式下设置能量值为15 - - 如果是auto模式则保持默认设置 - """ - is_group_chat = self.context.chat_stream.group_info is not None if self.context.chat_stream else False - if is_group_chat and global_config.chat.group_chat_mode != "auto": - self.context.energy_value = 25 - - async def start(self): - """ - 启动心跳聊天系统 - - 功能说明: - - 检查是否已经在运行,避免重复启动 - - 初始化关系构建器和表达学习器 - - 启动能量管理器和主动思考器 - - 创建主聊天循环任务并设置完成回调 - - 记录启动完成日志 - """ - if self.context.running: - return - self.context.running = True - - self.context.relationship_builder = relationship_builder_manager.get_or_create_builder(self.context.stream_id) - self.context.expression_learner = await expression_learner_manager.get_expression_learner(self.context.stream_id) - - # 启动主动思考监视器 - if global_config.chat.enable_proactive_thinking: - self._proactive_monitor_task = asyncio.create_task(self._proactive_monitor_loop()) - self._proactive_monitor_task.add_done_callback(self._handle_proactive_monitor_completion) - logger.info(f"{self.context.log_prefix} 主动思考监视器已启动") - - await self.wakeup_manager.start() - - self._loop_task = asyncio.create_task(self._main_chat_loop()) - self._loop_task.add_done_callback(self._handle_loop_completion) - logger.info(f"{self.context.log_prefix} HeartFChatting 启动完成") - - async def add_message(self, message: Dict[str, Any]): - """从外部接收新消息并放入队列""" - await self.context.new_message_queue.put(message) - - async def stop(self): - """ - 停止心跳聊天系统 - - 功能说明: - - 检查是否正在运行,避免重复停止 - - 设置运行状态为False - - 停止能量管理器和主动思考器 - - 取消主聊天循环任务 - - 记录停止完成日志 - """ - if not self.context.running: - return - self.context.running = False - - # 停止主动思考监视器 - if self._proactive_monitor_task and not self._proactive_monitor_task.done(): - self._proactive_monitor_task.cancel() - await asyncio.sleep(0) - logger.info(f"{self.context.log_prefix} 主动思考监视器已停止") - - await self.wakeup_manager.stop() - - if self._loop_task and not self._loop_task.done(): - self._loop_task.cancel() - await asyncio.sleep(0) - logger.info(f"{self.context.log_prefix} HeartFChatting 已停止") - - def _handle_loop_completion(self, task: asyncio.Task): - """ - 处理主循环任务完成 - - Args: - task: 完成的异步任务对象 - - 功能说明: - - 处理任务异常完成的情况 - - 区分正常停止和异常终止 - - 记录相应的日志信息 - - 处理取消任务的情况 - """ - try: - if exception := task.exception(): - logger.error(f"{self.context.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}") - logger.error(traceback.format_exc()) - else: - logger.info(f"{self.context.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)") - except asyncio.CancelledError: - logger.info(f"{self.context.log_prefix} HeartFChatting: 结束了聊天") - - def _handle_proactive_monitor_completion(self, task: asyncio.Task): - """ - 处理主动思考监视器任务完成 - - Args: - task: 完成的异步任务对象 - - 功能说明: - - 处理任务异常完成的情况 - - 记录任务正常结束或被取消的日志 - """ - try: - if exception := task.exception(): - logger.error(f"{self.context.log_prefix} 主动思考监视器异常: {exception}") - else: - logger.info(f"{self.context.log_prefix} 主动思考监视器正常结束") - except asyncio.CancelledError: - logger.info(f"{self.context.log_prefix} 主动思考监视器被取消") - - async def _proactive_monitor_loop(self): - """ - 主动思考监视器循环 - - 功能说明: - - 定期检查是否需要进行主动思考 - - 计算聊天沉默时间,并与动态思考间隔比较 - - 当沉默时间超过阈值时,触发主动思考 - - 处理思考过程中的异常 - """ - while self.context.running: - await asyncio.sleep(15) - - if not self._should_enable_proactive_thinking(): - continue - - current_time = time.time() - silence_duration = current_time - self.context.last_message_time - target_interval = self._get_dynamic_thinking_interval() - - if silence_duration >= target_interval: - try: - formatted_time = self._format_duration(silence_duration) - event = ProactiveTriggerEvent( - source="silence_monitor", - reason=f"聊天已沉默 {formatted_time}", - metadata={"silence_duration": silence_duration}, - ) - await self.proactive_thinker.think(event) - self.context.last_message_time = current_time - except Exception as e: - logger.error(f"{self.context.log_prefix} 主动思考触发执行出错: {e}") - logger.error(traceback.format_exc()) - - def _should_enable_proactive_thinking(self) -> bool: - """ - 判断是否应启用主动思考 - - Returns: - bool: 如果应启用主动思考则返回True,否则返回False - - 功能说明: - - 检查全局配置和特定聊天设置 - - 支持按群聊和私聊分别配置 - - 支持白名单模式,只在特定聊天中启用 - """ - if not self.context.chat_stream: - return False - - is_group_chat = self.context.chat_stream.group_info is not None - - if is_group_chat and not global_config.chat.proactive_thinking_in_group: - return False - if not is_group_chat and not global_config.chat.proactive_thinking_in_private: - return False - - stream_parts = self.context.stream_id.split(":") - current_chat_identifier = f"{stream_parts}:{stream_parts}" if len(stream_parts) >= 2 else self.context.stream_id - - enable_list = getattr( - global_config.chat, - "proactive_thinking_enable_in_groups" if is_group_chat else "proactive_thinking_enable_in_private", - [], - ) - return not enable_list or current_chat_identifier in enable_list - - def _get_dynamic_thinking_interval(self) -> float: - """ - 获取动态思考间隔时间 - - Returns: - float: 思考间隔秒数 - - 功能说明: - - 尝试从timing_utils导入正态分布间隔函数 - - 根据配置计算动态间隔,增加随机性 - - 在无法导入或计算出错时,回退到固定的间隔 - """ - try: - from src.utils.timing_utils import get_normal_distributed_interval - - base_interval = global_config.chat.proactive_thinking_interval - delta_sigma = getattr(global_config.chat, "delta_sigma", 120) - - if base_interval <= 0: - base_interval = abs(base_interval) - if delta_sigma < 0: - delta_sigma = abs(delta_sigma) - - if base_interval == 0 and delta_sigma == 0: - return 300 - if delta_sigma == 0: - return base_interval - - sigma_percentage = delta_sigma / base_interval if base_interval > 0 else delta_sigma / 1000 - return get_normal_distributed_interval(base_interval, sigma_percentage, 1, 86400, use_3sigma_rule=True) - - except ImportError: - logger.warning(f"{self.context.log_prefix} timing_utils不可用,使用固定间隔") - return max(300, abs(global_config.chat.proactive_thinking_interval)) - except Exception as e: - logger.error(f"{self.context.log_prefix} 动态间隔计算出错: {e},使用固定间隔") - return max(300, abs(global_config.chat.proactive_thinking_interval)) - - @staticmethod - def _format_duration(seconds: float) -> str: - """ - 格式化时长为可读字符串 - - Args: - seconds: 时长秒数 - - Returns: - str: 格式化后的字符串 (例如 "1小时2分3秒") - """ - hours = int(seconds // 3600) - minutes = int((seconds % 3600) // 60) - secs = int(seconds % 60) - parts = [] - if hours > 0: - parts.append(f"{hours}小时") - if minutes > 0: - parts.append(f"{minutes}分") - if secs > 0 or not parts: - parts.append(f"{secs}秒") - return "".join(parts) - - async def _main_chat_loop(self): - """ - 主聊天循环 - - 功能说明: - - 持续运行聊天处理循环 - - 只有在有新消息时才进行思考循环 - - 无新消息时等待新消息到达(由主动思考系统单独处理主动发言) - - 处理取消和异常情况 - - 在异常时尝试重新启动循环 - """ - try: - while self.context.running: - has_new_messages = await self._loop_body() - - if has_new_messages: - # 有新消息时,继续快速检查是否还有更多消息 - await asyncio.sleep(1) - else: - # 无新消息时,等待较长时间再检查 - # 这里只是为了定期检查系统状态,不进行思考循环 - # 真正的新消息响应依赖于消息到达时的通知 - await asyncio.sleep(1.0) - - except asyncio.CancelledError: - logger.info(f"{self.context.log_prefix} 麦麦已关闭聊天") - except Exception: - logger.error(f"{self.context.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动") - print(traceback.format_exc()) - await asyncio.sleep(3) - self._loop_task = asyncio.create_task(self._main_chat_loop()) - logger.error(f"{self.context.log_prefix} 结束了当前聊天循环") - - async def _loop_body(self) -> bool: - """ - 单次循环体处理 - - Returns: - bool: 是否处理了新消息 - - 功能说明: - - 检查是否处于睡眠模式,如果是则处理唤醒度逻辑 - - 获取最近的新消息(过滤机器人自己的消息和命令) - - 只有在有新消息时才进行思考循环处理 - - 更新最后消息时间和读取时间 - - 根据当前聊天模式执行不同的处理逻辑 - - FOCUS模式:直接处理所有消息并检查退出条件 - - NORMAL模式:检查进入FOCUS模式的条件,并通过normal_mode_handler处理消息 - """ - async with self._processing_lock: - # --- 核心状态更新 --- - await self.sleep_manager.update_sleep_state(self.wakeup_manager) - current_sleep_state = self.sleep_manager.get_current_sleep_state() - is_sleeping = current_sleep_state == SleepState.SLEEPING - is_in_insomnia = current_sleep_state == SleepState.INSOMNIA - - # 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收 - filter_command_flag = not (is_sleeping or is_in_insomnia) - - # 从队列中获取所有待处理的新消息 - recent_messages = [] - while not self.context.new_message_queue.empty(): - recent_messages.append(await self.context.new_message_queue.get()) - - has_new_messages = bool(recent_messages) - new_message_count = len(recent_messages) - - # 只有在有新消息时才进行思考循环处理 - if has_new_messages: - self.context.last_message_time = time.time() - self.context.last_read_time = time.time() - - # --- 专注模式安静群组检查 --- - quiet_groups = global_config.chat.focus_mode_quiet_groups - if quiet_groups and self.context.chat_stream: - is_group_chat = self.context.chat_stream.group_info is not None - if is_group_chat: - try: - platform = self.context.chat_stream.platform - group_id = self.context.chat_stream.group_info.group_id - - # 兼容不同QQ适配器的平台名称 - is_qq_platform = platform in ["qq", "napcat"] - - current_chat_identifier = f"{platform}:{group_id}" - config_identifier_for_qq = f"qq:{group_id}" - - is_in_quiet_list = (current_chat_identifier in quiet_groups or - (is_qq_platform and config_identifier_for_qq in quiet_groups)) - - if is_in_quiet_list: - is_mentioned_in_batch = False - for msg in recent_messages: - if msg.get("is_mentioned"): - is_mentioned_in_batch = True - break - - if not is_mentioned_in_batch: - logger.info(f"{self.context.log_prefix} 在专注安静模式下,因未被提及而忽略了消息。") - return True # 消耗消息但不做回复 - except Exception as e: - logger.error(f"{self.context.log_prefix} 检查专注安静群组时出错: {e}") - - # 处理唤醒度逻辑 - if current_sleep_state in [SleepState.SLEEPING, SleepState.PREPARING_SLEEP, SleepState.INSOMNIA]: - self._handle_wakeup_messages(recent_messages) - - # 再次获取最新状态,因为 handle_wakeup 可能导致状态变为 WOKEN_UP - current_sleep_state = self.sleep_manager.get_current_sleep_state() - - if current_sleep_state == SleepState.SLEEPING: - # 只有在纯粹的 SLEEPING 状态下才跳过消息处理 - return True - - if current_sleep_state == SleepState.WOKEN_UP: - logger.info(f"{self.context.log_prefix} 从睡眠中被唤醒,将处理积压的消息。") - - # 根据聊天模式处理新消息 - should_process, interest_value = await self._should_process_messages(recent_messages) - if not should_process: - # 消息数量不足或兴趣不够,等待 - await asyncio.sleep(0.5) - return True # Skip rest of the logic for this iteration - - # Messages should be processed - action_type = await self.cycle_processor.observe(interest_value=interest_value) - - # 尝试触发表达学习 - if self.context.expression_learner: - try: - await self.context.expression_learner.trigger_learning_for_chat() - except Exception as e: - logger.error(f"{self.context.log_prefix} 表达学习触发失败: {e}") - - # 管理no_reply计数器 - if action_type != "no_reply": - self.recent_interest_records.clear() - self.context.no_reply_consecutive = 0 - logger.debug(f"{self.context.log_prefix} 执行了{action_type}动作,重置no_reply计数器") - else: # action_type == "no_reply" - self.context.no_reply_consecutive += 1 - self._determine_form_type() - - # 在一轮动作执行完毕后,增加睡眠压力 - if self.context.energy_manager and global_config.sleep_system.enable_insomnia_system: - if action_type not in ["no_reply", "no_action"]: - self.context.energy_manager.increase_sleep_pressure() - - # 如果成功观察,增加能量值并重置累积兴趣值 - self.context.energy_value += 1 / global_config.chat.focus_value - # 重置累积兴趣值,因为消息已经被成功处理 - self.context.breaking_accumulated_interest = 0.0 - logger.info( - f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值" - ) - - # 更新上一帧的睡眠状态 - self.context.was_sleeping = is_sleeping - - # --- 重新入睡逻辑 --- - # 如果被吵醒了,并且在一定时间内没有新消息,则尝试重新入睡 - if self.sleep_manager.get_current_sleep_state() == SleepState.WOKEN_UP and not has_new_messages: - re_sleep_delay = global_config.sleep_system.re_sleep_delay_minutes * 60 - # 使用 last_message_time 来判断空闲时间 - if time.time() - self.context.last_message_time > re_sleep_delay: - logger.info( - f"{self.context.log_prefix} 已被唤醒且超过 {re_sleep_delay / 60} 分钟无新消息,尝试重新入睡。" - ) - self.sleep_manager.reset_sleep_state_after_wakeup() - - # 保存HFC上下文状态 - self.context.save_context_state() - return has_new_messages - - def _handle_wakeup_messages(self, messages): - """ - 处理休眠状态下的消息,累积唤醒度 - - Args: - messages: 消息列表 - - 功能说明: - - 区分私聊和群聊消息 - - 检查群聊消息是否艾特了机器人 - - 调用唤醒度管理器累积唤醒度 - - 如果达到阈值则唤醒并进入愤怒状态 - """ - if not self.wakeup_manager: - return - - is_private_chat = self.context.chat_stream.group_info is None if self.context.chat_stream else False - - for message in messages: - is_mentioned = False - - # 检查群聊消息是否艾特了机器人 - if not is_private_chat: - # 最终修复:直接使用消息对象中由上游处理好的 is_mention 字段。 - # 该字段在 message.py 的 MessageRecv._process_single_segment 中被设置。 - if message.get("is_mentioned"): - is_mentioned = True - - # 累积唤醒度 - woke_up = self.wakeup_manager.add_wakeup_value(is_private_chat, is_mentioned) - - if woke_up: - logger.info(f"{self.context.log_prefix} 被消息吵醒,进入愤怒状态!") - break - - def _determine_form_type(self) -> str: - """判断使用哪种形式的no_reply""" - # 检查是否启用breaking模式 - if not getattr(global_config.chat, "enable_breaking_mode", False): - logger.info(f"{self.context.log_prefix} breaking模式已禁用,使用waiting形式") - self.context.focus_energy = 1 - return "waiting" - - # 如果连续no_reply次数少于3次,使用waiting形式 - if self.context.no_reply_consecutive <= 3: - self.context.focus_energy = 1 - return "waiting" - else: - # 使用累积兴趣值而不是最近3次的记录 - total_interest = self.context.breaking_accumulated_interest - - # 计算调整后的阈值 - adjusted_threshold = 1 / global_config.chat.get_current_talk_frequency(self.context.stream_id) - - logger.info( - f"{self.context.log_prefix} 累积兴趣值: {total_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}" - ) - - # 如果累积兴趣值小于阈值,进入breaking形式 - if total_interest < adjusted_threshold: - logger.info(f"{self.context.log_prefix} 累积兴趣度不足,进入breaking形式") - self.context.focus_energy = random.randint(3, 6) - return "breaking" - else: - logger.info(f"{self.context.log_prefix} 累积兴趣度充足,使用waiting形式") - self.context.focus_energy = 1 - return "waiting" - - async def _should_process_messages(self, new_message: List[Dict[str, Any]]) -> tuple[bool, float]: - """ - 统一判断是否应该处理消息的函数 - 根据当前循环模式和消息内容决定是否继续处理 - """ - if not new_message: - return False, 0.0 - - # 计算平均兴趣值 - total_interest = 0.0 - message_count = 0 - for msg_dict in new_message: - interest_value = msg_dict.get("interest_value", 0.0) - if msg_dict.get("processed_plain_text", ""): - total_interest += interest_value - message_count += 1 - - avg_interest = total_interest / message_count if message_count > 0 else 0.0 - - logger.info(f"{self.context.log_prefix} 收到 {len(new_message)} 条新消息,立即处理!平均兴趣值: {avg_interest:.2f}") - return True, avg_interest diff --git a/src/chat/chat_loop/hfc_context.py b/src/chat/chat_loop/hfc_context.py deleted file mode 100644 index 67606de12..000000000 --- a/src/chat/chat_loop/hfc_context.py +++ /dev/null @@ -1,82 +0,0 @@ -import time -from typing import List, Optional, TYPE_CHECKING - -from src.chat.chat_loop.hfc_utils import CycleDetail -from src.chat.express.expression_learner import ExpressionLearner -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from src.chat.planner_actions.action_manager import ActionManager -from src.config.config import global_config -from src.person_info.relationship_builder_manager import RelationshipBuilder - -if TYPE_CHECKING: - pass - - -class HfcContext: - def __init__(self, chat_id: str): - """ - 初始化HFC聊天上下文 - - Args: - chat_id: 聊天ID标识符 - - 功能说明: - - 存储和管理单个聊天会话的所有状态信息 - - 包含聊天流、关系构建器、表达学习器等核心组件 - - 管理聊天模式、能量值、时间戳等关键状态 - - 提供循环历史记录和当前循环详情的存储 - - 集成唤醒度管理器,处理休眠状态下的唤醒机制 - - Raises: - ValueError: 如果找不到对应的聊天流 - """ - self.stream_id: str = chat_id - self.chat_stream: Optional[ChatStream] = get_chat_manager().get_stream(self.stream_id) - if not self.chat_stream: - raise ValueError(f"无法找到聊天流: {self.stream_id}") - - self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]" - - self.relationship_builder: Optional[RelationshipBuilder] = None - self.expression_learner: Optional[ExpressionLearner] = None - - self.energy_value = self.chat_stream.energy_value - self.sleep_pressure = self.chat_stream.sleep_pressure - self.was_sleeping = False # 用于检测睡眠状态的切换 - - self.last_message_time = time.time() - self.last_read_time = time.time() - 10 - - # 从聊天流恢复breaking累积兴趣值 - self.breaking_accumulated_interest = getattr(self.chat_stream, "breaking_accumulated_interest", 0.0) - - self.action_manager = ActionManager() - - self.running: bool = False - - self.history_loop: List[CycleDetail] = [] - self.cycle_counter = 0 - self.current_cycle_detail: Optional[CycleDetail] = None - - # 唤醒度管理器 - 延迟初始化以避免循环导入 - self.wakeup_manager: Optional["WakeUpManager"] = None - self.energy_manager: Optional["EnergyManager"] = None - self.sleep_manager: Optional["SleepManager"] = None - - # 从聊天流获取focus_energy,如果没有则使用配置文件中的值 - self.focus_energy = getattr(self.chat_stream, "focus_energy", global_config.chat.focus_value) - self.no_reply_consecutive = 0 - self.total_interest = 0.0 - # breaking形式下的累积兴趣值 - self.breaking_accumulated_interest = 0.0 - # 引用HeartFChatting实例,以便其他组件可以调用其方法 - self.chat_instance: "HeartFChatting" - - def save_context_state(self): - """将当前状态保存到聊天流""" - if self.chat_stream: - self.chat_stream.energy_value = self.energy_value - self.chat_stream.sleep_pressure = self.sleep_pressure - self.chat_stream.focus_energy = self.focus_energy - self.chat_stream.no_reply_consecutive = self.no_reply_consecutive - self.chat_stream.breaking_accumulated_interest = self.breaking_accumulated_interest diff --git a/src/chat/chat_loop/hfc_utils.py b/src/chat/chat_loop/hfc_utils.py deleted file mode 100644 index 32d31fd52..000000000 --- a/src/chat/chat_loop/hfc_utils.py +++ /dev/null @@ -1,172 +0,0 @@ -import time -from typing import Optional, Dict, Any, Union - -from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import get_chat_manager -from src.plugin_system.apis import send_api -from maim_message.message_base import GroupInfo - - -logger = get_logger("hfc") - - -class CycleDetail: - """ - 循环信息记录类 - - 功能说明: - - 记录单次思考循环的详细信息 - - 包含循环ID、思考ID、时间戳等基本信息 - - 存储循环的规划信息和动作信息 - - 提供序列化和转换功能 - """ - - def __init__(self, cycle_id: Union[int, str]): - """ - 初始化循环详情记录 - - Args: - cycle_id: 循环ID,用于标识循环的顺序 - - 功能说明: - - 设置循环基本标识信息 - - 初始化时间戳和计时器 - - 准备循环信息存储容器 - """ - self.cycle_id = cycle_id - self.thinking_id = "" - self.start_time = time.time() - self.end_time: Optional[float] = None - self.timers: Dict[str, float] = {} - - self.loop_plan_info: Dict[str, Any] = {} - self.loop_action_info: Dict[str, Any] = {} - - def to_dict(self) -> Dict[str, Any]: - """ - 将循环信息转换为字典格式 - - Returns: - dict: 包含所有循环信息的字典,已处理循环引用和序列化问题 - - 功能说明: - - 递归转换复杂对象为可序列化格式 - - 防止循环引用导致的无限递归 - - 限制递归深度避免栈溢出 - - 只保留基本数据类型和可序列化的值 - """ - - def convert_to_serializable(obj, depth=0, seen=None): - if seen is None: - seen = set() - - # 防止递归过深 - if depth > 5: # 降低递归深度限制 - return str(obj) - - # 防止循环引用 - obj_id = id(obj) - if obj_id in seen: - return str(obj) - seen.add(obj_id) - - try: - if hasattr(obj, "to_dict"): - # 对于有to_dict方法的对象,直接调用其to_dict方法 - return obj.to_dict() - elif isinstance(obj, dict): - # 对于字典,只保留基本类型和可序列化的值 - return { - k: convert_to_serializable(v, depth + 1, seen) - for k, v in obj.items() - if isinstance(k, (str, int, float, bool)) - } - elif isinstance(obj, (list, tuple)): - # 对于列表和元组,只保留可序列化的元素 - return [ - convert_to_serializable(item, depth + 1, seen) - for item in obj - if not isinstance(item, (dict, list, tuple)) - or isinstance(item, (str, int, float, bool, type(None))) - ] - elif isinstance(obj, (str, int, float, bool, type(None))): - return obj - else: - return str(obj) - finally: - seen.remove(obj_id) - - return { - "cycle_id": self.cycle_id, - "start_time": self.start_time, - "end_time": self.end_time, - "timers": self.timers, - "thinking_id": self.thinking_id, - "loop_plan_info": convert_to_serializable(self.loop_plan_info), - "loop_action_info": convert_to_serializable(self.loop_action_info), - } - - def set_loop_info(self, loop_info: Dict[str, Any]): - """ - 设置循环信息 - - Args: - loop_info: 包含循环规划和动作信息的字典 - - 功能说明: - - 从传入的循环信息中提取规划和动作信息 - - 更新当前循环详情的相关字段 - """ - self.loop_plan_info = loop_info["loop_plan_info"] - self.loop_action_info = loop_info["loop_action_info"] - - -async def send_typing(user_id): - """ - 发送打字状态指示 - - 功能说明: - - 创建内心聊天流(用于状态显示) - - 发送typing状态消息 - - 不存储到消息记录中 - - 用于S4U功能的视觉反馈 - """ - group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心") - - chat = await get_chat_manager().get_or_create_stream( - platform="amaidesu_default", - user_info=None, - group_info=group_info, - ) - - from plugin_system.core.event_manager import event_manager - from src.plugins.built_in.napcat_adapter_plugin.event_types import NapcatEvent - # 设置正在输入状态 - await event_manager.trigger_event(NapcatEvent.PERSONAL.SET_INPUT_STATUS,user_id=user_id,event_type=1) - - await send_api.custom_to_stream( - message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False - ) - - -async def stop_typing(): - """ - 停止打字状态指示 - - 功能说明: - - 创建内心聊天流(用于状态显示) - - 发送stop_typing状态消息 - - 不存储到消息记录中 - - 结束S4U功能的视觉反馈 - """ - group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心") - - chat = await get_chat_manager().get_or_create_stream( - platform="amaidesu_default", - user_info=None, - group_info=group_info, - ) - - await send_api.custom_to_stream( - message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False - ) diff --git a/src/chat/chat_loop/proactive/events.py b/src/chat/chat_loop/proactive/events.py deleted file mode 100644 index 89a3bc7bb..000000000 --- a/src/chat/chat_loop/proactive/events.py +++ /dev/null @@ -1,14 +0,0 @@ -from dataclasses import dataclass, field -from typing import Optional, Dict, Any - - -@dataclass -class ProactiveTriggerEvent: - """ - 主动思考触发事件的数据类 - """ - - source: str # 触发源的标识,例如 "silence_monitor", "insomnia_manager" - reason: str # 触发的具体原因,例如 "聊天已沉默10分钟", "深夜emo" - metadata: Optional[Dict[str, Any]] = field(default_factory=dict) # 可选的元数据,用于传递额外信息 - related_message_id: Optional[str] = None # 关联的消息ID,用于加载上下文 diff --git a/src/chat/chat_loop/proactive/proactive_thinker.py b/src/chat/chat_loop/proactive/proactive_thinker.py deleted file mode 100644 index 34abf7803..000000000 --- a/src/chat/chat_loop/proactive/proactive_thinker.py +++ /dev/null @@ -1,264 +0,0 @@ -import time -import traceback -from typing import TYPE_CHECKING, Dict, Any - -from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages_with_id -from src.common.database.sqlalchemy_database_api import store_action_info -from src.common.logger import get_logger -from src.config.config import global_config -from src.mood.mood_manager import mood_manager -from src.plugin_system import tool_api -from src.plugin_system.apis import generator_api -from src.plugin_system.apis.generator_api import process_human_text -from src.plugin_system.base.component_types import ChatMode -from src.schedule.schedule_manager import schedule_manager -from .events import ProactiveTriggerEvent -from ..hfc_context import HfcContext - -if TYPE_CHECKING: - from ..cycle_processor import CycleProcessor - -logger = get_logger("hfc") - - -class ProactiveThinker: - """ - 主动思考器,负责处理和执行主动思考事件。 - 当接收到 ProactiveTriggerEvent 时,它会根据事件内容进行一系列决策和操作, - 例如调整情绪、调用规划器生成行动,并最终可能产生一个主动的回复。 - """ - - def __init__(self, context: HfcContext, cycle_processor: "CycleProcessor"): - """ - 初始化主动思考器。 - - Args: - context (HfcContext): HFC聊天上下文对象,提供了当前聊天会话的所有背景信息。 - cycle_processor (CycleProcessor): 循环处理器,用于执行主动思考后产生的动作。 - - 功能说明: - - 接收并处理主动思考事件 (ProactiveTriggerEvent)。 - - 在思考前根据事件类型执行预处理操作,如修改当前情绪状态。 - - 调用行动规划器 (Action Planner) 来决定下一步应该做什么。 - - 如果规划结果是发送消息,则调用生成器API生成回复并发送。 - """ - self.context = context - self.cycle_processor = cycle_processor - - async def think(self, trigger_event: ProactiveTriggerEvent): - """ - 主动思考的统一入口API。 - 这是外部触发主动思考时调用的主要方法。 - - Args: - trigger_event (ProactiveTriggerEvent): 描述触发上下文的事件对象,包含了思考的来源和原因。 - """ - logger.info( - f"{self.context.log_prefix} 接收到主动思考事件: " - f"来源='{trigger_event.source}', 原因='{trigger_event.reason}'" - ) - - try: - # 步骤 1: 根据事件类型执行思考前的准备工作,例如调整情绪。 - await self._prepare_for_thinking(trigger_event) - - # 步骤 2: 执行核心的思考和决策逻辑。 - await self._execute_proactive_thinking(trigger_event) - - except Exception as e: - # 捕获并记录在思考过程中发生的任何异常。 - logger.error(f"{self.context.log_prefix} 主动思考 think 方法执行异常: {e}") - logger.error(traceback.format_exc()) - - async def _prepare_for_thinking(self, trigger_event: ProactiveTriggerEvent): - """ - 根据事件类型,在正式思考前执行准备工作。 - 目前主要是处理来自失眠管理器的事件,并据此调整情绪。 - - Args: - trigger_event (ProactiveTriggerEvent): 触发事件。 - """ - # 目前只处理来自失眠管理器(insomnia_manager)的事件 - if trigger_event.source != "insomnia_manager": - return - - try: - # 获取当前聊天的情绪对象 - mood_obj = mood_manager.get_mood_by_chat_id(self.context.stream_id) - new_mood = None - - # 根据失眠的不同原因设置对应的情绪 - if trigger_event.reason == "low_pressure": - new_mood = "精力过剩,毫无睡意" - elif trigger_event.reason == "random": - new_mood = "深夜emo,胡思乱想" - elif trigger_event.reason == "goodnight": - new_mood = "有点困了,准备睡觉了" - - # 如果成功匹配到了新的情绪,则更新情绪状态 - if new_mood: - mood_obj.mood_state = new_mood - mood_obj.last_change_time = time.time() - logger.info( - f"{self.context.log_prefix} 因 '{trigger_event.reason}'," - f"情绪状态被强制更新为: {mood_obj.mood_state}" - ) - - except Exception as e: - logger.error(f"{self.context.log_prefix} 设置失眠情绪时出错: {e}") - - async def _execute_proactive_thinking(self, trigger_event: ProactiveTriggerEvent): - """ - 执行主动思考的核心逻辑。 - 它会调用规划器来决定是否要采取行动,以及采取什么行动。 - - Args: - trigger_event (ProactiveTriggerEvent): 触发事件。 - """ - try: - actions, _ = await self.cycle_processor.action_planner.plan(mode=ChatMode.PROACTIVE) - action_result = actions[0] if actions else {} - action_type = action_result.get("action_type") - - if action_type is None: - logger.info(f"{self.context.log_prefix} 主动思考决策: 规划器未返回有效动作") - return - - if action_type == "proactive_reply": - await self._generate_proactive_content_and_send(action_result, trigger_event) - elif action_type not in ["do_nothing", "no_action"]: - await self.cycle_processor._handle_action( - action=action_result["action_type"], - reasoning=action_result.get("reasoning", ""), - action_data=action_result.get("action_data", {}), - cycle_timers={}, - thinking_id="", - action_message=action_result.get("action_message") - ) - else: - logger.info(f"{self.context.log_prefix} 主动思考决策: 保持沉默") - - except Exception as e: - logger.error(f"{self.context.log_prefix} 主动思考执行异常: {e}") - logger.error(traceback.format_exc()) - - - async def _generate_proactive_content_and_send(self, action_result: Dict[str, Any], trigger_event: ProactiveTriggerEvent): - """ - 获取实时信息,构建最终的生成提示词,并生成和发送主动回复。 - - Args: - action_result (Dict[str, Any]): 规划器返回的动作结果。 - trigger_event (ProactiveTriggerEvent): 触发事件。 - """ - try: - topic = action_result.get("action_data", {}).get("topic", "随便聊聊") - logger.info(f"{self.context.log_prefix} 主动思考确定主题: '{topic}'") - - schedule_block = "你今天没有日程安排。" - if global_config.planning_system.schedule_enable: - if current_activity := schedule_manager.get_current_activity(): - schedule_block = f"你当前正在:{current_activity}。" - - news_block = "暂时没有获取到最新资讯。" - if trigger_event.source != "reminder_system": - try: - web_search_tool = tool_api.get_tool_instance("web_search") - if web_search_tool: - try: - search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10}) - except TypeError: - try: - search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10}) - except TypeError: - logger.warning(f"{self.context.log_prefix} 网络搜索工具参数不匹配,跳过搜索") - news_block = "跳过网络搜索。" - search_result_dict = None - - if search_result_dict and not search_result_dict.get("error"): - news_block = search_result_dict.get("content", "未能提取有效资讯。") - elif search_result_dict: - logger.warning(f"{self.context.log_prefix} 网络搜索返回错误: {search_result_dict.get('error')}") - else: - logger.warning(f"{self.context.log_prefix} 未找到 web_search 工具实例。") - except Exception as e: - logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}") - message_list = get_raw_msg_before_timestamp_with_chat( - chat_id=self.context.stream_id, - timestamp=time.time(), - limit=int(global_config.chat.max_context_size * 0.3), - ) - chat_context_block, _ = await build_readable_messages_with_id(messages=message_list) - bot_name = global_config.bot.nickname - personality = global_config.personality - identity_block = ( - f"你的名字是{bot_name}。\n" - f"关于你:{personality.personality_core},并且{personality.personality_side}。\n" - f"你的身份是{personality.identity},平时说话风格是{personality.reply_style}。" - ) - mood_block = f"你现在的心情是:{mood_manager.get_mood_by_chat_id(self.context.stream_id).mood_state}" - - final_prompt = f""" -## 你的角色 -{identity_block} - -## 你的心情 -{mood_block} - -## 你今天的日程安排 -{schedule_block} - -## 关于你准备讨论的话题"{topic}"的最新信息 -{news_block} - -## 最近的聊天内容 -{chat_context_block} - -## 任务 -你现在想要主动说些什么。话题是"{topic}",但这只是一个参考方向。 - -根据最近的聊天内容,你可以: -- 如果是想关心朋友,就自然地询问他们的情况 -- 如果想起了之前的话题,就问问后来怎么样了 -- 如果有什么想分享的想法,就自然地开启话题 -- 如果只是想闲聊,就随意地说些什么 - -**重要**:如果获取到了最新的网络信息(news_block不为空),请**自然地**将这些信息融入你的回复中,作为话题的补充或引子,而不是生硬地复述。 - -## 要求 -- 像真正的朋友一样,自然地表达关心或好奇 -- 不要过于正式,要口语化和亲切 -- 结合你的角色设定,保持温暖的风格 -- 直接输出你想说的话,不要解释为什么要说 - -请输出一条简短、自然的主动发言。 -""" - - response_text = await generator_api.generate_response_custom( - chat_stream=self.context.chat_stream, - prompt=final_prompt, - request_type="chat.replyer.proactive", - ) - - if response_text: - response_set = process_human_text( - content=response_text, - enable_splitter=global_config.response_splitter.enable, - enable_chinese_typo=global_config.chinese_typo.enable, - ) - await self.cycle_processor.response_handler.send_response( - response_set, time.time(), action_result.get("action_message") - ) - await store_action_info( - chat_stream=self.context.chat_stream, - action_name="proactive_reply", - action_data={"topic": topic, "response": response_text}, - action_prompt_display=f"主动发起对话: {topic}", - action_done=True, - ) - else: - logger.error(f"{self.context.log_prefix} 主动思考生成回复失败。") - - except Exception as e: - logger.error(f"{self.context.log_prefix} 生成主动回复内容时异常: {e}") - logger.error(traceback.format_exc()) diff --git a/src/chat/chat_loop/response_handler.py b/src/chat/chat_loop/response_handler.py deleted file mode 100644 index 99f065319..000000000 --- a/src/chat/chat_loop/response_handler.py +++ /dev/null @@ -1,184 +0,0 @@ -import time -import random -from typing import Dict, Any, Tuple - -from src.common.logger import get_logger -from src.plugin_system.apis import send_api, message_api, database_api -from src.person_info.person_info import get_person_info_manager -from .hfc_context import HfcContext - -# 导入反注入系统 - -# 日志记录器 -logger = get_logger("hfc") -anti_injector_logger = get_logger("anti_injector") - - -class ResponseHandler: - """ - 响应处理器类,负责生成和发送机器人的回复。 - """ - def __init__(self, context: HfcContext): - """ - 初始化响应处理器 - - Args: - context: HFC聊天上下文对象 - - 功能说明: - - 负责生成和发送机器人的回复 - - 处理回复的格式化和发送逻辑 - - 管理回复状态和日志记录 - """ - self.context = context - - async def generate_and_send_reply( - self, - response_set, - reply_to_str, - loop_start_time, - action_message, - cycle_timers: Dict[str, float], - thinking_id, - plan_result, - ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: - """ - 生成并发送回复的主方法 - - Args: - response_set: 生成的回复内容集合 - reply_to_str: 回复目标字符串 - loop_start_time: 循环开始时间 - action_message: 动作消息数据 - cycle_timers: 循环计时器 - thinking_id: 思考ID - plan_result: 规划结果 - - Returns: - tuple: (循环信息, 回复文本, 计时器信息) - - 功能说明: - - 发送生成的回复内容 - - 存储动作信息到数据库 - - 构建并返回完整的循环信息 - - 用于上级方法的状态跟踪 - """ - reply_text = await self.send_response(response_set, loop_start_time, action_message) - - person_info_manager = get_person_info_manager() - - # 获取平台信息 - platform = "default" - if self.context.chat_stream: - platform = ( - action_message.get("chat_info_platform") - or action_message.get("user_platform") - or self.context.chat_stream.platform - ) - - # 获取用户信息并生成回复提示 - user_id = action_message.get("user_id", "") - person_id = person_info_manager.get_person_id(platform, user_id) - person_name = await person_info_manager.get_value(person_id, "person_name") - action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" - - # 存储动作信息到数据库 - await database_api.store_action_info( - chat_stream=self.context.chat_stream, - action_build_into_prompt=False, - action_prompt_display=action_prompt_display, - action_done=True, - thinking_id=thinking_id, - action_data={"reply_text": reply_text, "reply_to": reply_to_str}, - action_name="reply", - ) - - # 构建循环信息 - loop_info: Dict[str, Any] = { - "loop_plan_info": { - "action_result": plan_result.get("action_result", {}), - }, - "loop_action_info": { - "action_taken": True, - "reply_text": reply_text, - "command": "", - "taken_time": time.time(), - }, - } - - return loop_info, reply_text, cycle_timers - - async def send_response(self, reply_set, thinking_start_time, message_data) -> str: - """ - 发送回复内容的具体实现 - - Args: - reply_set: 回复内容集合,包含多个回复段 - reply_to: 回复目标 - thinking_start_time: 思考开始时间 - message_data: 消息数据 - - Returns: - str: 完整的回复文本 - - 功能说明: - - 检查是否有新消息需要回复 - - 处理主动思考的"沉默"决定 - - 根据消息数量决定是否添加回复引用 - - 逐段发送回复内容,支持打字效果 - - 正确处理元组格式的回复段 - """ - current_time = time.time() - # 计算新消息数量 - new_message_count = await message_api.count_new_messages( - chat_id=self.context.stream_id, start_time=thinking_start_time, end_time=current_time - ) - - # 根据新消息数量决定是否需要引用回复 - need_reply = new_message_count >= random.randint(2, 4) - - reply_text = "" - is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True - - first_replied = False - for reply_seg in reply_set: - # 调试日志:验证reply_seg的格式 - logger.debug(f"Processing reply_seg type: {type(reply_seg)}, content: {reply_seg}") - - # 修正:正确处理元组格式 (格式为: (type, content)) - if isinstance(reply_seg, tuple) and len(reply_seg) >= 2: - _, data = reply_seg - else: - # 向下兼容:如果已经是字符串,则直接使用 - data = str(reply_seg) - - if isinstance(data, list): - data = "".join(map(str, data)) - reply_text += data - - # 如果是主动思考且内容为“沉默”,则不发送 - if is_proactive_thinking and data.strip() == "沉默": - logger.info(f"{self.context.log_prefix} 主动思考决定保持沉默,不发送消息") - continue - - # 发送第一段回复 - if not first_replied: - await send_api.text_to_stream( - text=data, - stream_id=self.context.stream_id, - reply_to_message=message_data, - set_reply=need_reply, - typing=False, - ) - first_replied = True - else: - # 发送后续回复 - sent_message = await send_api.text_to_stream( - text=data, - stream_id=self.context.stream_id, - reply_to_message=None, - set_reply=False, - typing=True, - ) - - return reply_text diff --git a/src/chat/chat_loop/sleep_manager/notification_sender.py b/src/chat/chat_loop/sleep_manager/notification_sender.py deleted file mode 100644 index 95ee304e9..000000000 --- a/src/chat/chat_loop/sleep_manager/notification_sender.py +++ /dev/null @@ -1,32 +0,0 @@ -from src.common.logger import get_logger -from ..hfc_context import HfcContext - -logger = get_logger("notification_sender") - - -class NotificationSender: - @staticmethod - async def send_goodnight_notification(context: HfcContext): - """发送晚安通知""" - try: - from ..proactive.events import ProactiveTriggerEvent - from ..proactive.proactive_thinker import ProactiveThinker - - event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight") - proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor) - await proactive_thinker.think(event) - except Exception as e: - logger.error(f"发送晚安通知失败: {e}") - - @staticmethod - async def send_insomnia_notification(context: HfcContext, reason: str): - """发送失眠通知""" - try: - from ..proactive.events import ProactiveTriggerEvent - from ..proactive.proactive_thinker import ProactiveThinker - - event = ProactiveTriggerEvent(source="sleep_manager", reason=reason) - proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor) - await proactive_thinker.think(event) - except Exception as e: - logger.error(f"发送失眠通知失败: {e}") \ No newline at end of file diff --git a/src/chat/chat_loop/sleep_manager/sleep_state.py b/src/chat/chat_loop/sleep_manager/sleep_state.py deleted file mode 100644 index 624521ea0..000000000 --- a/src/chat/chat_loop/sleep_manager/sleep_state.py +++ /dev/null @@ -1,110 +0,0 @@ -from enum import Enum, auto -from datetime import datetime -from src.common.logger import get_logger -from src.manager.local_store_manager import local_storage - -logger = get_logger("sleep_state") - - -class SleepState(Enum): - """ - 定义了角色可能处于的几种睡眠状态。 - 这是一个状态机,用于管理角色的睡眠周期。 - """ - - AWAKE = auto() # 清醒状态 - INSOMNIA = auto() # 失眠状态 - PREPARING_SLEEP = auto() # 准备入睡状态,一个短暂的过渡期 - SLEEPING = auto() # 正在睡觉状态 - WOKEN_UP = auto() # 被吵醒状态 - - -class SleepStateSerializer: - """ - 睡眠状态序列化器。 - 负责将内存中的睡眠状态对象持久化到本地存储(如JSON文件), - 以及在程序启动时从本地存储中恢复状态。 - 这样可以确保即使程序重启,角色的睡眠状态也能得以保留。 - """ - @staticmethod - def save(state_data: dict): - """ - 将当前的睡眠状态数据保存到本地存储。 - - Args: - state_data (dict): 包含睡眠状态信息的字典。 - datetime对象会被转换为时间戳,Enum成员会被转换为其名称字符串。 - """ - try: - # 准备要序列化的数据字典 - state = { - # 保存当前状态的枚举名称 - "current_state": state_data["_current_state"].name, - # 将datetime对象转换为Unix时间戳以便序列化 - "sleep_buffer_end_time_ts": state_data["_sleep_buffer_end_time"].timestamp() - if state_data["_sleep_buffer_end_time"] - else None, - "total_delayed_minutes_today": state_data["_total_delayed_minutes_today"], - # 将date对象转换为ISO格式的字符串 - "last_sleep_check_date_str": state_data["_last_sleep_check_date"].isoformat() - if state_data["_last_sleep_check_date"] - else None, - "re_sleep_attempt_time_ts": state_data["_re_sleep_attempt_time"].timestamp() - if state_data["_re_sleep_attempt_time"] - else None, - } - # 写入本地存储 - local_storage["schedule_sleep_state"] = state - logger.debug(f"已保存睡眠状态: {state}") - except Exception as e: - logger.error(f"保存睡眠状态失败: {e}") - - @staticmethod - def load() -> dict: - """ - 从本地存储加载并解析睡眠状态。 - - Returns: - dict: 包含恢复后睡眠状态信息的字典。 - 如果加载失败或没有找到数据,则返回一个默认的清醒状态。 - """ - # 定义一个默认的状态,以防加载失败 - state_data = { - "_current_state": SleepState.AWAKE, - "_sleep_buffer_end_time": None, - "_total_delayed_minutes_today": 0, - "_last_sleep_check_date": None, - "_re_sleep_attempt_time": None, - } - try: - # 从本地存储读取数据 - state = local_storage["schedule_sleep_state"] - if state and isinstance(state, dict): - # 恢复当前状态枚举 - state_name = state.get("current_state") - if state_name and hasattr(SleepState, state_name): - state_data["_current_state"] = SleepState[state_name] - - # 从时间戳恢复datetime对象 - end_time_ts = state.get("sleep_buffer_end_time_ts") - if end_time_ts: - state_data["_sleep_buffer_end_time"] = datetime.fromtimestamp(end_time_ts) - - # 恢复重新入睡尝试时间 - re_sleep_ts = state.get("re_sleep_attempt_time_ts") - if re_sleep_ts: - state_data["_re_sleep_attempt_time"] = datetime.fromtimestamp(re_sleep_ts) - - # 恢复今日延迟睡眠总分钟数 - state_data["_total_delayed_minutes_today"] = state.get("total_delayed_minutes_today", 0) - - # 从ISO格式字符串恢复date对象 - date_str = state.get("last_sleep_check_date_str") - if date_str: - state_data["_last_sleep_check_date"] = datetime.fromisoformat(date_str).date() - - logger.info(f"成功从本地存储加载睡眠状态: {state}") - except Exception as e: - # 如果加载过程中出现任何问题,记录警告并返回默认状态 - logger.warning(f"加载睡眠状态失败,将使用默认值: {e}") - return state_data \ No newline at end of file diff --git a/src/chat/chat_loop/sleep_manager/wakeup_manager.py b/src/chat/chat_loop/sleep_manager/wakeup_manager.py deleted file mode 100644 index 28c91dd3d..000000000 --- a/src/chat/chat_loop/sleep_manager/wakeup_manager.py +++ /dev/null @@ -1,232 +0,0 @@ -import asyncio -import time -from typing import Optional -from src.common.logger import get_logger -from src.config.config import global_config -from src.manager.local_store_manager import local_storage -from ..hfc_context import HfcContext - -logger = get_logger("wakeup") - - -class WakeUpManager: - def __init__(self, context: HfcContext): - """ - 初始化唤醒度管理器 - - Args: - context: HFC聊天上下文对象 - - 功能说明: - - 管理休眠状态下的唤醒度累积 - - 处理唤醒度的自然衰减 - - 控制愤怒状态的持续时间 - """ - self.context = context - self.wakeup_value = 0.0 # 当前唤醒度 - self.is_angry = False # 是否处于愤怒状态 - self.angry_start_time = 0.0 # 愤怒状态开始时间 - self.last_decay_time = time.time() # 上次衰减时间 - self._decay_task: Optional[asyncio.Task] = None - self.last_log_time = 0 - self.log_interval = 30 - - # 从配置文件获取参数 - sleep_config = global_config.sleep_system - self.wakeup_threshold = sleep_config.wakeup_threshold - self.private_message_increment = sleep_config.private_message_increment - self.group_mention_increment = sleep_config.group_mention_increment - self.decay_rate = sleep_config.decay_rate - self.decay_interval = sleep_config.decay_interval - self.angry_duration = sleep_config.angry_duration - self.enabled = sleep_config.enable - self.angry_prompt = sleep_config.angry_prompt - - self._load_wakeup_state() - - def _get_storage_key(self) -> str: - """获取当前聊天流的本地存储键""" - return f"wakeup_manager_state_{self.context.stream_id}" - - def _load_wakeup_state(self): - """从本地存储加载状态""" - state = local_storage[self._get_storage_key()] - if state and isinstance(state, dict): - self.wakeup_value = state.get("wakeup_value", 0.0) - self.is_angry = state.get("is_angry", False) - self.angry_start_time = state.get("angry_start_time", 0.0) - logger.info(f"{self.context.log_prefix} 成功从本地存储加载唤醒状态: {state}") - else: - logger.info(f"{self.context.log_prefix} 未找到本地唤醒状态,将使用默认值初始化。") - - def _save_wakeup_state(self): - """将当前状态保存到本地存储""" - state = { - "wakeup_value": self.wakeup_value, - "is_angry": self.is_angry, - "angry_start_time": self.angry_start_time, - } - local_storage[self._get_storage_key()] = state - logger.debug(f"{self.context.log_prefix} 已将唤醒状态保存到本地存储: {state}") - - async def start(self): - """启动唤醒度管理器""" - if not self.enabled: - logger.info(f"{self.context.log_prefix} 唤醒度系统已禁用,跳过启动") - return - - if not self._decay_task: - self._decay_task = asyncio.create_task(self._decay_loop()) - self._decay_task.add_done_callback(self._handle_decay_completion) - logger.info(f"{self.context.log_prefix} 唤醒度管理器已启动") - - async def stop(self): - """停止唤醒度管理器""" - if self._decay_task and not self._decay_task.done(): - self._decay_task.cancel() - await asyncio.sleep(0) - logger.info(f"{self.context.log_prefix} 唤醒度管理器已停止") - - def _handle_decay_completion(self, task: asyncio.Task): - """处理衰减任务完成""" - try: - if exception := task.exception(): - logger.error(f"{self.context.log_prefix} 唤醒度衰减任务异常: {exception}") - else: - logger.info(f"{self.context.log_prefix} 唤醒度衰减任务正常结束") - except asyncio.CancelledError: - logger.info(f"{self.context.log_prefix} 唤醒度衰减任务被取消") - - async def _decay_loop(self): - """唤醒度衰减循环""" - while self.context.running: - await asyncio.sleep(self.decay_interval) - - current_time = time.time() - - # 检查愤怒状态是否过期 - if self.is_angry and current_time - self.angry_start_time >= self.angry_duration: - self.is_angry = False - # 通知情绪管理系统清除愤怒状态 - from src.mood.mood_manager import mood_manager - - mood_manager.clear_angry_from_wakeup(self.context.stream_id) - logger.info(f"{self.context.log_prefix} 愤怒状态结束,恢复正常") - self._save_wakeup_state() - - # 唤醒度自然衰减 - if self.wakeup_value > 0: - old_value = self.wakeup_value - self.wakeup_value = max(0, self.wakeup_value - self.decay_rate) - if old_value != self.wakeup_value: - logger.debug(f"{self.context.log_prefix} 唤醒度衰减: {old_value:.1f} -> {self.wakeup_value:.1f}") - self._save_wakeup_state() - - def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False) -> bool: - """ - 增加唤醒度值 - - Args: - is_private_chat: 是否为私聊 - is_mentioned: 是否被艾特(仅群聊有效) - - Returns: - bool: 是否达到唤醒阈值 - """ - # 如果系统未启用,直接返回 - if not self.enabled: - return False - - # 只有在休眠且非失眠状态下才累积唤醒度 - from .sleep_state import SleepState - - sleep_manager = self.context.sleep_manager - if not sleep_manager: - return False - - current_sleep_state = sleep_manager.get_current_sleep_state() - if current_sleep_state != SleepState.SLEEPING: - return False - - old_value = self.wakeup_value - - if is_private_chat: - # 私聊每条消息都增加唤醒度 - self.wakeup_value += self.private_message_increment - logger.debug(f"{self.context.log_prefix} 私聊消息增加唤醒度: +{self.private_message_increment}") - elif is_mentioned: - # 群聊只有被艾特才增加唤醒度 - self.wakeup_value += self.group_mention_increment - logger.debug(f"{self.context.log_prefix} 群聊艾特增加唤醒度: +{self.group_mention_increment}") - else: - # 群聊未被艾特,不增加唤醒度 - return False - - current_time = time.time() - if current_time - self.last_log_time > self.log_interval: - logger.info( - f"{self.context.log_prefix} 唤醒度变化: {old_value:.1f} -> {self.wakeup_value:.1f} (阈值: {self.wakeup_threshold})" - ) - self.last_log_time = current_time - else: - logger.debug( - f"{self.context.log_prefix} 唤醒度变化: {old_value:.1f} -> {self.wakeup_value:.1f} (阈值: {self.wakeup_threshold})" - ) - - # 检查是否达到唤醒阈值 - if self.wakeup_value >= self.wakeup_threshold: - self._trigger_wakeup() - return True - - self._save_wakeup_state() - return False - - def _trigger_wakeup(self): - """触发唤醒,进入愤怒状态""" - self.is_angry = True - self.angry_start_time = time.time() - self.wakeup_value = 0.0 # 重置唤醒度 - - self._save_wakeup_state() - - # 通知情绪管理系统进入愤怒状态 - from src.mood.mood_manager import mood_manager - - mood_manager.set_angry_from_wakeup(self.context.stream_id) - - # 通知SleepManager重置睡眠状态 - if self.context.sleep_manager: - self.context.sleep_manager.reset_sleep_state_after_wakeup() - - logger.info(f"{self.context.log_prefix} 唤醒度达到阈值({self.wakeup_threshold}),被吵醒进入愤怒状态!") - - def get_angry_prompt_addition(self) -> str: - """获取愤怒状态下的提示词补充""" - if self.is_angry: - return self.angry_prompt - return "" - - def is_in_angry_state(self) -> bool: - """检查是否处于愤怒状态""" - if self.is_angry: - current_time = time.time() - if current_time - self.angry_start_time >= self.angry_duration: - self.is_angry = False - # 通知情绪管理系统清除愤怒状态 - from src.mood.mood_manager import mood_manager - - mood_manager.clear_angry_from_wakeup(self.context.stream_id) - logger.info(f"{self.context.log_prefix} 愤怒状态自动过期") - return False - return self.is_angry - - def get_status_info(self) -> dict: - """获取当前状态信息""" - return { - "wakeup_value": self.wakeup_value, - "wakeup_threshold": self.wakeup_threshold, - "is_angry": self.is_angry, - "angry_remaining_time": max(0, self.angry_duration - (time.time() - self.angry_start_time)) - if self.is_angry - else 0, - } diff --git a/src/chat/chatter_manager.py b/src/chat/chatter_manager.py new file mode 100644 index 000000000..be70f4969 --- /dev/null +++ b/src/chat/chatter_manager.py @@ -0,0 +1,145 @@ +from typing import Dict, List, Optional, Any +import time +from src.plugin_system.base.base_chatter import BaseChatter +from src.common.data_models.message_manager_data_model import StreamContext +from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner as ActionPlanner +from src.chat.planner_actions.action_manager import ChatterActionManager +from src.plugin_system.base.component_types import ChatType, ComponentType +from src.common.logger import get_logger + +logger = get_logger("chatter_manager") + +class ChatterManager: + def __init__(self, action_manager: ChatterActionManager): + self.action_manager = action_manager + self.chatter_classes: Dict[ChatType, List[type]] = {} + self.instances: Dict[str, BaseChatter] = {} + + # 管理器统计 + self.stats = { + "chatters_registered": 0, + "streams_processed": 0, + "successful_executions": 0, + "failed_executions": 0, + } + + def _auto_register_from_component_registry(self): + """从组件注册表自动注册已注册的chatter组件""" + try: + from src.plugin_system.core.component_registry import component_registry + # 获取所有CHATTER类型的组件 + chatter_components = component_registry.get_enabled_chatter_registry() + for chatter_name, chatter_class in chatter_components.items(): + self.register_chatter(chatter_class) + logger.info(f"自动注册chatter组件: {chatter_name}") + except Exception as e: + logger.warning(f"自动注册chatter组件时发生错误: {e}") + + def register_chatter(self, chatter_class: type): + """注册聊天处理器类""" + for chat_type in chatter_class.chat_types: + if chat_type not in self.chatter_classes: + self.chatter_classes[chat_type] = [] + self.chatter_classes[chat_type].append(chatter_class) + logger.info(f"注册聊天处理器 {chatter_class.__name__} 支持 {chat_type.value} 聊天类型") + + self.stats["chatters_registered"] += 1 + + def get_chatter_class(self, chat_type: ChatType) -> Optional[type]: + """获取指定聊天类型的聊天处理器类""" + if chat_type in self.chatter_classes: + return self.chatter_classes[chat_type][0] + return None + + def get_supported_chat_types(self) -> List[ChatType]: + """获取支持的聊天类型列表""" + return list(self.chatter_classes.keys()) + + def get_registered_chatters(self) -> Dict[ChatType, List[type]]: + """获取已注册的聊天处理器""" + return self.chatter_classes.copy() + + def get_stream_instance(self, stream_id: str) -> Optional[BaseChatter]: + """获取指定流的聊天处理器实例""" + return self.instances.get(stream_id) + + def cleanup_inactive_instances(self, max_inactive_minutes: int = 60): + """清理不活跃的实例""" + current_time = time.time() + max_inactive_seconds = max_inactive_minutes * 60 + + inactive_streams = [] + for stream_id, instance in self.instances.items(): + if hasattr(instance, 'get_activity_time'): + activity_time = instance.get_activity_time() + if (current_time - activity_time) > max_inactive_seconds: + inactive_streams.append(stream_id) + + for stream_id in inactive_streams: + del self.instances[stream_id] + logger.info(f"清理不活跃聊天流实例: {stream_id}") + + async def process_stream_context(self, stream_id: str, context: StreamContext) -> dict: + """处理流上下文""" + chat_type = context.chat_type + logger.debug(f"处理流 {stream_id},聊天类型: {chat_type.value}") + if not self.chatter_classes: + self._auto_register_from_component_registry() + + # 获取适合该聊天类型的chatter + chatter_class = self.get_chatter_class(chat_type) + if not chatter_class: + # 如果没有找到精确匹配,尝试查找支持ALL类型的chatter + from src.plugin_system.base.component_types import ChatType + all_chatter_class = self.get_chatter_class(ChatType.ALL) + if all_chatter_class: + chatter_class = all_chatter_class + logger.info(f"流 {stream_id} 使用通用chatter (类型: {chat_type.value})") + else: + raise ValueError(f"No chatter registered for chat type {chat_type}") + + if stream_id not in self.instances: + self.instances[stream_id] = chatter_class(stream_id=stream_id, action_manager=self.action_manager) + logger.info(f"创建新的聊天流实例: {stream_id} 使用 {chatter_class.__name__} (类型: {chat_type.value})") + + self.stats["streams_processed"] += 1 + try: + result = await self.instances[stream_id].execute(context) + self.stats["successful_executions"] += 1 + + # 从 mood_manager 获取最新的 chat_stream 并同步回 StreamContext + try: + from src.mood.mood_manager import mood_manager + mood = mood_manager.get_mood_by_chat_id(stream_id) + if mood and mood.chat_stream: + context.chat_stream = mood.chat_stream + logger.debug(f"已将最新的 chat_stream 同步回流 {stream_id} 的 StreamContext") + except Exception as sync_e: + logger.error(f"同步 chat_stream 回 StreamContext 失败: {sync_e}") + + # 记录处理结果 + success = result.get("success", False) + actions_count = result.get("actions_count", 0) + logger.debug(f"流 {stream_id} 处理完成: 成功={success}, 动作数={actions_count}") + + return result + except Exception as e: + self.stats["failed_executions"] += 1 + logger.error(f"处理流 {stream_id} 时发生错误: {e}") + raise + + def get_stats(self) -> Dict[str, Any]: + """获取管理器统计信息""" + stats = self.stats.copy() + stats["active_instances"] = len(self.instances) + stats["registered_chatter_types"] = len(self.chatter_classes) + return stats + + def reset_stats(self): + """重置统计信息""" + self.stats = { + "chatters_registered": 0, + "streams_processed": 0, + "successful_executions": 0, + "failed_executions": 0, + } \ No newline at end of file diff --git a/src/chat/emoji_system/emoji_history.py b/src/chat/emoji_system/emoji_history.py index d0e2ca856..804f61e0a 100644 --- a/src/chat/emoji_system/emoji_history.py +++ b/src/chat/emoji_system/emoji_history.py @@ -2,8 +2,10 @@ """ 表情包发送历史记录模块 """ -from collections import deque + +import os from typing import List, Dict +from collections import deque from src.common.logger import get_logger @@ -25,15 +27,15 @@ def add_emoji_to_history(chat_id: str, emoji_description: str): """ if not chat_id or not emoji_description: return - + # 如果当前聊天还没有历史记录,则创建一个新的 deque if chat_id not in _history_cache: _history_cache[chat_id] = deque(maxlen=MAX_HISTORY_SIZE) - + # 添加新表情到历史记录 history = _history_cache[chat_id] history.append(emoji_description) - + logger.debug(f"已将表情 '{emoji_description}' 添加到聊天 {chat_id} 的内存历史中") @@ -49,10 +51,10 @@ def get_recent_emojis(chat_id: str, limit: int = 5) -> List[str]: return [] history = _history_cache[chat_id] - + # 从 deque 的右侧(即最近添加的)开始取 num_to_get = min(limit, len(history)) recent_emojis = [history[-i] for i in range(1, num_to_get + 1)] - + logger.debug(f"为聊天 {chat_id} 从内存中获取到最近 {len(recent_emojis)} 个表情: {recent_emojis}") return recent_emojis diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index e2a6eb7f1..b614345f0 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -149,7 +149,7 @@ class MaiEmoji: # --- 数据库操作 --- try: # 准备数据库记录 for emoji collection - async with get_db_session() as session: + with get_db_session() as session: emotion_str = ",".join(self.emotion) if self.emotion else "" emoji = Emoji( @@ -167,7 +167,7 @@ class MaiEmoji: last_used_time=self.last_used_time, ) session.add(emoji) - await session.commit() + session.commit() logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") @@ -203,17 +203,17 @@ class MaiEmoji: # 2. 删除数据库记录 try: - async with get_db_session() as session: - will_delete_emoji = ( - await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)) + with get_db_session() as session: + will_delete_emoji = session.execute( + select(Emoji).where(Emoji.emoji_hash == self.hash) ).scalar_one_or_none() if will_delete_emoji is None: logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") - result = 0 + result = 0 # Indicate no DB record was deleted else: - await session.delete(will_delete_emoji) - result = 1 - await session.commit() + session.delete(will_delete_emoji) + result = 1 # Successfully deleted one record + session.commit() except Exception as e: logger.error(f"[错误] 删除数据库记录时出错: {str(e)}") result = 0 @@ -424,19 +424,17 @@ class EmojiManager: # if not self._initialized: # raise RuntimeError("EmojiManager not initialized") - @staticmethod - async def record_usage(emoji_hash: str) -> None: + def record_usage(self, emoji_hash: str) -> None: """记录表情使用次数""" try: - async with get_db_session() as session: - emoji_update = ( - await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)) - ).scalar_one_or_none() + with get_db_session() as session: + emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none() if emoji_update is None: logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") else: emoji_update.usage_count += 1 - emoji_update.last_used_time = time.time() + emoji_update.last_used_time = time.time() # Update last used time + session.commit() except Exception as e: logger.error(f"记录表情使用失败: {str(e)}") @@ -479,7 +477,7 @@ class EmojiManager: emoji_options_str = "" for i, emoji in enumerate(candidate_emojis): # 为每个表情包创建一个编号和它的详细描述 - emoji_options_str += f"编号: {i+1}\n描述: {emoji.description}\n\n" + emoji_options_str += f"编号: {i + 1}\n描述: {emoji.description}\n\n" # 精心设计的prompt,引导LLM做出选择 prompt = f""" @@ -523,13 +521,11 @@ class EmojiManager: # 7. 获取选中的表情包并更新使用记录 selected_emoji = candidate_emojis[selected_index] - await self.record_usage(selected_emoji.emoji_hash) + self.record_usage(selected_emoji.hash) _time_end = time.time() - logger.info( - f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s" - ) - + logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s") + # 8. 返回选中的表情包信息 return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion @@ -629,8 +625,9 @@ class EmojiManager: # 无论steal_emoji是否开启,都检查emoji文件夹以支持手动注册 # 只有在需要腾出空间或填充表情库时,才真正执行注册 - if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or \ - (self.emoji_num < self.emoji_num_max): + if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or ( + self.emoji_num < self.emoji_num_max + ): try: # 获取目录下所有图片文件 files_to_process = [ @@ -660,11 +657,10 @@ class EmojiManager: async def get_all_emoji_from_db(self) -> None: """获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects""" try: - async with get_db_session() as session: + with get_db_session() as session: logger.debug("[数据库] 开始加载所有表情包记录 ...") - result = await session.execute(select(Emoji)) - emoji_instances = result.scalars().all() + emoji_instances = session.execute(select(Emoji)).scalars().all() emoji_objects, load_errors = _to_emoji_objects(emoji_instances) # 更新内存中的列表和数量 @@ -680,8 +676,7 @@ class EmojiManager: self.emoji_objects = [] # 加载失败则清空列表 self.emoji_num = 0 - @staticmethod - async def get_emoji_from_db(emoji_hash: Optional[str] = None) -> List["MaiEmoji"]: + async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]: """获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找) 参数: @@ -691,16 +686,14 @@ class EmojiManager: list[MaiEmoji]: 表情包对象列表 """ try: - async with get_db_session() as session: + with get_db_session() as session: if emoji_hash: - result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)) - query = result.scalars().all() + query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all() else: logger.warning( "[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。" ) - result = await session.execute(select(Emoji)) - query = result.scalars().all() + query = session.execute(select(Emoji)).scalars().all() emoji_instances = query emoji_objects, load_errors = _to_emoji_objects(emoji_instances) @@ -748,8 +741,8 @@ class EmojiManager: try: emoji_record = await self.get_emoji_from_db(emoji_hash) if emoji_record and emoji_record[0].emotion: - logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record[0].emotion[:50]}...") - return emoji_record[0].emotion + logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...") + return emoji_record.emotion except Exception as e: logger.error(f"从数据库查询表情包描述时出错: {e}") @@ -777,11 +770,10 @@ class EmojiManager: # 如果内存中没有,从数据库查找 try: - async with get_db_session() as session: - result = await session.execute( + with get_db_session() as session: + emoji_record = session.execute( select(Emoji).where(Emoji.emoji_hash == emoji_hash) - ) - emoji_record = result.scalar_one_or_none() + ).scalar_one_or_none() if emoji_record and emoji_record.description: logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...") return emoji_record.description @@ -938,19 +930,21 @@ class EmojiManager: image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() if Image.open(io.BytesIO(image_bytes)).format else "jpeg" - + image_format = ( + Image.open(io.BytesIO(image_bytes)).format.lower() + if Image.open(io.BytesIO(image_bytes)).format + else "jpeg" + ) # 2. 检查数据库中是否已存在该表情包的描述,实现复用 existing_description = None try: - async with get_db_session() as session: - result = await session.execute( - select(Images).filter( - (Images.emoji_hash == image_hash) & (Images.type == "emoji") - ) + with get_db_session() as session: + existing_image = ( + session.query(Images) + .filter((Images.emoji_hash == image_hash) & (Images.type == "emoji")) + .one_or_none() ) - existing_image = result.scalar_one_or_none() if existing_image and existing_image.description: existing_description = existing_image.description logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...") diff --git a/src/chat/energy_system/__init__.py b/src/chat/energy_system/__init__.py new file mode 100644 index 000000000..0addfd070 --- /dev/null +++ b/src/chat/energy_system/__init__.py @@ -0,0 +1,28 @@ +""" +能量系统模块 +提供稳定、高效的聊天流能量计算和管理功能 +""" + +from .energy_manager import ( + EnergyManager, + EnergyLevel, + EnergyComponent, + EnergyCalculator, + InterestEnergyCalculator, + ActivityEnergyCalculator, + RecencyEnergyCalculator, + RelationshipEnergyCalculator, + energy_manager +) + +__all__ = [ + "EnergyManager", + "EnergyLevel", + "EnergyComponent", + "EnergyCalculator", + "InterestEnergyCalculator", + "ActivityEnergyCalculator", + "RecencyEnergyCalculator", + "RelationshipEnergyCalculator", + "energy_manager" +] \ No newline at end of file diff --git a/src/chat/energy_system/energy_manager.py b/src/chat/energy_system/energy_manager.py new file mode 100644 index 000000000..8ee2017cb --- /dev/null +++ b/src/chat/energy_system/energy_manager.py @@ -0,0 +1,473 @@ +""" +重构后的 focus_energy 管理系统 +提供稳定、高效的聊天流能量计算和管理功能 +""" + +import time +from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict +from dataclasses import dataclass, field +from enum import Enum +from abc import ABC, abstractmethod + +from src.common.logger import get_logger +from src.config.config import global_config + +logger = get_logger("energy_system") + + +class EnergyLevel(Enum): + """能量等级""" + VERY_LOW = 0.1 # 非常低 + LOW = 0.3 # 低 + NORMAL = 0.5 # 正常 + HIGH = 0.7 # 高 + VERY_HIGH = 0.9 # 非常高 + + +@dataclass +class EnergyComponent: + """能量组件""" + name: str + value: float + weight: float = 1.0 + decay_rate: float = 0.05 # 衰减率 + last_updated: float = field(default_factory=time.time) + + def get_current_value(self) -> float: + """获取当前值(考虑时间衰减)""" + age = time.time() - self.last_updated + decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减 + return self.value * decay_factor + + def update_value(self, new_value: float) -> None: + """更新值""" + self.value = max(0.0, min(1.0, new_value)) + self.last_updated = time.time() + + +class EnergyContext(TypedDict): + """能量计算上下文""" + stream_id: str + messages: List[Any] + user_id: Optional[str] + + +class EnergyResult(TypedDict): + """能量计算结果""" + energy: float + level: EnergyLevel + distribution_interval: float + component_scores: Dict[str, float] + cached: bool + + +class EnergyCalculator(ABC): + """能量计算器抽象基类""" + + @abstractmethod + def calculate(self, context: Dict[str, Any]) -> float: + """计算能量值""" + pass + + @abstractmethod + def get_weight(self) -> float: + """获取权重""" + pass + + +class InterestEnergyCalculator(EnergyCalculator): + """兴趣度能量计算器""" + + def calculate(self, context: Dict[str, Any]) -> float: + """基于消息兴趣度计算能量""" + messages = context.get("messages", []) + if not messages: + return 0.3 + + # 计算平均兴趣度 + total_interest = 0.0 + valid_messages = 0 + + for msg in messages: + interest_value = getattr(msg, "interest_value", None) + if interest_value is not None: + try: + interest_float = float(interest_value) + if 0.0 <= interest_float <= 1.0: + total_interest += interest_float + valid_messages += 1 + except (ValueError, TypeError): + continue + + if valid_messages > 0: + avg_interest = total_interest / valid_messages + logger.debug(f"平均消息兴趣度: {avg_interest:.3f} (基于 {valid_messages} 条消息)") + return avg_interest + else: + return 0.3 + + def get_weight(self) -> float: + return 0.5 + + +class ActivityEnergyCalculator(EnergyCalculator): + """活跃度能量计算器""" + + def __init__(self): + self.action_weights = { + "reply": 0.4, + "react": 0.3, + "mention": 0.2, + "other": 0.1 + } + + def calculate(self, context: Dict[str, Any]) -> float: + """基于活跃度计算能量""" + messages = context.get("messages", []) + if not messages: + return 0.2 + + total_score = 0.0 + max_possible_score = len(messages) * 0.4 # 最高可能分数 + + for msg in messages: + actions = getattr(msg, "actions", []) + if isinstance(actions, list) and actions: + for action in actions: + weight = self.action_weights.get(action, self.action_weights["other"]) + total_score += weight + + if max_possible_score > 0: + activity_score = min(1.0, total_score / max_possible_score) + logger.debug(f"活跃度分数: {activity_score:.3f}") + return activity_score + else: + return 0.2 + + def get_weight(self) -> float: + return 0.3 + + +class RecencyEnergyCalculator(EnergyCalculator): + """最近性能量计算器""" + + def calculate(self, context: Dict[str, Any]) -> float: + """基于最近性计算能量""" + messages = context.get("messages", []) + if not messages: + return 0.1 + + # 获取最新消息时间 + latest_time = 0.0 + for msg in messages: + msg_time = getattr(msg, "time", None) + if msg_time and msg_time > latest_time: + latest_time = msg_time + + if latest_time == 0.0: + return 0.1 + + # 计算时间衰减 + current_time = time.time() + age = current_time - latest_time + + # 时间衰减策略: + # 1小时内:1.0 + # 1-6小时:0.8 + # 6-24小时:0.5 + # 1-7天:0.3 + # 7天以上:0.1 + if age < 3600: # 1小时内 + recency_score = 1.0 + elif age < 6 * 3600: # 6小时内 + recency_score = 0.8 + elif age < 24 * 3600: # 24小时内 + recency_score = 0.5 + elif age < 7 * 24 * 3600: # 7天内 + recency_score = 0.3 + else: + recency_score = 0.1 + + logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age/3600:.1f}小时)") + return recency_score + + def get_weight(self) -> float: + return 0.2 + + +class RelationshipEnergyCalculator(EnergyCalculator): + """关系能量计算器""" + + def calculate(self, context: Dict[str, Any]) -> float: + """基于关系计算能量""" + user_id = context.get("user_id") + if not user_id: + return 0.3 + + # 使用插件内部的兴趣度评分系统获取关系分 + try: + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system + + relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id) + logger.debug(f"使用插件内部系统计算关系分: {relationship_score:.3f}") + return max(0.0, min(1.0, relationship_score)) + + except Exception as e: + logger.warning(f"插件内部关系分计算失败,使用默认值: {e}") + return 0.3 # 默认基础分 + + def get_weight(self) -> float: + return 0.1 + + +class EnergyManager: + """能量管理器 - 统一管理所有能量计算""" + + def __init__(self) -> None: + self.calculators: List[EnergyCalculator] = [ + InterestEnergyCalculator(), + ActivityEnergyCalculator(), + RecencyEnergyCalculator(), + RelationshipEnergyCalculator(), + ] + + # 能量缓存 + self.energy_cache: Dict[str, Tuple[float, float]] = {} # stream_id -> (energy, timestamp) + self.cache_ttl: int = 60 # 1分钟缓存 + + # AFC阈值配置 + self.thresholds: Dict[str, float] = { + "high_match": 0.8, + "reply": 0.4, + "non_reply": 0.2 + } + + # 统计信息 + self.stats: Dict[str, Union[int, float, str]] = { + "total_calculations": 0, + "cache_hits": 0, + "cache_misses": 0, + "average_calculation_time": 0.0, + "last_threshold_update": time.time(), + } + + # 从配置加载阈值 + self._load_thresholds_from_config() + + logger.info("能量管理器初始化完成") + + def _load_thresholds_from_config(self) -> None: + """从配置加载AFC阈值""" + try: + if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None: + self.thresholds["high_match"] = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8) + self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4) + self.thresholds["non_reply"] = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2) + + # 确保阈值关系合理 + self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1) + self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1) + + self.stats["last_threshold_update"] = time.time() + logger.info(f"加载AFC阈值: {self.thresholds}") + except Exception as e: + logger.warning(f"加载AFC阈值失败,使用默认值: {e}") + + def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float: + """计算聊天流的focus_energy""" + start_time = time.time() + + # 更新统计 + self.stats["total_calculations"] += 1 + + # 检查缓存 + if stream_id in self.energy_cache: + cached_energy, cached_time = self.energy_cache[stream_id] + if time.time() - cached_time < self.cache_ttl: + self.stats["cache_hits"] += 1 + logger.debug(f"使用缓存能量: {stream_id} = {cached_energy:.3f}") + return cached_energy + else: + self.stats["cache_misses"] += 1 + + # 构建计算上下文 + context: EnergyContext = { + "stream_id": stream_id, + "messages": messages, + "user_id": user_id, + } + + # 计算各组件能量 + component_scores: Dict[str, float] = {} + total_weight = 0.0 + + for calculator in self.calculators: + try: + score = calculator.calculate(context) + weight = calculator.get_weight() + + component_scores[calculator.__class__.__name__] = score + total_weight += weight + + logger.debug(f"{calculator.__class__.__name__} 能量: {score:.3f} (权重: {weight:.3f})") + + except Exception as e: + logger.warning(f"计算 {calculator.__class__.__name__} 能量失败: {e}") + + # 加权计算总能量 + if total_weight > 0: + total_energy = 0.0 + for calculator in self.calculators: + if calculator.__class__.__name__ in component_scores: + score = component_scores[calculator.__class__.__name__] + weight = calculator.get_weight() + total_energy += score * (weight / total_weight) + else: + total_energy = 0.5 + + # 应用阈值调整和变换 + final_energy = self._apply_threshold_adjustment(total_energy) + + # 缓存结果 + self.energy_cache[stream_id] = (final_energy, time.time()) + + # 清理过期缓存 + self._cleanup_cache() + + # 更新平均计算时间 + calculation_time = time.time() - start_time + total_calculations = self.stats["total_calculations"] + self.stats["average_calculation_time"] = ( + (self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time) + / total_calculations + ) + + logger.info(f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)") + return final_energy + + def _apply_threshold_adjustment(self, energy: float) -> float: + """应用阈值调整和变换""" + # 获取参考阈值 + high_threshold = self.thresholds["high_match"] + reply_threshold = self.thresholds["reply"] + + # 计算与阈值的相对位置 + if energy >= high_threshold: + # 高能量区域:指数增强 + adjusted = 0.7 + (energy - 0.7) ** 0.8 + elif energy >= reply_threshold: + # 中等能量区域:线性保持 + adjusted = energy + else: + # 低能量区域:对数压缩 + adjusted = 0.4 * (energy / 0.4) ** 1.2 + + # 确保在合理范围内 + return max(0.1, min(1.0, adjusted)) + + def get_energy_level(self, energy: float) -> EnergyLevel: + """获取能量等级""" + if energy >= EnergyLevel.VERY_HIGH.value: + return EnergyLevel.VERY_HIGH + elif energy >= EnergyLevel.HIGH.value: + return EnergyLevel.HIGH + elif energy >= EnergyLevel.NORMAL.value: + return EnergyLevel.NORMAL + elif energy >= EnergyLevel.LOW.value: + return EnergyLevel.LOW + else: + return EnergyLevel.VERY_LOW + + def get_distribution_interval(self, energy: float) -> float: + """基于能量等级获取分发周期""" + energy_level = self.get_energy_level(energy) + + # 根据能量等级确定基础分发周期 + if energy_level == EnergyLevel.VERY_HIGH: + base_interval = 1.0 # 1秒 + elif energy_level == EnergyLevel.HIGH: + base_interval = 3.0 # 3秒 + elif energy_level == EnergyLevel.NORMAL: + base_interval = 8.0 # 8秒 + elif energy_level == EnergyLevel.LOW: + base_interval = 15.0 # 15秒 + else: + base_interval = 30.0 # 30秒 + + # 添加随机扰动避免同步 + import random + jitter = random.uniform(0.8, 1.2) + final_interval = base_interval * jitter + + # 确保在配置范围内 + min_interval = getattr(global_config.chat, "dynamic_distribution_min_interval", 1.0) + max_interval = getattr(global_config.chat, "dynamic_distribution_max_interval", 60.0) + + return max(min_interval, min(max_interval, final_interval)) + + def invalidate_cache(self, stream_id: str) -> None: + """失效指定流的缓存""" + if stream_id in self.energy_cache: + del self.energy_cache[stream_id] + logger.debug(f"已清除聊天流 {stream_id} 的能量缓存") + + def _cleanup_cache(self) -> None: + """清理过期缓存""" + current_time = time.time() + expired_keys = [ + stream_id for stream_id, (_, timestamp) in self.energy_cache.items() + if current_time - timestamp > self.cache_ttl + ] + + for key in expired_keys: + del self.energy_cache[key] + + if expired_keys: + logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存") + + def get_statistics(self) -> Dict[str, Any]: + """获取统计信息""" + return { + "cache_size": len(self.energy_cache), + "calculators": [calc.__class__.__name__ for calc in self.calculators], + "thresholds": self.thresholds, + "performance_stats": self.stats.copy(), + } + + def update_thresholds(self, new_thresholds: Dict[str, float]) -> None: + """更新阈值""" + self.thresholds.update(new_thresholds) + + # 确保阈值关系合理 + self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1) + self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1) + + self.stats["last_threshold_update"] = time.time() + logger.info(f"更新AFC阈值: {self.thresholds}") + + def add_calculator(self, calculator: EnergyCalculator) -> None: + """添加计算器""" + self.calculators.append(calculator) + logger.info(f"添加能量计算器: {calculator.__class__.__name__}") + + def remove_calculator(self, calculator: EnergyCalculator) -> None: + """移除计算器""" + if calculator in self.calculators: + self.calculators.remove(calculator) + logger.info(f"移除能量计算器: {calculator.__class__.__name__}") + + def clear_cache(self) -> None: + """清空缓存""" + self.energy_cache.clear() + logger.info("清空能量缓存") + + def get_cache_hit_rate(self) -> float: + """获取缓存命中率""" + total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0) + if total_requests == 0: + return 0.0 + return self.stats["cache_hits"] / total_requests + + +# 全局能量管理器实例 +energy_manager = EnergyManager() \ No newline at end of file diff --git a/src/chat/frequency_analyzer/analyzer.py b/src/chat/frequency_analyzer/analyzer.py index f888b9737..1493c47ea 100644 --- a/src/chat/frequency_analyzer/analyzer.py +++ b/src/chat/frequency_analyzer/analyzer.py @@ -14,6 +14,7 @@ Chat Frequency Analyzer - MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。 - MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。 """ + import time as time_module from datetime import datetime, timedelta, time from typing import List, Tuple, Optional @@ -72,12 +73,14 @@ class ChatFrequencyAnalyzer: current_window_end = datetimes[i] # 合并重叠或相邻的高峰时段 - if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(hours=MIN_GAP_BETWEEN_PEAKS_HOURS): + if peak_windows and current_window_start - peak_windows[-1][1] < timedelta( + hours=MIN_GAP_BETWEEN_PEAKS_HOURS + ): # 扩展上一个窗口的结束时间 peak_windows[-1] = (peak_windows[-1][0], current_window_end) else: peak_windows.append((current_window_start, current_window_end)) - + return peak_windows def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]: @@ -100,7 +103,7 @@ class ChatFrequencyAnalyzer: return [] peak_datetime_windows = self._find_peak_windows(timestamps) - + # 将 datetime 窗口转换为 time 窗口,并进行归一化处理 peak_time_windows = [] for start_dt, end_dt in peak_datetime_windows: @@ -110,7 +113,7 @@ class ChatFrequencyAnalyzer: # 更新缓存 self._analysis_cache[chat_id] = (time_module.time(), peak_time_windows) - + return peak_time_windows def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool: @@ -126,7 +129,7 @@ class ChatFrequencyAnalyzer: """ if now is None: now = datetime.now() - + now_time = now.time() peak_times = self.get_peak_chat_times(chat_id) @@ -137,7 +140,7 @@ class ChatFrequencyAnalyzer: else: # 跨天 if now_time >= start_time or now_time <= end_time: return True - + return False diff --git a/src/chat/frequency_analyzer/tracker.py b/src/chat/frequency_analyzer/tracker.py index 178435528..3621cb5b4 100644 --- a/src/chat/frequency_analyzer/tracker.py +++ b/src/chat/frequency_analyzer/tracker.py @@ -56,7 +56,7 @@ class ChatFrequencyTracker: now = time.time() if chat_id not in self._timestamps: self._timestamps[chat_id] = [] - + self._timestamps[chat_id].append(now) logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}") self._save_timestamps() diff --git a/src/chat/frequency_analyzer/trigger.py b/src/chat/frequency_analyzer/trigger.py index d62547306..2d8e8b56f 100644 --- a/src/chat/frequency_analyzer/trigger.py +++ b/src/chat/frequency_analyzer/trigger.py @@ -14,15 +14,16 @@ Frequency-Based Proactive Trigger - TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。 - COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。 """ + import asyncio import time from datetime import datetime from typing import Dict, Optional from src.common.logger import get_logger -from src.chat.chat_loop.proactive.events import ProactiveTriggerEvent -from src.chat.heart_flow.heartflow import heartflow -from src.chat.chat_loop.sleep_manager.sleep_manager import SleepManager +# AFC manager has been moved to chatter plugin + +# TODO: 需要重新实现主动思考和睡眠管理功能 from .analyzer import chat_frequency_analyzer logger = get_logger("FrequencyBasedTrigger") @@ -39,8 +40,8 @@ class FrequencyBasedTrigger: 一个周期性任务,根据聊天频率分析结果来触发主动思考。 """ - def __init__(self, sleep_manager: SleepManager): - self._sleep_manager = sleep_manager + def __init__(self): + # TODO: 需要重新实现睡眠管理器 self._task: Optional[asyncio.Task] = None # 记录上次为用户触发的时间,用于冷却控制 # 格式: { "chat_id": timestamp } @@ -53,19 +54,21 @@ class FrequencyBasedTrigger: await asyncio.sleep(TRIGGER_CHECK_INTERVAL_SECONDS) logger.debug("开始执行频率触发器检查...") - # 1. 检查角色是否清醒 - if self._sleep_manager.is_sleeping(): - logger.debug("角色正在睡眠,跳过本次频率触发检查。") - continue + # 1. TODO: 检查角色是否清醒 - 需要重新实现睡眠状态检查 + # 暂时跳过睡眠检查 + # if self._sleep_manager.is_sleeping(): + # logger.debug("角色正在睡眠,跳过本次频率触发检查。") + # continue # 2. 获取所有已知的聊天ID - # 【注意】这里我们假设所有 subheartflow 的 ID 就是 chat_id - all_chat_ids = list(heartflow.subheartflows.keys()) + # 注意:AFC管理器已移至chatter插件,此功能暂时禁用 + # all_chat_ids = list(afc_manager.affinity_flow_chatters.keys()) + all_chat_ids = [] # 暂时禁用此功能 if not all_chat_ids: continue now = datetime.now() - + for chat_id in all_chat_ids: # 3. 检查是否处于冷却时间内 last_triggered_time = self._last_triggered.get(chat_id, 0) @@ -74,29 +77,11 @@ class FrequencyBasedTrigger: # 4. 检查当前是否是该用户的高峰聊天时间 if chat_frequency_analyzer.is_in_peak_time(chat_id, now): - - sub_heartflow = await heartflow.get_or_create_subheartflow(chat_id) - if not sub_heartflow: - logger.warning(f"无法为 {chat_id} 获取或创建 sub_heartflow。") - continue - - # 5. 检查用户当前是否已有活跃的思考或回复任务 - cycle_detail = sub_heartflow.heart_fc_instance.context.current_cycle_detail - if cycle_detail and not cycle_detail.end_time: - logger.debug(f"用户 {chat_id} 的聊天循环正忙(仍在周期 {cycle_detail.cycle_id} 中),本次不触发。") - continue - - logger.info(f"检测到用户 {chat_id} 处于聊天高峰期,且聊天循环空闲,准备触发主动思考。") - - # 6. 直接调用 proactive_thinker - event = ProactiveTriggerEvent( - source="frequency_analyzer", - reason="User is in a high-frequency chat period." - ) - await sub_heartflow.heart_fc_instance.proactive_thinker.think(event) - - # 7. 更新触发时间,进入冷却 - self._last_triggered[chat_id] = time.time() + # 5. 检查用户当前是否已有活跃的处理任务 + # 注意:AFC管理器已移至chatter插件,此功能暂时禁用 + # chatter = afc_manager.get_or_create_chatter(chat_id) + logger.info(f"检测到用户 {chat_id} 处于聊天高峰期,但AFC功能已移至chatter插件") + continue except asyncio.CancelledError: logger.info("频率触发器任务被取消。") diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py deleted file mode 100644 index 111b37e64..000000000 --- a/src/chat/heart_flow/heartflow.py +++ /dev/null @@ -1,40 +0,0 @@ -import traceback -from typing import Any, Optional, Dict - -from src.common.logger import get_logger -from src.chat.heart_flow.sub_heartflow import SubHeartflow -from src.chat.message_receive.chat_stream import get_chat_manager - -logger = get_logger("heartflow") - - -class Heartflow: - """主心流协调器,负责初始化并协调聊天""" - - def __init__(self): - self.subheartflows: Dict[Any, "SubHeartflow"] = {} - - async def get_or_create_subheartflow(self, subheartflow_id: Any) -> Optional["SubHeartflow"]: - """获取或创建一个新的SubHeartflow实例""" - if subheartflow_id in self.subheartflows: - if subflow := self.subheartflows.get(subheartflow_id): - return subflow - - try: - new_subflow = SubHeartflow(subheartflow_id) - - await new_subflow.initialize() - - # 注册子心流 - self.subheartflows[subheartflow_id] = new_subflow - heartflow_name = get_chat_manager().get_stream_name(subheartflow_id) or subheartflow_id - logger.info(f"[{heartflow_name}] 开始接收消息") - - return new_subflow - except Exception as e: - logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True) - traceback.print_exc() - return None - - -heartflow = Heartflow() diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py deleted file mode 100644 index 958bc9096..000000000 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ /dev/null @@ -1,178 +0,0 @@ -import asyncio -import math -import re -import traceback -from typing import Tuple, TYPE_CHECKING - -from src.chat.heart_flow.heartflow import heartflow -from src.chat.memory_system.Hippocampus import hippocampus_manager -from src.chat.message_receive.message import MessageRecv -from src.chat.message_receive.storage import MessageStorage -from src.chat.utils.chat_message_builder import replace_user_references_sync -from src.chat.utils.timer_calculator import Timer -from src.chat.utils.utils import is_mentioned_bot_in_message -from src.common.logger import get_logger -from src.config.config import global_config -from src.mood.mood_manager import mood_manager -from src.person_info.relationship_manager import get_relationship_manager - -if TYPE_CHECKING: - from src.chat.heart_flow.sub_heartflow import SubHeartflow - -logger = get_logger("chat") - - -async def _process_relationship(message: MessageRecv) -> None: - """处理用户关系逻辑 - - Args: - message: 消息对象,包含用户信息 - """ - platform = message.message_info.platform - user_id = message.message_info.user_info.user_id # type: ignore - nickname = message.message_info.user_info.user_nickname # type: ignore - cardname = message.message_info.user_info.user_cardname or nickname # type: ignore - - relationship_manager = get_relationship_manager() - is_known = await relationship_manager.is_known_some_one(platform, user_id) - - if not is_known: - logger.info(f"首次认识用户: {nickname}") - await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore - - -async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool, list[str]]: - """计算消息的兴趣度 - - Args: - message: 待处理的消息对象 - - Returns: - Tuple[float, bool, list[str]]: (兴趣度, 是否被提及, 关键词) - """ - is_mentioned, _ = is_mentioned_bot_in_message(message) - interested_rate = 0.0 - - with Timer("记忆激活"): - interested_rate, keywords = await hippocampus_manager.get_activate_from_text( - message.processed_plain_text, - max_depth=4, - fast_retrieval=False, - ) - message.key_words = keywords - message.key_words_lite = keywords - logger.debug(f"记忆激活率: {interested_rate:.2f}, 关键词: {keywords}") - - text_len = len(message.processed_plain_text) - # 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算 - # 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%) - - if text_len == 0: - base_interest = 0.01 # 空消息最低兴趣度 - elif text_len <= 5: - # 1-5字符:线性增长 0.01 -> 0.03 - base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4 - elif text_len <= 10: - # 6-10字符:线性增长 0.03 -> 0.06 - base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5 - elif text_len <= 20: - # 11-20字符:线性增长 0.06 -> 0.12 - base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10 - elif text_len <= 30: - # 21-30字符:线性增长 0.12 -> 0.18 - base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10 - elif text_len <= 50: - # 31-50字符:线性增长 0.18 -> 0.22 - base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20 - elif text_len <= 100: - # 51-100字符:线性增长 0.22 -> 0.26 - base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50 - else: - # 100+字符:对数增长 0.26 -> 0.3,增长率递减 - base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901 - - # 确保在范围内 - base_interest = min(max(base_interest, 0.01), 0.3) - - interested_rate += base_interest - - if is_mentioned: - interest_increase_on_mention = 1 - interested_rate += interest_increase_on_mention - - return interested_rate, is_mentioned, keywords - - -class HeartFCMessageReceiver: - """心流处理器,负责处理接收到的消息并计算兴趣度""" - - def __init__(self): - """初始化心流处理器,创建消息存储实例""" - self.storage = MessageStorage() - - async def process_message(self, message: MessageRecv) -> None: - """处理接收到的原始消息数据 - - 主要流程: - 1. 消息解析与初始化 - 2. 消息缓冲处理 - 4. 过滤检查 - 5. 兴趣度计算 - 6. 关系处理 - - Args: - message_data: 原始消息字符串 - """ - try: - # 1. 消息解析与初始化 - userinfo = message.message_info.user_info - chat = message.chat_stream - - # 2. 兴趣度计算与更新 - interested_rate, is_mentioned, keywords = await _calculate_interest(message) - message.interest_value = interested_rate - message.is_mentioned = is_mentioned - - await self.storage.store_message(message, chat) - - subheartflow: SubHeartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) # type: ignore - - await subheartflow.heart_fc_instance.add_message(message.to_dict()) - if global_config.mood.enable_mood: - chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) - asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate)) - - # 3. 日志记录 - mes_name = chat.group_info.group_name if chat.group_info else "私聊" - # current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time)) - current_talk_frequency = global_config.chat.get_current_talk_frequency(chat.stream_id) - - # 如果消息中包含图片标识,则将 [picid:...] 替换为 [图片] - picid_pattern = r"\[picid:([^\]]+)\]" - processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text) - - # 应用用户引用格式替换,将回复和@格式转换为可读格式 - processed_plain_text = replace_user_references_sync( - processed_plain_text, - message.message_info.platform, # type: ignore - replace_bot_name=True, - ) - - if keywords: - logger.info( - f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}][关键词:{keywords}]" - ) # type: ignore - else: - logger.info( - f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}]" - ) # type: ignore - - logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]") - - # 4. 关系处理 - if global_config.relationship.enable_relationship: - await _process_relationship(message) - - except Exception as e: - logger.error(f"消息处理失败: {e}") - print(traceback.format_exc()) diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py deleted file mode 100644 index 136b1cb41..000000000 --- a/src/chat/heart_flow/sub_heartflow.py +++ /dev/null @@ -1,42 +0,0 @@ -from rich.traceback import install - -from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import get_chat_manager -from src.chat.chat_loop.heartFC_chat import HeartFChatting -from src.chat.utils.utils import get_chat_type_and_target_info - -logger = get_logger("sub_heartflow") - -install(extra_lines=3) - - -class SubHeartflow: - def __init__( - self, - subheartflow_id, - ): - """子心流初始化函数 - - Args: - subheartflow_id: 子心流唯一标识符 - """ - # 基础属性,两个值是一样的 - self.subheartflow_id = subheartflow_id - self.chat_id = subheartflow_id - - self.is_group_chat, self.chat_target_info = (None, None) - self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id - - # focus模式退出冷却时间管理 - self.last_focus_exit_time: float = 0 # 上次退出focus模式的时间 - - # 随便水群 normal_chat 和 认真水群 focus_chat 实例 - # CHAT模式激活 随便水群 FOCUS模式激活 认真水群 - self.heart_fc_instance: HeartFChatting = HeartFChatting( - chat_id=self.subheartflow_id, - ) # 该sub_heartflow的HeartFChatting实例 - - async def initialize(self): - """异步初始化方法,创建兴趣流并确定聊天类型""" - self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_id) - await self.heart_fc_instance.start() diff --git a/src/chat/interest_system/__init__.py b/src/chat/interest_system/__init__.py new file mode 100644 index 000000000..e05cbeebf --- /dev/null +++ b/src/chat/interest_system/__init__.py @@ -0,0 +1,15 @@ +""" +兴趣度系统模块 +提供机器人兴趣标签和智能匹配功能 +""" + +from .bot_interest_manager import BotInterestManager, bot_interest_manager +from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult + +__all__ = [ + "BotInterestManager", + "bot_interest_manager", + "BotInterestTag", + "BotPersonalityInterests", + "InterestMatchResult", +] diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py new file mode 100644 index 000000000..be04dd065 --- /dev/null +++ b/src/chat/interest_system/bot_interest_manager.py @@ -0,0 +1,805 @@ +""" +机器人兴趣标签管理系统 +基于人设生成兴趣标签,并使用embedding计算匹配度 +""" + +import orjson +import traceback +from typing import List, Dict, Optional, Any +from datetime import datetime +import numpy as np + +from src.common.logger import get_logger +from src.config.config import global_config +from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult + +logger = get_logger("bot_interest_manager") + + +class BotInterestManager: + """机器人兴趣标签管理器""" + + def __init__(self): + self.current_interests: Optional[BotPersonalityInterests] = None + self.embedding_cache: Dict[str, List[float]] = {} # embedding缓存 + self._initialized = False + + # Embedding客户端配置 + self.embedding_request = None + self.embedding_config = None + self.embedding_dimension = 1024 # 默认BGE-M3 embedding维度 + + @property + def is_initialized(self) -> bool: + """检查兴趣系统是否已初始化""" + return self._initialized + + async def initialize(self, personality_description: str, personality_id: str = "default"): + """初始化兴趣标签系统""" + try: + logger.info("机器人兴趣系统开始初始化...") + logger.info(f"人设ID: {personality_id}, 描述长度: {len(personality_description)}") + + # 初始化embedding模型 + await self._initialize_embedding_model() + + # 检查embedding客户端是否成功初始化 + if not self.embedding_request: + raise RuntimeError("Embedding客户端初始化失败") + + # 生成或加载兴趣标签 + await self._load_or_generate_interests(personality_description, personality_id) + + self._initialized = True + + # 检查是否成功获取兴趣标签 + if self.current_interests and len(self.current_interests.get_active_tags()) > 0: + active_tags_count = len(self.current_interests.get_active_tags()) + logger.info("机器人兴趣系统初始化完成!") + logger.info(f"当前已激活 {active_tags_count} 个兴趣标签, Embedding缓存 {len(self.embedding_cache)} 个") + else: + raise RuntimeError("未能成功加载或生成兴趣标签") + + except Exception as e: + logger.error(f"机器人兴趣系统初始化失败: {e}") + traceback.print_exc() + raise # 重新抛出异常,不允许降级初始化 + + async def _initialize_embedding_model(self): + """初始化embedding模型""" + logger.info("🔧 正在配置embedding客户端...") + + # 使用项目配置的embedding模型 + from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest + + logger.debug("✅ 成功导入embedding相关模块") + + # 检查embedding配置是否存在 + if not hasattr(model_config.model_task_config, "embedding"): + raise RuntimeError("❌ 未找到embedding模型配置") + + logger.info("📋 找到embedding模型配置") + self.embedding_config = model_config.model_task_config.embedding + self.embedding_dimension = 1024 # BGE-M3的维度 + logger.info(f"📐 使用模型维度: {self.embedding_dimension}") + + # 创建LLMRequest实例用于embedding + self.embedding_request = LLMRequest(model_set=self.embedding_config, request_type="interest_embedding") + logger.info("✅ Embedding请求客户端初始化成功") + logger.info(f"🔗 客户端类型: {type(self.embedding_request).__name__}") + + # 获取第一个embedding模型的ModelInfo + if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list: + first_model_name = self.embedding_config.model_list[0] + logger.info(f"🎯 使用embedding模型: {first_model_name}") + else: + logger.warning("⚠️ 未找到embedding模型列表") + + logger.info("✅ Embedding模型初始化完成") + + async def _load_or_generate_interests(self, personality_description: str, personality_id: str): + """加载或生成兴趣标签""" + logger.info(f"📚 正在为 '{personality_id}' 加载或生成兴趣标签...") + + # 首先尝试从数据库加载 + logger.info("尝试从数据库加载兴趣标签...") + loaded_interests = await self._load_interests_from_database(personality_id) + + if loaded_interests: + self.current_interests = loaded_interests + active_count = len(loaded_interests.get_active_tags()) + logger.info(f"成功从数据库加载 {active_count} 个兴趣标签 (版本: {loaded_interests.version})") + tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in loaded_interests.get_active_tags()] + tags_str = "\n".join(tags_info) + logger.info(f"当前兴趣标签:\n{tags_str}") + else: + # 生成新的兴趣标签 + logger.info("数据库中未找到兴趣标签,开始生成...") + generated_interests = await self._generate_interests_from_personality( + personality_description, personality_id + ) + + if generated_interests: + self.current_interests = generated_interests + active_count = len(generated_interests.get_active_tags()) + logger.info(f"成功生成 {active_count} 个新兴趣标签。") + tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in generated_interests.get_active_tags()] + tags_str = "\n".join(tags_info) + logger.info(f"当前兴趣标签:\n{tags_str}") + + # 保存到数据库 + logger.info("正在保存至数据库...") + await self._save_interests_to_database(generated_interests) + else: + raise RuntimeError("❌ 兴趣标签生成失败") + + async def _generate_interests_from_personality( + self, personality_description: str, personality_id: str + ) -> Optional[BotPersonalityInterests]: + """根据人设生成兴趣标签""" + try: + logger.info("🎨 开始根据人设生成兴趣标签...") + logger.info(f"📝 人设长度: {len(personality_description)} 字符") + + # 检查embedding客户端是否可用 + if not hasattr(self, "embedding_request"): + raise RuntimeError("❌ Embedding客户端未初始化,无法生成兴趣标签") + + # 构建提示词 + logger.info("📝 构建LLM提示词...") + prompt = f""" +基于以下机器人人设描述,生成一套合适的兴趣标签: + +人设描述: +{personality_description} + +请生成一系列兴趣关键词标签,要求: +1. 标签应该符合人设特点和性格 +2. 每个标签都有权重(0.1-1.0),表示对该兴趣的喜好程度 +3. 生成15-25个不等的标签 +4. 标签应该是具体的关键词,而不是抽象概念 + +请以JSON格式返回,格式如下: +{{ + "interests": [ + {{"name": "标签名", "weight": 0.8}}, + {{"name": "标签名", "weight": 0.6}}, + {{"name": "标签名", "weight": 0.9}} + ] +}} + +注意: +- 权重范围0.1-1.0,权重越高表示越感兴趣 +- 标签要具体,如"编程"、"游戏"、"旅行"等 +- 根据人设生成个性化的标签 +""" + + # 调用LLM生成兴趣标签 + logger.info("🤖 正在调用LLM生成兴趣标签...") + response = await self._call_llm_for_interest_generation(prompt) + + if not response: + raise RuntimeError("❌ LLM未返回有效响应") + + logger.info("✅ LLM响应成功,开始解析兴趣标签...") + interests_data = orjson.loads(response) + + bot_interests = BotPersonalityInterests( + personality_id=personality_id, personality_description=personality_description + ) + + # 解析生成的兴趣标签 + interests_list = interests_data.get("interests", []) + logger.info(f"📋 解析到 {len(interests_list)} 个兴趣标签") + + for i, tag_data in enumerate(interests_list): + tag_name = tag_data.get("name", f"标签_{i}") + weight = tag_data.get("weight", 0.5) + + tag = BotInterestTag(tag_name=tag_name, weight=weight) + bot_interests.interest_tags.append(tag) + + logger.debug(f" 🏷️ {tag_name} (权重: {weight:.2f})") + + # 为所有标签生成embedding + logger.info("🧠 开始为兴趣标签生成embedding向量...") + await self._generate_embeddings_for_tags(bot_interests) + + logger.info("✅ 兴趣标签生成完成") + return bot_interests + + except orjson.JSONDecodeError as e: + logger.error(f"❌ 解析LLM响应JSON失败: {e}") + raise + except Exception as e: + logger.error(f"❌ 根据人设生成兴趣标签失败: {e}") + traceback.print_exc() + raise + + async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]: + """调用LLM生成兴趣标签""" + try: + logger.info("🔧 配置LLM客户端...") + + # 使用llm_api来处理请求 + from src.plugin_system.apis import llm_api + from src.config.config import model_config + + # 构建完整的提示词,明确要求只返回纯JSON + full_prompt = f"""你是一个专业的机器人人设分析师,擅长根据人设描述生成合适的兴趣标签。 + +{prompt} + +请确保返回格式为有效的JSON,不要包含任何额外的文本、解释或代码块标记。只返回JSON对象本身。""" + + # 使用replyer模型配置 + replyer_config = model_config.model_task_config.replyer + + # 调用LLM API + logger.info("🚀 正在通过LLM API发送请求...") + success, response, reasoning_content, model_name = await llm_api.generate_with_model( + prompt=full_prompt, + model_config=replyer_config, + request_type="interest_generation", + temperature=0.7, + max_tokens=2000, + ) + + if success and response: + logger.info(f"✅ LLM响应成功,模型: {model_name}, 响应长度: {len(response)} 字符") + logger.debug( + f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}" + ) + if reasoning_content: + logger.debug(f"🧠 推理内容: {reasoning_content[:100]}...") + + # 清理响应内容,移除可能的代码块标记 + cleaned_response = self._clean_llm_response(response) + return cleaned_response + else: + logger.warning("⚠️ LLM返回空响应或调用失败") + return None + + except Exception as e: + logger.error(f"❌ 调用LLM生成兴趣标签失败: {e}") + logger.error("🔍 错误详情:") + traceback.print_exc() + return None + + def _clean_llm_response(self, response: str) -> str: + """清理LLM响应,移除代码块标记和其他非JSON内容""" + import re + + # 移除 ```json 和 ``` 标记 + cleaned = re.sub(r"```json\s*", "", response) + cleaned = re.sub(r"\s*```", "", cleaned) + + # 移除可能的多余空格和换行 + cleaned = cleaned.strip() + + # 尝试提取JSON对象(如果响应中有其他文本) + json_match = re.search(r"\{.*\}", cleaned, re.DOTALL) + if json_match: + cleaned = json_match.group(0) + + logger.debug(f"🧹 清理后的响应: {cleaned[:200]}..." if len(cleaned) > 200 else f"🧹 清理后的响应: {cleaned}") + return cleaned + + async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests): + """为所有兴趣标签生成embedding""" + if not hasattr(self, "embedding_request"): + raise RuntimeError("❌ Embedding客户端未初始化,无法生成embedding") + + total_tags = len(interests.interest_tags) + logger.info(f"🧠 开始为 {total_tags} 个兴趣标签生成embedding向量...") + + cached_count = 0 + generated_count = 0 + failed_count = 0 + + for i, tag in enumerate(interests.interest_tags, 1): + if tag.tag_name in self.embedding_cache: + # 使用缓存的embedding + tag.embedding = self.embedding_cache[tag.tag_name] + cached_count += 1 + logger.debug(f" [{i}/{total_tags}] 🏷️ '{tag.tag_name}' - 使用缓存") + else: + # 生成新的embedding + embedding_text = tag.tag_name + + logger.debug(f" [{i}/{total_tags}] 🔄 正在为 '{tag.tag_name}' 生成embedding...") + embedding = await self._get_embedding(embedding_text) + + if embedding: + tag.embedding = embedding + self.embedding_cache[tag.tag_name] = embedding + generated_count += 1 + logger.debug(f" ✅ '{tag.tag_name}' embedding生成成功") + else: + failed_count += 1 + logger.warning(f" ❌ '{tag.tag_name}' embedding生成失败") + + if failed_count > 0: + raise RuntimeError(f"❌ 有 {failed_count} 个兴趣标签embedding生成失败") + + interests.last_updated = datetime.now() + logger.info("=" * 50) + logger.info("✅ Embedding生成完成!") + logger.info(f"📊 总标签数: {total_tags}") + logger.info(f"💾 缓存命中: {cached_count}") + logger.info(f"🆕 新生成: {generated_count}") + logger.info(f"❌ 失败: {failed_count}") + logger.info(f"🗃️ 总缓存大小: {len(self.embedding_cache)}") + logger.info("=" * 50) + + async def _get_embedding(self, text: str) -> List[float]: + """获取文本的embedding向量""" + if not hasattr(self, "embedding_request"): + raise RuntimeError("❌ Embedding请求客户端未初始化") + + # 检查缓存 + if text in self.embedding_cache: + logger.debug(f"💾 使用缓存的embedding: '{text[:30]}...'") + return self.embedding_cache[text] + + # 使用LLMRequest获取embedding + logger.debug(f"🔄 正在获取embedding: '{text[:30]}...'") + embedding, model_name = await self.embedding_request.get_embedding(text) + + if embedding and len(embedding) > 0: + self.embedding_cache[text] = embedding + logger.debug(f"✅ Embedding获取成功,维度: {len(embedding)}, 模型: {model_name}") + return embedding + else: + raise RuntimeError(f"❌ 返回的embedding为空: {embedding}") + + async def _generate_message_embedding(self, message_text: str, keywords: List[str]) -> List[float]: + """为消息生成embedding向量""" + # 组合消息文本和关键词作为embedding输入 + if keywords: + combined_text = f"{message_text} {' '.join(keywords)}" + else: + combined_text = message_text + + logger.debug(f"🔄 正在为消息生成embedding,输入长度: {len(combined_text)}") + + # 生成embedding + embedding = await self._get_embedding(combined_text) + logger.debug(f"✅ 消息embedding生成成功,维度: {len(embedding)}") + return embedding + + async def _calculate_similarity_scores( + self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str] + ): + """计算消息与兴趣标签的相似度分数""" + try: + if not self.current_interests: + return + + active_tags = self.current_interests.get_active_tags() + if not active_tags: + return + + logger.debug(f"🔍 开始计算与 {len(active_tags)} 个兴趣标签的相似度") + + for tag in active_tags: + if tag.embedding: + # 计算余弦相似度 + similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding) + weighted_score = similarity * tag.weight + + # 设置相似度阈值为0.3 + if similarity > 0.3: + result.add_match(tag.tag_name, weighted_score, keywords) + logger.debug( + f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}" + ) + + except Exception as e: + logger.error(f"❌ 计算相似度分数失败: {e}") + + async def calculate_interest_match(self, message_text: str, keywords: List[str] = None) -> InterestMatchResult: + """计算消息与机器人兴趣的匹配度""" + if not self.current_interests or not self._initialized: + raise RuntimeError("❌ 兴趣标签系统未初始化") + + logger.debug(f"开始计算兴趣匹配度: 消息长度={len(message_text)}, 关键词数={len(keywords) if keywords else 0}") + + message_id = f"msg_{datetime.now().timestamp()}" + result = InterestMatchResult(message_id=message_id) + + # 获取活跃的兴趣标签 + active_tags = self.current_interests.get_active_tags() + if not active_tags: + raise RuntimeError("没有检测到活跃的兴趣标签") + + logger.debug(f"正在与 {len(active_tags)} 个兴趣标签进行匹配...") + + # 生成消息的embedding + logger.debug("正在生成消息 embedding...") + message_embedding = await self._get_embedding(message_text) + logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}") + + # 计算与每个兴趣标签的相似度 + match_count = 0 + high_similarity_count = 0 + medium_similarity_count = 0 + low_similarity_count = 0 + + # 分级相似度阈值 + affinity_config = global_config.affinity_flow + high_threshold = affinity_config.high_match_interest_threshold + medium_threshold = affinity_config.medium_match_interest_threshold + low_threshold = affinity_config.low_match_interest_threshold + + logger.debug(f"🔍 使用分级相似度阈值: 高={high_threshold}, 中={medium_threshold}, 低={low_threshold}") + + for tag in active_tags: + if tag.embedding: + similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding) + + # 基础加权分数 + weighted_score = similarity * tag.weight + + # 根据相似度等级应用不同的加成 + if similarity > high_threshold: + # 高相似度:强加成 + enhanced_score = weighted_score * affinity_config.high_match_keyword_multiplier + match_count += 1 + high_similarity_count += 1 + result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) + + elif similarity > medium_threshold: + # 中相似度:中等加成 + enhanced_score = weighted_score * affinity_config.medium_match_keyword_multiplier + match_count += 1 + medium_similarity_count += 1 + result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) + + elif similarity > low_threshold: + # 低相似度:轻微加成 + enhanced_score = weighted_score * affinity_config.low_match_keyword_multiplier + match_count += 1 + low_similarity_count += 1 + result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) + + logger.debug( + f"匹配统计: {match_count}/{len(active_tags)} 个标签命中 | " + f"高(>{high_threshold}): {high_similarity_count}, " + f"中(>{medium_threshold}): {medium_similarity_count}, " + f"低(>{low_threshold}): {low_similarity_count}" + ) + + # 添加直接关键词匹配奖励 + keyword_bonus = self._calculate_keyword_match_bonus(keywords, result.matched_tags) + logger.debug(f"🎯 关键词直接匹配奖励: {keyword_bonus}") + + # 应用关键词奖励到匹配分数 + for tag_name in result.matched_tags: + if tag_name in keyword_bonus: + original_score = result.match_scores[tag_name] + bonus = keyword_bonus[tag_name] + result.match_scores[tag_name] = original_score + bonus + logger.debug( + f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}" + ) + + # 计算总体分数 + result.calculate_overall_score() + + # 确定最佳匹配标签 + if result.matched_tags: + top_tag_name = max(result.match_scores.items(), key=lambda x: x[1])[0] + result.top_tag = top_tag_name + logger.debug(f"最佳匹配: '{top_tag_name}' (分数: {result.match_scores[top_tag_name]:.3f})") + + logger.debug( + f"最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}" + ) + return result + + def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]: + """计算关键词直接匹配奖励""" + if not keywords or not matched_tags: + return {} + + affinity_config = global_config.affinity_flow + bonus_dict = {} + + for tag_name in matched_tags: + bonus = 0.0 + + # 检查关键词与标签的直接匹配 + for keyword in keywords: + keyword_lower = keyword.lower().strip() + tag_name_lower = tag_name.lower() + + # 完全匹配 + if keyword_lower == tag_name_lower: + bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励 + logger.debug( + f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})" + ) + + # 包含匹配 + elif keyword_lower in tag_name_lower or tag_name_lower in keyword_lower: + bonus += ( + affinity_config.medium_match_interest_threshold * 0.3 + ) # 使用中匹配阈值的30%作为包含匹配奖励 + logger.debug( + f" 🎯 关键词包含匹配: '{keyword}' ⊃ '{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})" + ) + + # 部分匹配(编辑距离) + elif self._calculate_partial_match(keyword_lower, tag_name_lower): + bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励 + logger.debug( + f" 🎯 关键词部分匹配: '{keyword}' ≈ '{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})" + ) + + if bonus > 0: + bonus_dict[tag_name] = min(bonus, affinity_config.max_match_bonus) # 使用配置的最大奖励限制 + + return bonus_dict + + def _calculate_partial_match(self, text1: str, text2: str) -> bool: + """计算部分匹配(基于编辑距离)""" + try: + # 简单的编辑距离计算 + max_len = max(len(text1), len(text2)) + if max_len == 0: + return False + + # 计算编辑距离 + distance = self._levenshtein_distance(text1, text2) + + # 如果编辑距离小于较短字符串长度的一半,认为是部分匹配 + min_len = min(len(text1), len(text2)) + return distance <= min_len // 2 + + except Exception: + return False + + def _levenshtein_distance(self, s1: str, s2: str) -> int: + """计算莱文斯坦距离""" + if len(s1) < len(s2): + return self._levenshtein_distance(s2, s1) + + if len(s2) == 0: + return len(s1) + + previous_row = range(len(s2) + 1) + for i, c1 in enumerate(s1): + current_row = [i + 1] + for j, c2 in enumerate(s2): + insertions = previous_row[j + 1] + 1 + deletions = current_row[j] + 1 + substitutions = previous_row[j] + (c1 != c2) + current_row.append(min(insertions, deletions, substitutions)) + previous_row = current_row + + return previous_row[-1] + + def _calculate_cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float: + """计算余弦相似度""" + try: + vec1 = np.array(vec1) + vec2 = np.array(vec2) + + dot_product = np.dot(vec1, vec2) + norm1 = np.linalg.norm(vec1) + norm2 = np.linalg.norm(vec2) + + if norm1 == 0 or norm2 == 0: + return 0.0 + + return dot_product / (norm1 * norm2) + + except Exception as e: + logger.error(f"计算余弦相似度失败: {e}") + return 0.0 + + async def _load_interests_from_database(self, personality_id: str) -> Optional[BotPersonalityInterests]: + """从数据库加载兴趣标签""" + try: + logger.debug(f"从数据库加载兴趣标签, personality_id: {personality_id}") + + # 导入SQLAlchemy相关模块 + from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests + from src.common.database.sqlalchemy_database_api import get_db_session + import orjson + + with get_db_session() as session: + # 查询最新的兴趣标签配置 + db_interests = ( + session.query(DBBotPersonalityInterests) + .filter(DBBotPersonalityInterests.personality_id == personality_id) + .order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc()) + .first() + ) + + if db_interests: + logger.debug(f"在数据库中找到兴趣标签配置, 版本: {db_interests.version}") + logger.debug(f"📅 最后更新时间: {db_interests.last_updated}") + logger.debug(f"🧠 使用的embedding模型: {db_interests.embedding_model}") + + # 解析JSON格式的兴趣标签 + try: + tags_data = orjson.loads(db_interests.interest_tags) + logger.debug(f"🏷️ 解析到 {len(tags_data)} 个兴趣标签") + + # 创建BotPersonalityInterests对象 + interests = BotPersonalityInterests( + personality_id=db_interests.personality_id, + personality_description=db_interests.personality_description, + embedding_model=db_interests.embedding_model, + version=db_interests.version, + last_updated=db_interests.last_updated, + ) + + # 解析兴趣标签 + for tag_data in tags_data: + tag = BotInterestTag( + tag_name=tag_data.get("tag_name", ""), + weight=tag_data.get("weight", 0.5), + created_at=datetime.fromisoformat( + tag_data.get("created_at", datetime.now().isoformat()) + ), + updated_at=datetime.fromisoformat( + tag_data.get("updated_at", datetime.now().isoformat()) + ), + is_active=tag_data.get("is_active", True), + embedding=tag_data.get("embedding"), + ) + interests.interest_tags.append(tag) + + logger.debug(f"成功解析 {len(interests.interest_tags)} 个兴趣标签") + return interests + + except (orjson.JSONDecodeError, Exception) as e: + logger.error(f"❌ 解析兴趣标签JSON失败: {e}") + logger.debug(f"🔍 原始JSON数据: {db_interests.interest_tags[:200]}...") + return None + else: + logger.info(f"ℹ️ 数据库中未找到personality_id为 '{personality_id}' 的兴趣标签配置") + return None + + except Exception as e: + logger.error(f"❌ 从数据库加载兴趣标签失败: {e}") + logger.error("🔍 错误详情:") + traceback.print_exc() + return None + + async def _save_interests_to_database(self, interests: BotPersonalityInterests): + """保存兴趣标签到数据库""" + try: + logger.info("💾 正在保存兴趣标签到数据库...") + logger.info(f"📋 personality_id: {interests.personality_id}") + logger.info(f"🏷️ 兴趣标签数量: {len(interests.interest_tags)}") + logger.info(f"🔄 版本: {interests.version}") + + # 导入SQLAlchemy相关模块 + from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests + from src.common.database.sqlalchemy_database_api import get_db_session + import orjson + + # 将兴趣标签转换为JSON格式 + tags_data = [] + for tag in interests.interest_tags: + tag_dict = { + "tag_name": tag.tag_name, + "weight": tag.weight, + "created_at": tag.created_at.isoformat(), + "updated_at": tag.updated_at.isoformat(), + "is_active": tag.is_active, + "embedding": tag.embedding, + } + tags_data.append(tag_dict) + + # 序列化为JSON + json_data = orjson.dumps(tags_data) + + with get_db_session() as session: + # 检查是否已存在相同personality_id的记录 + existing_record = ( + session.query(DBBotPersonalityInterests) + .filter(DBBotPersonalityInterests.personality_id == interests.personality_id) + .first() + ) + + if existing_record: + # 更新现有记录 + logger.info("🔄 更新现有的兴趣标签配置") + existing_record.interest_tags = json_data + existing_record.personality_description = interests.personality_description + existing_record.embedding_model = interests.embedding_model + existing_record.version = interests.version + existing_record.last_updated = interests.last_updated + + logger.info(f"✅ 成功更新兴趣标签配置,版本: {interests.version}") + + else: + # 创建新记录 + logger.info("🆕 创建新的兴趣标签配置") + new_record = DBBotPersonalityInterests( + personality_id=interests.personality_id, + personality_description=interests.personality_description, + interest_tags=json_data, + embedding_model=interests.embedding_model, + version=interests.version, + last_updated=interests.last_updated, + ) + session.add(new_record) + session.commit() + logger.info(f"✅ 成功创建兴趣标签配置,版本: {interests.version}") + + logger.info("✅ 兴趣标签已成功保存到数据库") + + # 验证保存是否成功 + with get_db_session() as session: + saved_record = ( + session.query(DBBotPersonalityInterests) + .filter(DBBotPersonalityInterests.personality_id == interests.personality_id) + .first() + ) + session.commit() + if saved_record: + logger.info(f"✅ 验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录") + logger.info(f" 版本: {saved_record.version}") + logger.info(f" 最后更新: {saved_record.last_updated}") + else: + logger.error(f"❌ 验证失败:数据库中未找到personality_id为 {interests.personality_id} 的记录") + + except Exception as e: + logger.error(f"❌ 保存兴趣标签到数据库失败: {e}") + logger.error("🔍 错误详情:") + traceback.print_exc() + + def get_current_interests(self) -> Optional[BotPersonalityInterests]: + """获取当前的兴趣标签配置""" + return self.current_interests + + def get_interest_stats(self) -> Dict[str, Any]: + """获取兴趣系统统计信息""" + if not self.current_interests: + return {"initialized": False} + + active_tags = self.current_interests.get_active_tags() + + return { + "initialized": self._initialized, + "total_tags": len(active_tags), + "embedding_model": self.current_interests.embedding_model, + "last_updated": self.current_interests.last_updated.isoformat(), + "cache_size": len(self.embedding_cache), + } + + async def update_interest_tags(self, new_personality_description: str = None): + """更新兴趣标签""" + try: + if not self.current_interests: + logger.warning("没有当前的兴趣标签配置,无法更新") + return + + if new_personality_description: + self.current_interests.personality_description = new_personality_description + + # 重新生成兴趣标签 + new_interests = await self._generate_interests_from_personality( + self.current_interests.personality_description, self.current_interests.personality_id + ) + + if new_interests: + new_interests.version = self.current_interests.version + 1 + self.current_interests = new_interests + await self._save_interests_to_database(new_interests) + logger.info(f"兴趣标签已更新,版本: {new_interests.version}") + + except Exception as e: + logger.error(f"更新兴趣标签失败: {e}") + traceback.print_exc() + + +# 创建全局实例(重新创建以包含新的属性) +bot_interest_manager = BotInterestManager() diff --git a/src/chat/message_manager/__init__.py b/src/chat/message_manager/__init__.py new file mode 100644 index 000000000..2f623fbd0 --- /dev/null +++ b/src/chat/message_manager/__init__.py @@ -0,0 +1,26 @@ +""" +消息管理器模块 +提供统一的消息管理、上下文管理和分发调度功能 +""" + +from .message_manager import MessageManager, message_manager +from .context_manager import StreamContextManager, context_manager +from .distribution_manager import ( + DistributionManager, + DistributionPriority, + DistributionTask, + StreamDistributionState, + distribution_manager +) + +__all__ = [ + "MessageManager", + "message_manager", + "StreamContextManager", + "context_manager", + "DistributionManager", + "DistributionPriority", + "DistributionTask", + "StreamDistributionState", + "distribution_manager" +] \ No newline at end of file diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py new file mode 100644 index 000000000..982b8a8a5 --- /dev/null +++ b/src/chat/message_manager/context_manager.py @@ -0,0 +1,653 @@ +""" +重构后的聊天上下文管理器 +提供统一、稳定的聊天上下文管理功能 +""" + +import asyncio +import time +from typing import Dict, List, Optional, Any, Union, Tuple +from abc import ABC, abstractmethod + +from src.common.data_models.message_manager_data_model import StreamContext +from src.common.logger import get_logger +from src.config.config import global_config +from src.common.data_models.database_data_model import DatabaseMessages +from src.chat.energy_system import energy_manager +from .distribution_manager import distribution_manager + +logger = get_logger("context_manager") + +class StreamContextManager: + """流上下文管理器 - 统一管理所有聊天流上下文""" + + def __init__(self, max_context_size: Optional[int] = None, context_ttl: Optional[int] = None): + # 上下文存储 + self.stream_contexts: Dict[str, Any] = {} + self.context_metadata: Dict[str, Dict[str, Any]] = {} + + # 统计信息 + self.stats: Dict[str, Union[int, float, str, Dict]] = { + "total_messages": 0, + "total_streams": 0, + "active_streams": 0, + "inactive_streams": 0, + "last_activity": time.time(), + "creation_time": time.time(), + } + + # 配置参数 + self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100) + self.context_ttl = context_ttl or getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时 + self.cleanup_interval = getattr(global_config.chat, "context_cleanup_interval", 3600) # 1小时 + self.auto_cleanup = getattr(global_config.chat, "auto_cleanup_contexts", True) + self.enable_validation = getattr(global_config.chat, "enable_context_validation", True) + + # 清理任务 + self.cleanup_task: Optional[Any] = None + self.is_running = False + + logger.info(f"上下文管理器初始化完成 (最大上下文: {self.max_context_size}, TTL: {self.context_ttl}s)") + + def add_stream_context(self, stream_id: str, context: Any, metadata: Optional[Dict[str, Any]] = None) -> bool: + """添加流上下文 + + Args: + stream_id: 流ID + context: 上下文对象 + metadata: 上下文元数据 + + Returns: + bool: 是否成功添加 + """ + if stream_id in self.stream_contexts: + logger.warning(f"流上下文已存在: {stream_id}") + return False + + # 添加上下文 + self.stream_contexts[stream_id] = context + + # 初始化元数据 + self.context_metadata[stream_id] = { + "created_time": time.time(), + "last_access_time": time.time(), + "access_count": 0, + "last_validation_time": 0.0, + "custom_metadata": metadata or {}, + } + + # 更新统计 + self.stats["total_streams"] += 1 + self.stats["active_streams"] += 1 + self.stats["last_activity"] = time.time() + + logger.debug(f"添加流上下文: {stream_id} (类型: {type(context).__name__})") + return True + + def remove_stream_context(self, stream_id: str) -> bool: + """移除流上下文 + + Args: + stream_id: 流ID + + Returns: + bool: 是否成功移除 + """ + if stream_id in self.stream_contexts: + context = self.stream_contexts[stream_id] + metadata = self.context_metadata.get(stream_id, {}) + + del self.stream_contexts[stream_id] + if stream_id in self.context_metadata: + del self.context_metadata[stream_id] + + self.stats["active_streams"] = max(0, self.stats["active_streams"] - 1) + self.stats["inactive_streams"] += 1 + self.stats["last_activity"] = time.time() + + logger.debug(f"移除流上下文: {stream_id} (类型: {type(context).__name__})") + return True + return False + + def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[StreamContext]: + """获取流上下文 + + Args: + stream_id: 流ID + update_access: 是否更新访问统计 + + Returns: + Optional[Any]: 上下文对象 + """ + context = self.stream_contexts.get(stream_id) + if context and update_access: + # 更新访问统计 + if stream_id in self.context_metadata: + metadata = self.context_metadata[stream_id] + metadata["last_access_time"] = time.time() + metadata["access_count"] = metadata.get("access_count", 0) + 1 + return context + + def get_context_metadata(self, stream_id: str) -> Optional[Dict[str, Any]]: + """获取上下文元数据 + + Args: + stream_id: 流ID + + Returns: + Optional[Dict[str, Any]]: 元数据 + """ + return self.context_metadata.get(stream_id) + + def update_context_metadata(self, stream_id: str, updates: Dict[str, Any]) -> bool: + """更新上下文元数据 + + Args: + stream_id: 流ID + updates: 更新的元数据 + + Returns: + bool: 是否成功更新 + """ + if stream_id not in self.context_metadata: + return False + + self.context_metadata[stream_id].update(updates) + return True + + def add_message_to_context(self, stream_id: str, message: DatabaseMessages, skip_energy_update: bool = False) -> bool: + """添加消息到上下文 + + Args: + stream_id: 流ID + message: 消息对象 + skip_energy_update: 是否跳过能量更新 + + Returns: + bool: 是否成功添加 + """ + context = self.get_stream_context(stream_id) + if not context: + logger.warning(f"流上下文不存在: {stream_id}") + return False + + try: + # 添加消息到上下文 + context.add_message(message) + + # 计算消息兴趣度 + interest_value = self._calculate_message_interest(message) + message.interest_value = interest_value + + # 更新统计 + self.stats["total_messages"] += 1 + self.stats["last_activity"] = time.time() + + # 更新能量和分发 + if not skip_energy_update: + self._update_stream_energy(stream_id) + distribution_manager.add_stream_message(stream_id, 1) + + logger.debug(f"添加消息到上下文: {stream_id} (兴趣度: {interest_value:.3f})") + return True + + except Exception as e: + logger.error(f"添加消息到上下文失败 {stream_id}: {e}", exc_info=True) + return False + + def update_message_in_context(self, stream_id: str, message_id: str, updates: Dict[str, Any]) -> bool: + """更新上下文中的消息 + + Args: + stream_id: 流ID + message_id: 消息ID + updates: 更新的属性 + + Returns: + bool: 是否成功更新 + """ + context = self.get_stream_context(stream_id) + if not context: + logger.warning(f"流上下文不存在: {stream_id}") + return False + + try: + # 更新消息信息 + context.update_message_info(message_id, **updates) + + # 如果更新了兴趣度,重新计算能量 + if "interest_value" in updates: + self._update_stream_energy(stream_id) + + logger.debug(f"更新上下文消息: {stream_id}/{message_id}") + return True + + except Exception as e: + logger.error(f"更新上下文消息失败 {stream_id}/{message_id}: {e}", exc_info=True) + return False + + def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]: + """获取上下文消息 + + Args: + stream_id: 流ID + limit: 消息数量限制 + include_unread: 是否包含未读消息 + + Returns: + List[Any]: 消息列表 + """ + context = self.get_stream_context(stream_id) + if not context: + return [] + + try: + messages = [] + if include_unread: + messages.extend(context.get_unread_messages()) + + if limit: + messages.extend(context.get_history_messages(limit=limit)) + else: + messages.extend(context.get_history_messages()) + + # 按时间排序 + messages.sort(key=lambda msg: getattr(msg, 'time', 0)) + + # 应用限制 + if limit and len(messages) > limit: + messages = messages[-limit:] + + return messages + + except Exception as e: + logger.error(f"获取上下文消息失败 {stream_id}: {e}", exc_info=True) + return [] + + def get_unread_messages(self, stream_id: str) -> List[DatabaseMessages]: + """获取未读消息 + + Args: + stream_id: 流ID + + Returns: + List[Any]: 未读消息列表 + """ + context = self.get_stream_context(stream_id) + if not context: + return [] + + try: + return context.get_unread_messages() + except Exception as e: + logger.error(f"获取未读消息失败 {stream_id}: {e}", exc_info=True) + return [] + + def mark_messages_as_read(self, stream_id: str, message_ids: List[str]) -> bool: + """标记消息为已读 + + Args: + stream_id: 流ID + message_ids: 消息ID列表 + + Returns: + bool: 是否成功标记 + """ + context = self.get_stream_context(stream_id) + if not context: + logger.warning(f"流上下文不存在: {stream_id}") + return False + + try: + if not hasattr(context, 'mark_message_as_read'): + logger.error(f"上下文对象缺少 mark_message_as_read 方法: {stream_id}") + return False + + marked_count = 0 + for message_id in message_ids: + try: + context.mark_message_as_read(message_id) + marked_count += 1 + except Exception as e: + logger.warning(f"标记消息已读失败 {message_id}: {e}") + + logger.debug(f"标记消息为已读: {stream_id} ({marked_count}/{len(message_ids)}条)") + return marked_count > 0 + + except Exception as e: + logger.error(f"标记消息已读失败 {stream_id}: {e}", exc_info=True) + return False + + def clear_context(self, stream_id: str) -> bool: + """清空上下文 + + Args: + stream_id: 流ID + + Returns: + bool: 是否成功清空 + """ + context = self.get_stream_context(stream_id) + if not context: + logger.warning(f"流上下文不存在: {stream_id}") + return False + + try: + # 清空消息 + if hasattr(context, 'unread_messages'): + context.unread_messages.clear() + if hasattr(context, 'history_messages'): + context.history_messages.clear() + + # 重置状态 + reset_attrs = ['interruption_count', 'afc_threshold_adjustment', 'last_check_time'] + for attr in reset_attrs: + if hasattr(context, attr): + if attr in ['interruption_count', 'afc_threshold_adjustment']: + setattr(context, attr, 0) + else: + setattr(context, attr, time.time()) + + # 重新计算能量 + self._update_stream_energy(stream_id) + + logger.info(f"清空上下文: {stream_id}") + return True + + except Exception as e: + logger.error(f"清空上下文失败 {stream_id}: {e}", exc_info=True) + return False + + def _calculate_message_interest(self, message: DatabaseMessages) -> float: + """计算消息兴趣度""" + try: + # 使用插件内部的兴趣度评分系统 + try: + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system + + # 使用插件内部的兴趣度评分系统计算(同步方式) + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + interest_score = loop.run_until_complete( + chatter_interest_scoring_system._calculate_single_message_score( + message=message, + bot_nickname=global_config.bot.nickname + ) + ) + interest_value = interest_score.total_score + + logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}") + + except Exception as e: + logger.warning(f"插件内部兴趣度计算失败,使用默认值: {e}") + interest_value = 0.5 # 默认中等兴趣度 + + return interest_value + + except Exception as e: + logger.error(f"计算消息兴趣度失败: {e}") + return 0.5 + + def _update_stream_energy(self, stream_id: str): + """更新流能量""" + try: + # 获取所有消息 + all_messages = self.get_context_messages(stream_id, self.max_context_size) + unread_messages = self.get_unread_messages(stream_id) + combined_messages = all_messages + unread_messages + + # 获取用户ID + user_id = None + if combined_messages: + last_message = combined_messages[-1] + user_id = last_message.user_info.user_id + + # 计算能量 + energy = energy_manager.calculate_focus_energy( + stream_id=stream_id, + messages=combined_messages, + user_id=user_id + ) + + # 更新分发管理器 + distribution_manager.update_stream_energy(stream_id, energy) + + except Exception as e: + logger.error(f"更新流能量失败 {stream_id}: {e}") + + def get_stream_statistics(self, stream_id: str) -> Optional[Dict[str, Any]]: + """获取流统计信息 + + Args: + stream_id: 流ID + + Returns: + Optional[Dict[str, Any]]: 统计信息 + """ + context = self.get_stream_context(stream_id, update_access=False) + if not context: + return None + + try: + metadata = self.context_metadata.get(stream_id, {}) + current_time = time.time() + created_time = metadata.get("created_time", current_time) + last_access_time = metadata.get("last_access_time", current_time) + access_count = metadata.get("access_count", 0) + + unread_messages = getattr(context, "unread_messages", []) + history_messages = getattr(context, "history_messages", []) + + return { + "stream_id": stream_id, + "context_type": type(context).__name__, + "total_messages": len(history_messages) + len(unread_messages), + "unread_messages": len(unread_messages), + "history_messages": len(history_messages), + "is_active": getattr(context, "is_active", True), + "last_check_time": getattr(context, "last_check_time", current_time), + "interruption_count": getattr(context, "interruption_count", 0), + "afc_threshold_adjustment": getattr(context, "afc_threshold_adjustment", 0.0), + "created_time": created_time, + "last_access_time": last_access_time, + "access_count": access_count, + "uptime_seconds": current_time - created_time, + "idle_seconds": current_time - last_access_time, + } + except Exception as e: + logger.error(f"获取流统计失败 {stream_id}: {e}", exc_info=True) + return None + + def get_manager_statistics(self) -> Dict[str, Any]: + """获取管理器统计信息 + + Returns: + Dict[str, Any]: 管理器统计信息 + """ + current_time = time.time() + uptime = current_time - self.stats.get("creation_time", current_time) + + return { + **self.stats, + "uptime_hours": uptime / 3600, + "stream_count": len(self.stream_contexts), + "metadata_count": len(self.context_metadata), + "auto_cleanup_enabled": self.auto_cleanup, + "cleanup_interval": self.cleanup_interval, + } + + def cleanup_inactive_contexts(self, max_inactive_hours: int = 24) -> int: + """清理不活跃的上下文 + + Args: + max_inactive_hours: 最大不活跃小时数 + + Returns: + int: 清理的上下文数量 + """ + current_time = time.time() + max_inactive_seconds = max_inactive_hours * 3600 + + inactive_streams = [] + for stream_id, context in self.stream_contexts.items(): + try: + # 获取最后活动时间 + metadata = self.context_metadata.get(stream_id, {}) + last_activity = metadata.get("last_access_time", metadata.get("created_time", 0)) + context_last_activity = getattr(context, "last_check_time", 0) + actual_last_activity = max(last_activity, context_last_activity) + + # 检查是否不活跃 + unread_count = len(getattr(context, "unread_messages", [])) + history_count = len(getattr(context, "history_messages", [])) + total_messages = unread_count + history_count + + if (current_time - actual_last_activity > max_inactive_seconds and + total_messages == 0): + inactive_streams.append(stream_id) + except Exception as e: + logger.warning(f"检查上下文活跃状态失败 {stream_id}: {e}") + continue + + # 清理不活跃上下文 + cleaned_count = 0 + for stream_id in inactive_streams: + if self.remove_stream_context(stream_id): + cleaned_count += 1 + + if cleaned_count > 0: + logger.info(f"清理了 {cleaned_count} 个不活跃上下文") + + return cleaned_count + + def validate_context_integrity(self, stream_id: str) -> bool: + """验证上下文完整性 + + Args: + stream_id: 流ID + + Returns: + bool: 是否完整 + """ + context = self.get_stream_context(stream_id) + if not context: + return False + + try: + # 检查基本属性 + required_attrs = ["stream_id", "unread_messages", "history_messages"] + for attr in required_attrs: + if not hasattr(context, attr): + logger.warning(f"上下文缺少必要属性: {attr}") + return False + + # 检查消息ID唯一性 + all_messages = getattr(context, "unread_messages", []) + getattr(context, "history_messages", []) + message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")] + if len(message_ids) != len(set(message_ids)): + logger.warning(f"上下文中存在重复消息ID: {stream_id}") + return False + + return True + + except Exception as e: + logger.error(f"验证上下文完整性失败 {stream_id}: {e}") + return False + + async def start(self) -> None: + """启动上下文管理器""" + if self.is_running: + logger.warning("上下文管理器已经在运行") + return + + await self.start_auto_cleanup() + logger.info("上下文管理器已启动") + + async def stop(self) -> None: + """停止上下文管理器""" + if not self.is_running: + return + + await self.stop_auto_cleanup() + logger.info("上下文管理器已停止") + + async def start_auto_cleanup(self, interval: Optional[float] = None) -> None: + """启动自动清理 + + Args: + interval: 清理间隔(秒) + """ + if not self.auto_cleanup: + logger.info("自动清理已禁用") + return + + if self.is_running: + logger.warning("自动清理已在运行") + return + + self.is_running = True + cleanup_interval = interval or self.cleanup_interval + logger.info(f"启动自动清理(间隔: {cleanup_interval}s)") + + import asyncio + self.cleanup_task = asyncio.create_task(self._cleanup_loop(cleanup_interval)) + + async def stop_auto_cleanup(self) -> None: + """停止自动清理""" + self.is_running = False + if self.cleanup_task and not self.cleanup_task.done(): + self.cleanup_task.cancel() + try: + await self.cleanup_task + except Exception: + pass + logger.info("自动清理已停止") + + async def _cleanup_loop(self, interval: float) -> None: + """清理循环 + + Args: + interval: 清理间隔 + """ + while self.is_running: + try: + await asyncio.sleep(interval) + self.cleanup_inactive_contexts() + self._cleanup_expired_contexts() + logger.debug("自动清理完成") + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"清理循环出错: {e}", exc_info=True) + await asyncio.sleep(interval) + + def _cleanup_expired_contexts(self) -> None: + """清理过期上下文""" + current_time = time.time() + expired_contexts = [] + + for stream_id, metadata in self.context_metadata.items(): + created_time = metadata.get("created_time", current_time) + if current_time - created_time > self.context_ttl: + expired_contexts.append(stream_id) + + for stream_id in expired_contexts: + self.remove_stream_context(stream_id) + + if expired_contexts: + logger.info(f"清理了 {len(expired_contexts)} 个过期上下文") + + def get_active_streams(self) -> List[str]: + """获取活跃流列表 + + Returns: + List[str]: 活跃流ID列表 + """ + return list(self.stream_contexts.keys()) + + +# 全局上下文管理器实例 +context_manager = StreamContextManager() \ No newline at end of file diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py new file mode 100644 index 000000000..ab3579589 --- /dev/null +++ b/src/chat/message_manager/distribution_manager.py @@ -0,0 +1,1004 @@ +""" +重构后的动态消息分发管理器 +提供高效、智能的消息分发调度功能 +""" + +import asyncio +import time +from typing import Dict, List, Optional, Set, Any, Callable +from dataclasses import dataclass, field +from enum import Enum +from heapq import heappush, heappop +from abc import ABC, abstractmethod + +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.energy_system import energy_manager + +logger = get_logger("distribution_manager") + + +class DistributionPriority(Enum): + """分发优先级""" + CRITICAL = 0 # 关键(立即处理) + HIGH = 1 # 高优先级 + NORMAL = 2 # 正常优先级 + LOW = 3 # 低优先级 + BACKGROUND = 4 # 后台优先级 + + def __lt__(self, other: 'DistributionPriority') -> bool: + """用于优先级比较""" + return self.value < other.value + + +@dataclass +class DistributionTask: + """分发任务""" + stream_id: str + priority: DistributionPriority + energy: float + message_count: int + created_time: float = field(default_factory=time.time) + retry_count: int = 0 + max_retries: int = 3 + task_id: str = field(default_factory=lambda: f"task_{time.time()}_{id(object())}") + metadata: Dict[str, Any] = field(default_factory=dict) + + def __lt__(self, other: 'DistributionTask') -> bool: + """用于优先队列排序""" + # 首先按优先级排序 + if self.priority.value != other.priority.value: + return self.priority.value < other.priority.value + + # 相同优先级按能量排序(能量高的优先) + if abs(self.energy - other.energy) > 0.01: + return self.energy > other.energy + + # 最后按创建时间排序(先创建的优先) + return self.created_time < other.created_time + + def can_retry(self) -> bool: + """检查是否可以重试""" + return self.retry_count < self.max_retries + + def get_retry_delay(self, base_delay: float = 5.0) -> float: + """获取重试延迟""" + return base_delay * (2 ** min(self.retry_count, 3)) + + +@dataclass +class StreamDistributionState: + """流分发状态""" + stream_id: str + energy: float + last_distribution_time: float + next_distribution_time: float + message_count: int + consecutive_failures: int = 0 + is_active: bool = True + total_distributions: int = 0 + total_failures: int = 0 + average_distribution_time: float = 0.0 + metadata: Dict[str, Any] = field(default_factory=dict) + + def should_distribute(self, current_time: float) -> bool: + """检查是否应该分发""" + return (self.is_active and + current_time >= self.next_distribution_time and + self.message_count > 0) + + def update_distribution_stats(self, distribution_time: float, success: bool) -> None: + """更新分发统计""" + if success: + self.total_distributions += 1 + self.consecutive_failures = 0 + else: + self.total_failures += 1 + self.consecutive_failures += 1 + + # 更新平均分发时间 + total_attempts = self.total_distributions + self.total_failures + if total_attempts > 0: + self.average_distribution_time = ( + (self.average_distribution_time * (total_attempts - 1) + distribution_time) + / total_attempts + ) + + +class DistributionExecutor(ABC): + """分发执行器抽象基类""" + + @abstractmethod + async def execute(self, stream_id: str, context: Dict[str, Any]) -> bool: + """执行分发 + + Args: + stream_id: 流ID + context: 分发上下文 + + Returns: + bool: 是否执行成功 + """ + pass + + @abstractmethod + def get_priority(self, stream_id: str) -> DistributionPriority: + """获取流优先级 + + Args: + stream_id: 流ID + + Returns: + DistributionPriority: 优先级 + """ + pass + + +class DistributionManager: + """分发管理器 - 统一管理消息分发调度""" + + def __init__(self, max_concurrent_tasks: Optional[int] = None, retry_delay: Optional[float] = None): + # 流状态管理 + self.stream_states: Dict[str, StreamDistributionState] = {} + + # 任务队列 + self.task_queue: List[DistributionTask] = [] + self.processing_tasks: Set[str] = set() # 正在处理的stream_id + self.completed_tasks: List[DistributionTask] = [] + self.failed_tasks: List[DistributionTask] = [] + + # 统计信息 + self.stats: Dict[str, Any] = { + "total_distributed": 0, + "total_failed": 0, + "avg_distribution_time": 0.0, + "current_queue_size": 0, + "total_created_tasks": 0, + "total_completed_tasks": 0, + "total_failed_tasks": 0, + "total_retry_attempts": 0, + "peak_queue_size": 0, + "start_time": time.time(), + "last_activity_time": time.time(), + } + + # 配置参数 + self.max_concurrent_tasks = ( + max_concurrent_tasks or + getattr(global_config.chat, "max_concurrent_distributions", 3) + ) + self.retry_delay = ( + retry_delay or + getattr(global_config.chat, "distribution_retry_delay", 5.0) + ) + self.max_queue_size = getattr(global_config.chat, "max_distribution_queue_size", 1000) + self.max_history_size = getattr(global_config.chat, "max_task_history_size", 100) + + # 分发执行器 + self.executor: Optional[DistributionExecutor] = None + self.executor_callbacks: Dict[str, Callable] = {} + + # 事件循环 + self.is_running = False + self.distribution_task: Optional[asyncio.Task] = None + self.cleanup_task: Optional[asyncio.Task] = None + + # 性能监控 + self.performance_metrics: Dict[str, List[float]] = { + "distribution_times": [], + "queue_sizes": [], + "processing_counts": [], + } + self.max_metrics_size = 1000 + + logger.info(f"分发管理器初始化完成 (并发: {self.max_concurrent_tasks}, 重试延迟: {self.retry_delay}s)") + + async def start(self, cleanup_interval: float = 3600.0) -> None: + """启动分发管理器 + + Args: + cleanup_interval: 清理间隔(秒) + """ + if self.is_running: + logger.warning("分发管理器已经在运行") + return + + self.is_running = True + self.distribution_task = asyncio.create_task(self._distribution_loop()) + self.cleanup_task = asyncio.create_task(self._cleanup_loop(cleanup_interval)) + + logger.info("分发管理器已启动") + + async def stop(self) -> None: + """停止分发管理器""" + if not self.is_running: + return + + self.is_running = False + + # 取消分发任务 + if self.distribution_task and not self.distribution_task.done(): + self.distribution_task.cancel() + try: + await self.distribution_task + except asyncio.CancelledError: + pass + + # 取消清理任务 + if self.cleanup_task and not self.cleanup_task.done(): + self.cleanup_task.cancel() + try: + await self.cleanup_task + except asyncio.CancelledError: + pass + + # 取消所有处理中的任务 + for stream_id in list(self.processing_tasks): + self._cancel_stream_processing(stream_id) + + logger.info("分发管理器已停止") + + def add_stream_message(self, stream_id: str, message_count: int = 1, + priority: Optional[DistributionPriority] = None) -> bool: + """添加流消息 + + Args: + stream_id: 流ID + message_count: 消息数量 + priority: 指定优先级(可选) + + Returns: + bool: 是否成功添加 + """ + current_time = time.time() + self.stats["last_activity_time"] = current_time + + # 检查队列大小限制 + if len(self.task_queue) >= self.max_queue_size: + logger.warning(f"分发队列已满,拒绝添加: {stream_id}") + return False + + # 获取或创建流状态 + if stream_id not in self.stream_states: + self.stream_states[stream_id] = StreamDistributionState( + stream_id=stream_id, + energy=0.5, # 默认能量 + last_distribution_time=current_time, + next_distribution_time=current_time, + message_count=0, + ) + + # 更新流状态 + state = self.stream_states[stream_id] + state.message_count += message_count + + # 计算优先级 + if priority is None: + priority = self._calculate_priority(state) + + # 创建分发任务 + task = DistributionTask( + stream_id=stream_id, + priority=priority, + energy=state.energy, + message_count=state.message_count, + ) + + # 添加到任务队列 + heappush(self.task_queue, task) + self.stats["current_queue_size"] = len(self.task_queue) + self.stats["peak_queue_size"] = max(self.stats["peak_queue_size"], len(self.task_queue)) + self.stats["total_created_tasks"] += 1 + + # 记录性能指标 + self._record_performance_metric("queue_sizes", len(self.task_queue)) + + logger.debug(f"添加分发任务: {stream_id} (优先级: {priority.name}, 消息数: {message_count})") + return True + + def update_stream_energy(self, stream_id: str, energy: float) -> None: + """更新流能量 + + Args: + stream_id: 流ID + energy: 新的能量值 + """ + if stream_id in self.stream_states: + self.stream_states[stream_id].energy = max(0.0, min(1.0, energy)) + + # 失效能量管理器缓存 + energy_manager.invalidate_cache(stream_id) + + logger.debug(f"更新流能量: {stream_id} = {energy:.3f}") + + def _calculate_priority(self, state: StreamDistributionState) -> DistributionPriority: + """计算分发优先级 + + Args: + state: 流状态 + + Returns: + DistributionPriority: 优先级 + """ + energy = state.energy + message_count = state.message_count + consecutive_failures = state.consecutive_failures + total_distributions = state.total_distributions + + # 使用执行器获取优先级(如果设置) + if self.executor: + try: + return self.executor.get_priority(state.stream_id) + except Exception as e: + logger.warning(f"获取执行器优先级失败: {e}") + + # 失败次数过多,降低优先级 + if consecutive_failures >= 3: + return DistributionPriority.BACKGROUND + + # 高分发次数降低优先级 + if total_distributions > 50 and message_count < 2: + return DistributionPriority.LOW + + # 基于能量和消息数计算优先级 + if energy >= 0.8 and message_count >= 3: + return DistributionPriority.CRITICAL + elif energy >= 0.6 or message_count >= 5: + return DistributionPriority.HIGH + elif energy >= 0.3 or message_count >= 2: + return DistributionPriority.NORMAL + else: + return DistributionPriority.LOW + + async def _distribution_loop(self): + """分发主循环""" + while self.is_running: + try: + # 处理任务队列 + await self._process_task_queue() + + # 更新统计信息 + self._update_statistics() + + # 记录性能指标 + self._record_performance_metric("processing_counts", len(self.processing_tasks)) + + # 动态调整循环间隔 + queue_size = len(self.task_queue) + processing_count = len(self.processing_tasks) + sleep_time = 0.05 if queue_size > 10 or processing_count > 0 else 0.2 + + # 短暂休眠 + await asyncio.sleep(sleep_time) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"分发循环出错: {e}", exc_info=True) + await asyncio.sleep(1.0) + + async def _process_task_queue(self): + """处理任务队列""" + current_time = time.time() + + # 检查是否有可用的处理槽位 + available_slots = self.max_concurrent_tasks - len(self.processing_tasks) + if available_slots <= 0: + return + + # 处理队列中的任务 + processed_count = 0 + while (self.task_queue and + processed_count < available_slots and + len(self.processing_tasks) < self.max_concurrent_tasks): + + task = heappop(self.task_queue) + self.stats["current_queue_size"] = len(self.task_queue) + + # 检查任务是否仍然有效 + if not self._is_task_valid(task, current_time): + self._handle_invalid_task(task) + continue + + # 开始处理任务 + await self._start_task_processing(task) + processed_count += 1 + + # 记录处理统计 + if processed_count > 0: + logger.debug(f"处理了 {processed_count} 个分发任务") + + def _is_task_valid(self, task: DistributionTask, current_time: float) -> bool: + """检查任务是否有效 + + Args: + task: 分发任务 + current_time: 当前时间 + + Returns: + bool: 任务是否有效 + """ + state = self.stream_states.get(task.stream_id) + if not state or not state.is_active: + return False + + # 检查任务是否已过期 + if current_time - task.created_time > 3600: # 1小时 + return False + + # 检查是否达到了分发时间 + return state.should_distribute(current_time) + + def _handle_invalid_task(self, task: DistributionTask) -> None: + """处理无效任务 + + Args: + task: 无效的任务 + """ + logger.debug(f"任务无效,丢弃: {task.stream_id} (创建时间: {task.created_time})") + # 可以添加到历史记录中用于分析 + if len(self.failed_tasks) < self.max_history_size: + self.failed_tasks.append(task) + + async def _start_task_processing(self, task: DistributionTask) -> None: + """开始处理任务 + + Args: + task: 分发任务 + """ + stream_id = task.stream_id + state = self.stream_states[stream_id] + current_time = time.time() + + # 标记为处理中 + self.processing_tasks.add(stream_id) + state.last_distribution_time = current_time + + # 计算下次分发时间 + interval = energy_manager.get_distribution_interval(state.energy) + state.next_distribution_time = current_time + interval + + # 记录开始处理 + logger.info(f"开始处理分发任务: {stream_id} " + f"(能量: {state.energy:.3f}, " + f"消息数: {state.message_count}, " + f"周期: {interval:.1f}s, " + f"重试次数: {task.retry_count})") + + # 创建处理任务 + asyncio.create_task(self._process_distribution_task(task)) + + async def _process_distribution_task(self, task: DistributionTask) -> None: + """处理分发任务 + + Args: + task: 分发任务 + """ + stream_id = task.stream_id + start_time = time.time() + + try: + # 调用外部处理函数 + success = await self._execute_distribution(stream_id) + + if success: + # 处理成功 + self._handle_task_success(task, start_time) + else: + # 处理失败 + await self._handle_task_failure(task) + + except Exception as e: + logger.error(f"处理分发任务失败 {stream_id}: {e}", exc_info=True) + await self._handle_task_failure(task) + + finally: + # 清理处理状态 + self.processing_tasks.discard(stream_id) + self.stats["last_activity_time"] = time.time() + + async def _execute_distribution(self, stream_id: str) -> bool: + """执行分发(需要外部实现) + + Args: + stream_id: 流ID + + Returns: + bool: 是否执行成功 + """ + # 使用执行器处理分发 + if self.executor: + try: + state = self.stream_states.get(stream_id) + context = { + "stream_id": stream_id, + "energy": state.energy if state else 0.5, + "message_count": state.message_count if state else 0, + "task_metadata": {}, + } + return await self.executor.execute(stream_id, context) + except Exception as e: + logger.error(f"执行器分发失败 {stream_id}: {e}") + return False + + # 回退到回调函数 + callback = self.executor_callbacks.get(stream_id) + if callback: + try: + result = callback(stream_id) + if asyncio.iscoroutine(result): + return await result + return bool(result) + except Exception as e: + logger.error(f"回调分发失败 {stream_id}: {e}") + return False + + # 默认处理 + logger.debug(f"执行分发: {stream_id}") + return True + + def _handle_task_success(self, task: DistributionTask, start_time: float) -> None: + """处理任务成功 + + Args: + task: 成功的任务 + start_time: 开始时间 + """ + stream_id = task.stream_id + state = self.stream_states.get(stream_id) + distribution_time = time.time() - start_time + + if state: + # 更新流状态 + state.update_distribution_stats(distribution_time, True) + state.message_count = 0 # 清空消息计数 + + # 更新全局统计 + self.stats["total_distributed"] += 1 + self.stats["total_completed_tasks"] += 1 + + # 更新平均分发时间 + if self.stats["total_distributed"] > 0: + self.stats["avg_distribution_time"] = ( + (self.stats["avg_distribution_time"] * (self.stats["total_distributed"] - 1) + distribution_time) + / self.stats["total_distributed"] + ) + + # 记录性能指标 + self._record_performance_metric("distribution_times", distribution_time) + + # 添加到成功任务历史 + if len(self.completed_tasks) < self.max_history_size: + self.completed_tasks.append(task) + + logger.info(f"分发任务成功: {stream_id} (耗时: {distribution_time:.2f}s, 重试: {task.retry_count})") + + async def _handle_task_failure(self, task: DistributionTask) -> None: + """处理任务失败 + + Args: + task: 失败的任务 + """ + stream_id = task.stream_id + state = self.stream_states.get(stream_id) + distribution_time = time.time() - task.created_time + + if state: + # 更新流状态 + state.update_distribution_stats(distribution_time, False) + + # 增加失败计数 + state.consecutive_failures += 1 + + # 计算重试延迟 + retry_delay = task.get_retry_delay(self.retry_delay) + task.retry_count += 1 + self.stats["total_retry_attempts"] += 1 + + # 如果还有重试机会,重新添加到队列 + if task.can_retry(): + # 等待重试延迟 + await asyncio.sleep(retry_delay) + + # 重新计算优先级(失败后降低优先级) + task.priority = DistributionPriority.LOW + + # 重新添加到队列 + heappush(self.task_queue, task) + self.stats["current_queue_size"] = len(self.task_queue) + + logger.warning(f"分发任务失败,准备重试: {stream_id} " + f"(重试次数: {task.retry_count}/{task.max_retries}, " + f"延迟: {retry_delay:.1f}s)") + else: + # 超过重试次数,标记为不活跃 + state.is_active = False + self.stats["total_failed"] += 1 + self.stats["total_failed_tasks"] += 1 + + # 添加到失败任务历史 + if len(self.failed_tasks) < self.max_history_size: + self.failed_tasks.append(task) + + logger.error(f"分发任务最终失败: {stream_id} (重试次数: {task.retry_count})") + + def _cancel_stream_processing(self, stream_id: str) -> None: + """取消流处理 + + Args: + stream_id: 流ID + """ + # 从处理集合中移除 + self.processing_tasks.discard(stream_id) + + # 更新流状态 + if stream_id in self.stream_states: + self.stream_states[stream_id].is_active = False + + logger.info(f"取消流处理: {stream_id}") + + def _update_statistics(self) -> None: + """更新统计信息""" + # 更新当前队列大小 + self.stats["current_queue_size"] = len(self.task_queue) + + # 更新运行时间 + if self.is_running: + self.stats["uptime"] = time.time() - self.stats["start_time"] + + # 更新性能统计 + self.stats["avg_queue_size"] = ( + sum(self.performance_metrics["queue_sizes"]) / + max(1, len(self.performance_metrics["queue_sizes"])) + ) + + self.stats["avg_processing_count"] = ( + sum(self.performance_metrics["processing_counts"]) / + max(1, len(self.performance_metrics["processing_counts"])) + ) + + def _record_performance_metric(self, metric_name: str, value: float) -> None: + """记录性能指标 + + Args: + metric_name: 指标名称 + value: 指标值 + """ + if metric_name in self.performance_metrics: + metrics = self.performance_metrics[metric_name] + metrics.append(value) + # 保持大小限制 + if len(metrics) > self.max_metrics_size: + metrics.pop(0) + + async def _cleanup_loop(self, interval: float) -> None: + """清理循环 + + Args: + interval: 清理间隔 + """ + while self.is_running: + try: + await asyncio.sleep(interval) + self._cleanup_expired_data() + logger.debug(f"清理完成,保留 {len(self.completed_tasks)} 个成功任务,{len(self.failed_tasks)} 个失败任务") + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"清理循环出错: {e}") + + def _cleanup_expired_data(self) -> None: + """清理过期数据""" + current_time = time.time() + max_age = 24 * 3600 # 24小时 + + # 清理过期的成功任务 + self.completed_tasks = [ + task for task in self.completed_tasks + if current_time - task.created_time < max_age + ] + + # 清理过期的失败任务 + self.failed_tasks = [ + task for task in self.failed_tasks + if current_time - task.created_time < max_age + ] + + # 清理性能指标 + for metric_name in self.performance_metrics: + if len(self.performance_metrics[metric_name]) > self.max_metrics_size: + self.performance_metrics[metric_name] = ( + self.performance_metrics[metric_name][-self.max_metrics_size:] + ) + + def get_stream_status(self, stream_id: str) -> Optional[Dict[str, Any]]: + """获取流状态 + + Args: + stream_id: 流ID + + Returns: + Optional[Dict[str, Any]]: 流状态信息 + """ + if stream_id not in self.stream_states: + return None + + state = self.stream_states[stream_id] + current_time = time.time() + time_until_next = max(0, state.next_distribution_time - current_time) + + return { + "stream_id": state.stream_id, + "energy": state.energy, + "message_count": state.message_count, + "last_distribution_time": state.last_distribution_time, + "next_distribution_time": state.next_distribution_time, + "time_until_next_distribution": time_until_next, + "consecutive_failures": state.consecutive_failures, + "total_distributions": state.total_distributions, + "total_failures": state.total_failures, + "average_distribution_time": state.average_distribution_time, + "is_active": state.is_active, + "is_processing": stream_id in self.processing_tasks, + "uptime": current_time - state.last_distribution_time, + } + + def get_queue_status(self) -> Dict[str, Any]: + """获取队列状态 + + Returns: + Dict[str, Any]: 队列状态信息 + """ + current_time = time.time() + uptime = current_time - self.stats["start_time"] if self.is_running else 0 + + # 分析任务优先级分布 + priority_counts = {} + for task in self.task_queue: + priority_name = task.priority.name + priority_counts[priority_name] = priority_counts.get(priority_name, 0) + 1 + + return { + "queue_size": len(self.task_queue), + "processing_count": len(self.processing_tasks), + "max_concurrent": self.max_concurrent_tasks, + "max_queue_size": self.max_queue_size, + "is_running": self.is_running, + "uptime": uptime, + "priority_distribution": priority_counts, + "stats": self.stats.copy(), + "performance_metrics": { + name: { + "count": len(metrics), + "avg": sum(metrics) / max(1, len(metrics)), + "min": min(metrics) if metrics else 0, + "max": max(metrics) if metrics else 0, + } + for name, metrics in self.performance_metrics.items() + }, + } + + def deactivate_stream(self, stream_id: str) -> bool: + """停用流 + + Args: + stream_id: 流ID + + Returns: + bool: 是否成功停用 + """ + if stream_id in self.stream_states: + self.stream_states[stream_id].is_active = False + # 取消正在处理的任务 + if stream_id in self.processing_tasks: + self._cancel_stream_processing(stream_id) + logger.info(f"停用流: {stream_id}") + return True + return False + + def activate_stream(self, stream_id: str) -> bool: + """激活流 + + Args: + stream_id: 流ID + + Returns: + bool: 是否成功激活 + """ + if stream_id in self.stream_states: + self.stream_states[stream_id].is_active = True + self.stream_states[stream_id].consecutive_failures = 0 + self.stream_states[stream_id].next_distribution_time = time.time() + logger.info(f"激活流: {stream_id}") + return True + return False + + def cleanup_inactive_streams(self, max_inactive_hours: int = 24) -> int: + """清理不活跃的流 + + Args: + max_inactive_hours: 最大不活跃小时数 + + Returns: + int: 清理的流数量 + """ + current_time = time.time() + max_inactive_seconds = max_inactive_hours * 3600 + + inactive_streams = [] + for stream_id, state in self.stream_states.items(): + if (not state.is_active and + current_time - state.last_distribution_time > max_inactive_seconds and + state.message_count == 0): + inactive_streams.append(stream_id) + + for stream_id in inactive_streams: + del self.stream_states[stream_id] + # 同时清理处理中的任务 + self.processing_tasks.discard(stream_id) + logger.debug(f"清理不活跃流: {stream_id}") + + if inactive_streams: + logger.info(f"清理了 {len(inactive_streams)} 个不活跃流") + + return len(inactive_streams) + + def set_executor(self, executor: DistributionExecutor) -> None: + """设置分发执行器 + + Args: + executor: 分发执行器实例 + """ + self.executor = executor + logger.info(f"设置分发执行器: {executor.__class__.__name__}") + + def register_callback(self, stream_id: str, callback: Callable) -> None: + """注册分发回调 + + Args: + stream_id: 流ID + callback: 回调函数 + """ + self.executor_callbacks[stream_id] = callback + logger.debug(f"注册分发回调: {stream_id}") + + def unregister_callback(self, stream_id: str) -> bool: + """注销分发回调 + + Args: + stream_id: 流ID + + Returns: + bool: 是否成功注销 + """ + if stream_id in self.executor_callbacks: + del self.executor_callbacks[stream_id] + logger.debug(f"注销分发回调: {stream_id}") + return True + return False + + def get_task_history(self, limit: int = 50) -> Dict[str, List[Dict[str, Any]]]: + """获取任务历史 + + Args: + limit: 返回数量限制 + + Returns: + Dict[str, List[Dict[str, Any]]]: 任务历史 + """ + def task_to_dict(task: DistributionTask) -> Dict[str, Any]: + return { + "task_id": task.task_id, + "stream_id": task.stream_id, + "priority": task.priority.name, + "energy": task.energy, + "message_count": task.message_count, + "created_time": task.created_time, + "retry_count": task.retry_count, + "max_retries": task.max_retries, + "metadata": task.metadata, + } + + return { + "completed_tasks": [task_to_dict(task) for task in self.completed_tasks[-limit:]], + "failed_tasks": [task_to_dict(task) for task in self.failed_tasks[-limit:]], + } + + def get_performance_summary(self) -> Dict[str, Any]: + """获取性能摘要 + + Returns: + Dict[str, Any]: 性能摘要 + """ + current_time = time.time() + uptime = current_time - self.stats["start_time"] + + # 计算成功率 + total_attempts = self.stats["total_completed_tasks"] + self.stats["total_failed_tasks"] + success_rate = ( + self.stats["total_completed_tasks"] / max(1, total_attempts) + ) if total_attempts > 0 else 0.0 + + # 计算吞吐量 + throughput = ( + self.stats["total_completed_tasks"] / max(1, uptime / 3600) + ) # 每小时完成任务数 + + return { + "uptime_hours": uptime / 3600, + "success_rate": success_rate, + "throughput_per_hour": throughput, + "avg_distribution_time": self.stats["avg_distribution_time"], + "total_retry_attempts": self.stats["total_retry_attempts"], + "peak_queue_size": self.stats["peak_queue_size"], + "active_streams": len(self.stream_states), + "processing_tasks": len(self.processing_tasks), + } + + def reset_statistics(self) -> None: + """重置统计信息""" + self.stats.update({ + "total_distributed": 0, + "total_failed": 0, + "avg_distribution_time": 0.0, + "current_queue_size": len(self.task_queue), + "total_created_tasks": 0, + "total_completed_tasks": 0, + "total_failed_tasks": 0, + "total_retry_attempts": 0, + "peak_queue_size": 0, + "start_time": time.time(), + "last_activity_time": time.time(), + }) + + # 清空性能指标 + for metrics in self.performance_metrics.values(): + metrics.clear() + + logger.info("分发管理器统计信息已重置") + + def get_all_stream_states(self) -> Dict[str, Dict[str, Any]]: + """获取所有流状态 + + Returns: + Dict[str, Dict[str, Any]]: 所有流状态 + """ + return { + stream_id: self.get_stream_status(stream_id) + for stream_id in self.stream_states.keys() + } + + def force_process_stream(self, stream_id: str) -> bool: + """强制处理指定流 + + Args: + stream_id: 流ID + + Returns: + bool: 是否成功触发处理 + """ + if stream_id not in self.stream_states: + return False + + state = self.stream_states[stream_id] + if not state.is_active: + return False + + # 创建高优先级任务 + task = DistributionTask( + stream_id=stream_id, + priority=DistributionPriority.CRITICAL, + energy=state.energy, + message_count=state.message_count, + ) + + # 添加到队列 + heappush(self.task_queue, task) + self.stats["current_queue_size"] = len(self.task_queue) + + logger.info(f"强制处理流: {stream_id}") + return True + + +# 全局分发管理器实例 +distribution_manager = DistributionManager() \ No newline at end of file diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py new file mode 100644 index 000000000..7c0d77828 --- /dev/null +++ b/src/chat/message_manager/message_manager.py @@ -0,0 +1,558 @@ +""" +消息管理模块 +管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息 +""" + +import asyncio +import random +import time +import traceback +from typing import Dict, Optional, Any, TYPE_CHECKING + +from src.common.logger import get_logger +from src.common.data_models.database_data_model import DatabaseMessages +from src.common.data_models.message_manager_data_model import StreamContext, MessageManagerStats, StreamStats +from src.chat.chatter_manager import ChatterManager +from src.chat.planner_actions.action_manager import ChatterActionManager +from src.plugin_system.base.component_types import ChatMode +from .sleep_manager.sleep_manager import SleepManager +from .sleep_manager.wakeup_manager import WakeUpManager +from src.config.config import global_config +from .context_manager import context_manager + +if TYPE_CHECKING: + from src.common.data_models.message_manager_data_model import StreamContext + +logger = get_logger("message_manager") + + +class MessageManager: + """消息管理器""" + + def __init__(self, check_interval: float = 5.0): + self.check_interval = check_interval # 检查间隔(秒) + self.is_running = False + self.manager_task: Optional[asyncio.Task] = None + + # 统计信息 + self.stats = MessageManagerStats() + + # 初始化chatter manager + self.action_manager = ChatterActionManager() + self.chatter_manager = ChatterManager(self.action_manager) + + # 初始化睡眠和唤醒管理器 + self.sleep_manager = SleepManager() + self.wakeup_manager = WakeUpManager(self.sleep_manager) + + # 初始化上下文管理器 + self.context_manager = context_manager + + async def start(self): + """启动消息管理器""" + if self.is_running: + logger.warning("消息管理器已经在运行") + return + + self.is_running = True + self.manager_task = asyncio.create_task(self._manager_loop()) + await self.wakeup_manager.start() + await self.context_manager.start() + logger.info("消息管理器已启动") + + async def stop(self): + """停止消息管理器""" + if not self.is_running: + return + + self.is_running = False + + # 停止所有流处理任务 + # 注意:context_manager 会自己清理任务 + if self.manager_task and not self.manager_task.done(): + self.manager_task.cancel() + + await self.wakeup_manager.stop() + await self.context_manager.stop() + + logger.info("消息管理器已停止") + + def add_message(self, stream_id: str, message: DatabaseMessages): + """添加消息到指定聊天流""" + # 检查流上下文是否存在,不存在则创建 + context = self.context_manager.get_stream_context(stream_id) + if not context: + # 创建新的流上下文 + from src.common.data_models.message_manager_data_model import StreamContext + context = StreamContext(stream_id=stream_id) + # 将创建的上下文添加到管理器 + self.context_manager.add_stream_context(stream_id, context) + + # 使用 context_manager 添加消息 + success = self.context_manager.add_message_to_context(stream_id, message) + + if success: + logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}") + else: + logger.warning(f"添加消息到聊天流 {stream_id} 失败") + + def update_message( + self, + stream_id: str, + message_id: str, + interest_value: float = None, + actions: list = None, + should_reply: bool = None, + ): + """更新消息信息""" + # 使用 context_manager 更新消息信息 + context = self.context_manager.get_stream_context(stream_id) + if context: + context.update_message_info(message_id, interest_value, actions, should_reply) + + def add_action(self, stream_id: str, message_id: str, action: str): + """添加动作到消息""" + # 使用 context_manager 添加动作到消息 + context = self.context_manager.get_stream_context(stream_id) + if context: + context.add_action_to_message(message_id, action) + + async def _manager_loop(self): + """管理器主循环 - 独立聊天流分发周期版本""" + while self.is_running: + try: + # 更新睡眠状态 + await self.sleep_manager.update_sleep_state(self.wakeup_manager) + + # 执行独立分发周期的检查 + await self._check_streams_with_individual_intervals() + + # 计算下次检查时间(使用最小间隔或固定间隔) + if global_config.chat.dynamic_distribution_enabled: + next_check_delay = self._calculate_next_manager_delay() + else: + next_check_delay = self.check_interval + + await asyncio.sleep(next_check_delay) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"消息管理器循环出错: {e}") + traceback.print_exc() + + async def _check_all_streams(self): + """检查所有聊天流""" + active_streams = 0 + total_unread = 0 + + # 使用 context_manager 获取活跃的流 + active_stream_ids = self.context_manager.get_active_streams() + + for stream_id in active_stream_ids: + context = self.context_manager.get_stream_context(stream_id) + if not context: + continue + + active_streams += 1 + + # 检查是否有未读消息 + unread_messages = self.context_manager.get_unread_messages(stream_id) + if unread_messages: + total_unread += len(unread_messages) + + # 如果没有处理任务,创建一个 + if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done(): + context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) + + # 更新统计 + self.stats.active_streams = active_streams + self.stats.total_unread_messages = total_unread + + async def _process_stream_messages(self, stream_id: str): + """处理指定聊天流的消息""" + context = self.context_manager.get_stream_context(stream_id) + if not context: + return + + try: + # 获取未读消息 + unread_messages = self.context_manager.get_unread_messages(stream_id) + if not unread_messages: + return + + # 检查是否需要打断现有处理 + await self._check_and_handle_interruption(context, stream_id) + + # --- 睡眠状态检查 --- + if self.sleep_manager.is_sleeping(): + logger.info(f"Bot正在睡觉,检查聊天流 {stream_id} 是否有唤醒触发器。") + + was_woken_up = False + is_private = context.is_private_chat() + + for message in unread_messages: + is_mentioned = message.is_mentioned or False + if not is_mentioned and not is_private: + bot_names = [global_config.bot.nickname] + global_config.bot.alias_names + if any(name in message.processed_plain_text for name in bot_names): + is_mentioned = True + logger.debug(f"通过关键词 '{next((name for name in bot_names if name in message.processed_plain_text), '')}' 匹配将消息标记为 'is_mentioned'") + + if is_private or is_mentioned: + if self.wakeup_manager.add_wakeup_value(is_private, is_mentioned, chat_id=stream_id): + was_woken_up = True + break # 一旦被吵醒,就跳出循环并处理消息 + + if not was_woken_up: + logger.debug(f"聊天流 {stream_id} 中没有唤醒触发器,保持消息未读状态。") + return # 退出,不处理消息 + + logger.info(f"Bot被聊天流 {stream_id} 中的消息吵醒,继续处理。") + elif self.sleep_manager.is_woken_up(): + angry_chat_id = self.wakeup_manager.angry_chat_id + if stream_id != angry_chat_id: + logger.debug(f"Bot处于WOKEN_UP状态,但当前流 {stream_id} 不是触发唤醒的流 {angry_chat_id},跳过处理。") + return # 退出,不处理此流的消息 + logger.info(f"Bot处于WOKEN_UP状态,处理触发唤醒的流 {stream_id}。") + # --- 睡眠状态检查结束 --- + + logger.debug(f"开始处理聊天流 {stream_id} 的 {len(unread_messages)} 条未读消息") + + # 直接使用StreamContext对象进行处理 + if unread_messages: + try: + # 记录当前chat type用于调试 + logger.debug(f"聊天流 {stream_id} 检测到的chat type: {context.chat_type.value}") + + # 发送到chatter manager,传递StreamContext对象 + results = await self.chatter_manager.process_stream_context(stream_id, context) + + # 处理结果,标记消息为已读 + if results.get("success", False): + self._clear_all_unread_messages(stream_id) + logger.debug(f"聊天流 {stream_id} 处理成功,清除了 {len(unread_messages)} 条未读消息") + else: + logger.warning(f"聊天流 {stream_id} 处理失败: {results.get('error_message', '未知错误')}") + + except Exception as e: + logger.error(f"处理聊天流 {stream_id} 时发生异常,将清除所有未读消息: {e}") + # 出现异常时也清除未读消息,避免重复处理 + self._clear_all_unread_messages(stream_id) + raise + + logger.debug(f"聊天流 {stream_id} 消息处理完成") + + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"处理聊天流 {stream_id} 消息时出错: {e}") + traceback.print_exc() + + def deactivate_stream(self, stream_id: str): + """停用聊天流""" + context = self.context_manager.get_stream_context(stream_id) + if context: + context.is_active = False + + # 取消处理任务 + if hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done(): + context.processing_task.cancel() + + logger.info(f"停用聊天流: {stream_id}") + + def activate_stream(self, stream_id: str): + """激活聊天流""" + context = self.context_manager.get_stream_context(stream_id) + if context: + context.is_active = True + logger.info(f"激活聊天流: {stream_id}") + + def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]: + """获取聊天流统计""" + context = self.context_manager.get_stream_context(stream_id) + if not context: + return None + + return StreamStats( + stream_id=stream_id, + is_active=context.is_active, + unread_count=len(self.context_manager.get_unread_messages(stream_id)), + history_count=len(context.history_messages), + last_check_time=context.last_check_time, + has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()), + ) + + def get_manager_stats(self) -> Dict[str, Any]: + """获取管理器统计""" + return { + "total_streams": self.stats.total_streams, + "active_streams": self.stats.active_streams, + "total_unread_messages": self.stats.total_unread_messages, + "total_processed_messages": self.stats.total_processed_messages, + "uptime": self.stats.uptime, + "start_time": self.stats.start_time, + } + + def cleanup_inactive_streams(self, max_inactive_hours: int = 24): + """清理不活跃的聊天流""" + # 使用 context_manager 的自动清理功能 + self.context_manager.cleanup_inactive_contexts(max_inactive_hours * 3600) + logger.info("已启动不活跃聊天流清理") + + async def _check_and_handle_interruption(self, context: StreamContext, stream_id: str): + """检查并处理消息打断""" + if not global_config.chat.interruption_enabled: + return + + # 检查是否有正在进行的处理任务 + if context.processing_task and not context.processing_task.done(): + # 计算打断概率 + interruption_probability = context.calculate_interruption_probability( + global_config.chat.interruption_max_limit, global_config.chat.interruption_probability_factor + ) + + # 检查是否已达到最大打断次数 + if context.interruption_count >= global_config.chat.interruption_max_limit: + logger.debug( + f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit},跳过打断检查" + ) + return + + # 根据概率决定是否打断 + if random.random() < interruption_probability: + logger.info(f"聊天流 {stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}") + + # 取消现有任务 + context.processing_task.cancel() + try: + await context.processing_task + except asyncio.CancelledError: + pass + + # 增加打断计数并应用afc阈值降低 + context.increment_interruption_count() + context.apply_interruption_afc_reduction(global_config.chat.interruption_afc_reduction) + + # 检查是否已达到最大次数 + if context.interruption_count >= global_config.chat.interruption_max_limit: + logger.warning( + f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit},后续消息将不再打断" + ) + else: + logger.info( + f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}, afc阈值调整: {context.get_afc_threshold_adjustment()}" + ) + else: + logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}") + + def _calculate_stream_distribution_interval(self, context: StreamContext) -> float: + """计算单个聊天流的分发周期 - 使用重构后的能量管理器""" + if not global_config.chat.dynamic_distribution_enabled: + return self.check_interval # 使用固定间隔 + + try: + from src.chat.energy_system import energy_manager + from src.plugin_system.apis.chat_api import get_chat_manager + + # 获取聊天流和能量 + chat_stream = get_chat_manager().get_stream(context.stream_id) + if chat_stream: + focus_energy = chat_stream.focus_energy + # 使用能量管理器获取分发周期 + interval = energy_manager.get_distribution_interval(focus_energy) + logger.debug(f"流 {context.stream_id} 分发周期: {interval:.2f}s (能量: {focus_energy:.3f})") + return interval + else: + # 默认间隔 + return self.check_interval + + except Exception as e: + logger.error(f"计算分发周期失败: {e}") + return self.check_interval + + def _calculate_next_manager_delay(self) -> float: + """计算管理器下次检查的延迟时间""" + current_time = time.time() + min_delay = float("inf") + + # 找到最近需要检查的流 + active_stream_ids = self.context_manager.get_active_streams() + for stream_id in active_stream_ids: + context = self.context_manager.get_stream_context(stream_id) + if not context or not context.is_active: + continue + + time_until_check = context.next_check_time - current_time + if time_until_check > 0: + min_delay = min(min_delay, time_until_check) + else: + min_delay = 0.1 # 立即检查 + break + + # 如果没有活跃流,使用默认间隔 + if min_delay == float("inf"): + return self.check_interval + + # 确保最小延迟 + return max(0.1, min(min_delay, self.check_interval)) + + async def _check_streams_with_individual_intervals(self): + """检查所有达到检查时间的聊天流""" + current_time = time.time() + processed_streams = 0 + + # 使用 context_manager 获取活跃的流 + active_stream_ids = self.context_manager.get_active_streams() + + for stream_id in active_stream_ids: + context = self.context_manager.get_stream_context(stream_id) + if not context or not context.is_active: + continue + + # 检查是否达到检查时间 + if current_time >= context.next_check_time: + # 更新检查时间 + context.last_check_time = current_time + + # 计算下次检查时间和分发周期 + if global_config.chat.dynamic_distribution_enabled: + context.distribution_interval = self._calculate_stream_distribution_interval(context) + else: + context.distribution_interval = self.check_interval + + # 设置下次检查时间 + context.next_check_time = current_time + context.distribution_interval + + # 检查未读消息 + unread_messages = self.context_manager.get_unread_messages(stream_id) + if unread_messages: + processed_streams += 1 + self.stats.total_unread_messages = len(unread_messages) + + # 如果没有处理任务,创建一个 + if not context.processing_task or context.processing_task.done(): + from src.plugin_system.apis.chat_api import get_chat_manager + + chat_stream = get_chat_manager().get_stream(context.stream_id) + focus_energy = chat_stream.focus_energy if chat_stream else 0.5 + + # 根据优先级记录日志 + if focus_energy >= 0.7: + logger.info( + f"高优先级流 {stream_id} 开始处理 | " + f"focus_energy: {focus_energy:.3f} | " + f"分发周期: {context.distribution_interval:.2f}s | " + f"未读消息: {len(unread_messages)}" + ) + else: + logger.debug( + f"流 {stream_id} 开始处理 | " + f"focus_energy: {focus_energy:.3f} | " + f"分发周期: {context.distribution_interval:.2f}s" + ) + + context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) + + # 更新活跃流计数 + active_count = len(self.context_manager.get_active_streams()) + self.stats.active_streams = active_count + + if processed_streams > 0: + logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}") + + async def _check_all_streams_with_priority(self): + """按优先级检查所有聊天流,高focus_energy的流优先处理""" + if not self.context_manager.get_active_streams(): + return + + # 获取活跃的聊天流并按focus_energy排序 + active_streams = [] + active_stream_ids = self.context_manager.get_active_streams() + + for stream_id in active_stream_ids: + context = self.context_manager.get_stream_context(stream_id) + if not context or not context.is_active: + continue + + # 获取focus_energy,如果不存在则使用默认值 + from src.plugin_system.apis.chat_api import get_chat_manager + + chat_stream = get_chat_manager().get_stream(context.stream_id) + focus_energy = 0.5 + if chat_stream: + focus_energy = chat_stream.focus_energy + + # 计算流优先级分数 + priority_score = self._calculate_stream_priority(context, focus_energy) + active_streams.append((priority_score, stream_id, context)) + + # 按优先级降序排序 + active_streams.sort(reverse=True, key=lambda x: x[0]) + + # 处理排序后的流 + active_stream_count = 0 + total_unread = 0 + + for priority_score, stream_id, context in active_streams: + active_stream_count += 1 + + # 检查是否有未读消息 + unread_messages = self.context_manager.get_unread_messages(stream_id) + if unread_messages: + total_unread += len(unread_messages) + + # 如果没有处理任务,创建一个 + if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done(): + context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) + + # 高优先级流的额外日志 + if priority_score > 0.7: + logger.info( + f"高优先级流 {stream_id} 开始处理 | " + f"优先级: {priority_score:.3f} | " + f"未读消息: {len(unread_messages)}" + ) + + # 更新统计 + self.stats.active_streams = active_stream_count + self.stats.total_unread_messages = total_unread + + def _calculate_stream_priority(self, context: StreamContext, focus_energy: float) -> float: + """计算聊天流的优先级分数 - 简化版本,主要使用focus_energy""" + # 使用重构后的能量管理器,主要依赖focus_energy + base_priority = focus_energy + + # 简单的未读消息加权 + unread_count = len(context.get_unread_messages()) + message_bonus = min(unread_count * 0.05, 0.2) # 最多20%加成 + + # 简单的时间加权 + current_time = time.time() + time_since_active = current_time - context.last_check_time + time_bonus = max(0, 1.0 - time_since_active / 7200.0) * 0.1 # 2小时内衰减 + + final_priority = base_priority + message_bonus + time_bonus + return max(0.0, min(1.0, final_priority)) + + def _clear_all_unread_messages(self, stream_id: str): + """清除指定上下文中的所有未读消息,防止意外情况导致消息一直未读""" + unread_messages = self.context_manager.get_unread_messages(stream_id) + if not unread_messages: + return + + logger.warning(f"正在清除 {len(unread_messages)} 条未读消息") + + # 将所有未读消息标记为已读 + context = self.context_manager.get_stream_context(stream_id) + if context: + for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表 + try: + context.mark_message_as_read(msg.message_id) + self.stats.total_processed_messages += 1 + logger.debug(f"强制清除消息 {msg.message_id},标记为已读") + except Exception as e: + logger.error(f"清除消息 {msg.message_id} 时出错: {e}") + + +# 创建全局消息管理器实例 +message_manager = MessageManager() diff --git a/src/chat/message_manager/sleep_manager/notification_sender.py b/src/chat/message_manager/sleep_manager/notification_sender.py new file mode 100644 index 000000000..07e8b09d4 --- /dev/null +++ b/src/chat/message_manager/sleep_manager/notification_sender.py @@ -0,0 +1,33 @@ +from src.common.logger import get_logger + +#from ..hfc_context import HfcContext + +logger = get_logger("notification_sender") + + +class NotificationSender: + @staticmethod + async def send_goodnight_notification(context): # type: ignore + """发送晚安通知""" + #try: + #from ..proactive.events import ProactiveTriggerEvent + #from ..proactive.proactive_thinker import ProactiveThinker + + #event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight") + #proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor) + #await proactive_thinker.think(event) + #except Exception as e: + #logger.error(f"发送晚安通知失败: {e}") + + @staticmethod + async def send_insomnia_notification(context, reason: str): # type: ignore + """发送失眠通知""" + #try: + #from ..proactive.events import ProactiveTriggerEvent + #from ..proactive.proactive_thinker import ProactiveThinker + + #event = ProactiveTriggerEvent(source="sleep_manager", reason=reason) + #proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor) + #await proactive_thinker.think(event) + #except Exception as e: + #logger.error(f"发送失眠通知失败: {e}") \ No newline at end of file diff --git a/src/chat/chat_loop/sleep_manager/sleep_manager.py b/src/chat/message_manager/sleep_manager/sleep_manager.py similarity index 68% rename from src/chat/chat_loop/sleep_manager/sleep_manager.py rename to src/chat/message_manager/sleep_manager/sleep_manager.py index ad4aa1ced..0ed21e685 100644 --- a/src/chat/chat_loop/sleep_manager/sleep_manager.py +++ b/src/chat/message_manager/sleep_manager/sleep_manager.py @@ -6,11 +6,11 @@ from typing import Optional, TYPE_CHECKING from src.common.logger import get_logger from src.config.config import global_config from .notification_sender import NotificationSender -from .sleep_state import SleepState, SleepStateSerializer +from .sleep_state import SleepState, SleepContext from .time_checker import TimeChecker if TYPE_CHECKING: - pass + from .wakeup_manager import WakeUpManager logger = get_logger("sleep_manager") @@ -25,28 +25,23 @@ class SleepManager: """ 初始化睡眠管理器。 """ - self.time_checker = TimeChecker() # 时间检查器,用于判断当前是否处于理论睡眠时间 + self.context = SleepContext() # 睡眠上下文,管理所有状态 + self.time_checker = TimeChecker() # 时间检查器 self.last_sleep_log_time = 0 # 上次记录睡眠日志的时间戳 self.sleep_log_interval = 35 # 睡眠日志记录间隔(秒) - - # --- 统一睡眠状态管理 --- - self._current_state: SleepState = SleepState.AWAKE # 当前睡眠状态 - self._sleep_buffer_end_time: Optional[datetime] = None # 睡眠缓冲结束时间,用于状态转换 - self._total_delayed_minutes_today: float = 0.0 # 今天总共延迟入睡的分钟数 - self._last_sleep_check_date: Optional[date] = None # 上次检查睡眠状态的日期 self._last_fully_slept_log_time: float = 0 # 上次完全进入睡眠状态的时间戳 - self._re_sleep_attempt_time: Optional[datetime] = None # 被吵醒后,尝试重新入睡的时间点 - - # 从本地存储加载上一次的睡眠状态 - self._load_sleep_state() def get_current_sleep_state(self) -> SleepState: """获取当前的睡眠状态。""" - return self._current_state + return self.context.current_state def is_sleeping(self) -> bool: """判断当前是否处于正在睡觉的状态。""" - return self._current_state == SleepState.SLEEPING + return self.context.current_state == SleepState.SLEEPING + + def is_woken_up(self) -> bool: + """判断当前是否处于被吵醒的状态。""" + return self.context.current_state == SleepState.WOKEN_UP async def update_sleep_state(self, wakeup_manager: Optional["WakeUpManager"] = None): """ @@ -58,41 +53,42 @@ class SleepManager: """ # 如果全局禁用了睡眠系统,则强制设置为清醒状态并返回 if not global_config.sleep_system.enable: - if self._current_state != SleepState.AWAKE: + if self.context.current_state != SleepState.AWAKE: logger.debug("睡眠系统禁用,强制设为 AWAKE") - self._current_state = SleepState.AWAKE + self.context.current_state = SleepState.AWAKE return now = datetime.now() today = now.date() # 跨天处理:如果日期变化,重置每日相关的睡眠状态 - if self._last_sleep_check_date != today: + if self.context.last_sleep_check_date != today: logger.info(f"新的一天 ({today}),重置睡眠状态。") - self._total_delayed_minutes_today = 0 - self._current_state = SleepState.AWAKE - self._sleep_buffer_end_time = None - self._last_sleep_check_date = today - self._save_sleep_state() + self.context.total_delayed_minutes_today = 0 + self.context.current_state = SleepState.AWAKE + self.context.sleep_buffer_end_time = None + self.context.last_sleep_check_date = today + self.context.save() # 检查当前是否处于理论上的睡眠时间段 is_in_theoretical_sleep, activity = self.time_checker.is_in_theoretical_sleep_time(now.time()) # --- 状态机核心处理逻辑 --- - if self._current_state == SleepState.AWAKE: + current_state = self.context.current_state + if current_state == SleepState.AWAKE: if is_in_theoretical_sleep: self._handle_awake_to_sleep(now, activity, wakeup_manager) - elif self._current_state == SleepState.PREPARING_SLEEP: + elif current_state == SleepState.PREPARING_SLEEP: self._handle_preparing_sleep(now, is_in_theoretical_sleep, wakeup_manager) - elif self._current_state == SleepState.SLEEPING: + elif current_state == SleepState.SLEEPING: self._handle_sleeping(now, is_in_theoretical_sleep, activity, wakeup_manager) - elif self._current_state == SleepState.INSOMNIA: + elif current_state == SleepState.INSOMNIA: self._handle_insomnia(now, is_in_theoretical_sleep) - elif self._current_state == SleepState.WOKEN_UP: + elif current_state == SleepState.WOKEN_UP: self._handle_woken_up(now, is_in_theoretical_sleep, wakeup_manager) def _handle_awake_to_sleep(self, now: datetime, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]): @@ -118,13 +114,13 @@ class SleepManager: delay_minutes = int(pressure_diff * max_delay_minutes) # 确保总延迟不超过当日最大值 - remaining_delay = max_delay_minutes - self._total_delayed_minutes_today + remaining_delay = max_delay_minutes - self.context.total_delayed_minutes_today delay_minutes = min(delay_minutes, remaining_delay) if delay_minutes > 0: # 增加一些随机性 buffer_seconds = random.randint(int(delay_minutes * 0.8 * 60), int(delay_minutes * 1.2 * 60)) - self._total_delayed_minutes_today += buffer_seconds / 60.0 + self.context.total_delayed_minutes_today += buffer_seconds / 60.0 logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 较低,延迟 {buffer_seconds / 60:.1f} 分钟入睡。") else: # 延迟额度已用完,设置一个较短的准备时间 @@ -139,22 +135,22 @@ class SleepManager: if global_config.sleep_system.enable_pre_sleep_notification: asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context)) - self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds) - self._current_state = SleepState.PREPARING_SLEEP + self.context.sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds) + self.context.current_state = SleepState.PREPARING_SLEEP logger.info(f"进入准备入睡状态,将在 {buffer_seconds / 60:.1f} 分钟内入睡。") - self._save_sleep_state() + self.context.save() else: # 无法获取 wakeup_manager,退回旧逻辑 buffer_seconds = random.randint(1 * 60, 3 * 60) - self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds) - self._current_state = SleepState.PREPARING_SLEEP + self.context.sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds) + self.context.current_state = SleepState.PREPARING_SLEEP logger.warning("无法获取 WakeUpManager,弹性睡眠采用默认1-3分钟延迟。") - self._save_sleep_state() + self.context.save() else: # 非弹性睡眠模式 if wakeup_manager and global_config.sleep_system.enable_pre_sleep_notification: asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context)) - self._current_state = SleepState.SLEEPING + self.context.current_state = SleepState.SLEEPING def _handle_preparing_sleep(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]): @@ -162,32 +158,32 @@ class SleepManager: # 如果在准备期间离开了理论睡眠时间,则取消入睡 if not is_in_theoretical_sleep: logger.info("准备入睡期间离开理论休眠时间,取消入睡,恢复清醒。") - self._current_state = SleepState.AWAKE - self._sleep_buffer_end_time = None - self._save_sleep_state() + self.context.current_state = SleepState.AWAKE + self.context.sleep_buffer_end_time = None + self.context.save() # 如果缓冲时间结束,则正式进入睡眠状态 - elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time: + elif self.context.sleep_buffer_end_time and now >= self.context.sleep_buffer_end_time: logger.info("睡眠缓冲期结束,正式进入休眠状态。") - self._current_state = SleepState.SLEEPING + self.context.current_state = SleepState.SLEEPING self._last_fully_slept_log_time = now.timestamp() # 设置一个随机的延迟,用于触发“睡后失眠”检查 delay_minutes_range = global_config.sleep_system.insomnia_trigger_delay_minutes delay_minutes = random.randint(delay_minutes_range[0], delay_minutes_range[1]) - self._sleep_buffer_end_time = now + timedelta(minutes=delay_minutes) + self.context.sleep_buffer_end_time = now + timedelta(minutes=delay_minutes) logger.info(f"已设置睡后失眠检查,将在 {delay_minutes} 分钟后触发。") - self._save_sleep_state() + self.context.save() def _handle_sleeping(self, now: datetime, is_in_theoretical_sleep: bool, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]): """处理“正在睡觉”状态下的逻辑。""" # 如果理论睡眠时间结束,则自然醒来 if not is_in_theoretical_sleep: logger.info("理论休眠时间结束,自然醒来。") - self._current_state = SleepState.AWAKE - self._save_sleep_state() + self.context.current_state = SleepState.AWAKE + self.context.save() # 检查是否到了触发“睡后失眠”的时间点 - elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time: + elif self.context.sleep_buffer_end_time and now >= self.context.sleep_buffer_end_time: if wakeup_manager: sleep_pressure = wakeup_manager.context.sleep_pressure pressure_threshold = global_config.sleep_system.flexible_sleep_pressure_threshold @@ -201,12 +197,12 @@ class SleepManager: logger.info("随机触发失眠。") if insomnia_reason: - self._current_state = SleepState.INSOMNIA + self.context.current_state = SleepState.INSOMNIA # 设置失眠的持续时间 duration_minutes_range = global_config.sleep_system.insomnia_duration_minutes duration_minutes = random.randint(*duration_minutes_range) - self._sleep_buffer_end_time = now + timedelta(minutes=duration_minutes) + self.context.sleep_buffer_end_time = now + timedelta(minutes=duration_minutes) # 发送失眠通知 asyncio.create_task(NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason)) @@ -214,8 +210,8 @@ class SleepManager: else: # 睡眠压力正常,不触发失眠,清除检查时间点 logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 正常,未触发睡后失眠。") - self._sleep_buffer_end_time = None - self._save_sleep_state() + self.context.sleep_buffer_end_time = None + self.context.save() else: # 定期记录睡眠日志 current_timestamp = now.timestamp() @@ -228,26 +224,26 @@ class SleepManager: # 如果离开理论睡眠时间,则失眠结束 if not is_in_theoretical_sleep: logger.info("已离开理论休眠时间,失眠结束,恢复清醒。") - self._current_state = SleepState.AWAKE - self._sleep_buffer_end_time = None - self._save_sleep_state() + self.context.current_state = SleepState.AWAKE + self.context.sleep_buffer_end_time = None + self.context.save() # 如果失眠持续时间已过,则恢复睡眠 - elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time: + elif self.context.sleep_buffer_end_time and now >= self.context.sleep_buffer_end_time: logger.info("失眠状态持续时间已过,恢复睡眠。") - self._current_state = SleepState.SLEEPING - self._sleep_buffer_end_time = None - self._save_sleep_state() + self.context.current_state = SleepState.SLEEPING + self.context.sleep_buffer_end_time = None + self.context.save() def _handle_woken_up(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]): """处理“被吵醒”状态下的逻辑。""" # 如果理论睡眠时间结束,则状态自动结束 if not is_in_theoretical_sleep: logger.info("理论休眠时间结束,被吵醒的状态自动结束。") - self._current_state = SleepState.AWAKE - self._re_sleep_attempt_time = None - self._save_sleep_state() + self.context.current_state = SleepState.AWAKE + self.context.re_sleep_attempt_time = None + self.context.save() # 到了尝试重新入睡的时间点 - elif self._re_sleep_attempt_time and now >= self._re_sleep_attempt_time: + elif self.context.re_sleep_attempt_time and now >= self.context.re_sleep_attempt_time: logger.info("被吵醒后经过一段时间,尝试重新入睡...") if wakeup_manager: sleep_pressure = wakeup_manager.context.sleep_pressure @@ -257,48 +253,28 @@ class SleepManager: if sleep_pressure >= pressure_threshold: logger.info("睡眠压力足够,从被吵醒状态转换到准备入睡。") buffer_seconds = random.randint(3 * 60, 8 * 60) - self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds) - self._current_state = SleepState.PREPARING_SLEEP - self._re_sleep_attempt_time = None + self.context.sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds) + self.context.current_state = SleepState.PREPARING_SLEEP + self.context.re_sleep_attempt_time = None else: # 睡眠压力不足,延迟一段时间后再次尝试 delay_minutes = 15 - self._re_sleep_attempt_time = now + timedelta(minutes=delay_minutes) + self.context.re_sleep_attempt_time = now + timedelta(minutes=delay_minutes) logger.info( f"睡眠压力({sleep_pressure:.1f})仍然较低,暂时保持清醒,在 {delay_minutes} 分钟后再次尝试。" ) - self._save_sleep_state() + self.context.save() def reset_sleep_state_after_wakeup(self): """ 当角色被用户消息等外部因素唤醒时调用此方法。 将状态强制转换为 WOKEN_UP,并设置一个延迟,之后会尝试重新入睡。 """ - if self._current_state in [SleepState.PREPARING_SLEEP, SleepState.SLEEPING, SleepState.INSOMNIA]: + if self.context.current_state in [SleepState.PREPARING_SLEEP, SleepState.SLEEPING, SleepState.INSOMNIA]: logger.info("被唤醒,进入 WOKEN_UP 状态!") - self._current_state = SleepState.WOKEN_UP - self._sleep_buffer_end_time = None + self.context.current_state = SleepState.WOKEN_UP + self.context.sleep_buffer_end_time = None re_sleep_delay_minutes = getattr(global_config.sleep_system, "re_sleep_delay_minutes", 10) - self._re_sleep_attempt_time = datetime.now() + timedelta(minutes=re_sleep_delay_minutes) + self.context.re_sleep_attempt_time = datetime.now() + timedelta(minutes=re_sleep_delay_minutes) logger.info(f"将在 {re_sleep_delay_minutes} 分钟后尝试重新入睡。") - self._save_sleep_state() - - def _save_sleep_state(self): - """将当前所有睡眠相关的状态打包并保存到本地存储。""" - state_data = { - "_current_state": self._current_state, - "_sleep_buffer_end_time": self._sleep_buffer_end_time, - "_total_delayed_minutes_today": self._total_delayed_minutes_today, - "_last_sleep_check_date": self._last_sleep_check_date, - "_re_sleep_attempt_time": self._re_sleep_attempt_time, - } - SleepStateSerializer.save(state_data) - - def _load_sleep_state(self): - """从本地存储加载并恢复所有睡眠相关的状态。""" - state_data = SleepStateSerializer.load() - self._current_state = state_data["_current_state"] - self._sleep_buffer_end_time = state_data["_sleep_buffer_end_time"] - self._total_delayed_minutes_today = state_data["_total_delayed_minutes_today"] - self._last_sleep_check_date = state_data["_last_sleep_check_date"] - self._re_sleep_attempt_time = state_data["_re_sleep_attempt_time"] + self.context.save() diff --git a/src/chat/message_manager/sleep_manager/sleep_state.py b/src/chat/message_manager/sleep_manager/sleep_state.py new file mode 100644 index 000000000..d59f1f3d6 --- /dev/null +++ b/src/chat/message_manager/sleep_manager/sleep_state.py @@ -0,0 +1,86 @@ +from enum import Enum, auto +from datetime import datetime, date +from typing import Optional + +from src.common.logger import get_logger +from src.manager.local_store_manager import local_storage + +logger = get_logger("sleep_state") + + +class SleepState(Enum): + """ + 定义了角色可能处于的几种睡眠状态。 + 这是一个状态机,用于管理角色的睡眠周期。 + """ + + AWAKE = auto() # 清醒状态 + INSOMNIA = auto() # 失眠状态 + PREPARING_SLEEP = auto() # 准备入睡状态,一个短暂的过渡期 + SLEEPING = auto() # 正在睡觉状态 + WOKEN_UP = auto() # 被吵醒状态 + + +class SleepContext: + """ + 睡眠上下文,负责封装和管理所有与睡眠相关的状态,并处理其持久化。 + """ + def __init__(self): + """初始化睡眠上下文,并从本地存储加载初始状态。""" + self.current_state: SleepState = SleepState.AWAKE + self.sleep_buffer_end_time: Optional[datetime] = None + self.total_delayed_minutes_today: float = 0.0 + self.last_sleep_check_date: Optional[date] = None + self.re_sleep_attempt_time: Optional[datetime] = None + self.load() + + def save(self): + """将当前的睡眠状态数据保存到本地存储。""" + try: + state = { + "current_state": self.current_state.name, + "sleep_buffer_end_time_ts": self.sleep_buffer_end_time.timestamp() + if self.sleep_buffer_end_time + else None, + "total_delayed_minutes_today": self.total_delayed_minutes_today, + "last_sleep_check_date_str": self.last_sleep_check_date.isoformat() + if self.last_sleep_check_date + else None, + "re_sleep_attempt_time_ts": self.re_sleep_attempt_time.timestamp() + if self.re_sleep_attempt_time + else None, + } + local_storage["schedule_sleep_state"] = state + logger.debug(f"已保存睡眠上下文: {state}") + except Exception as e: + logger.error(f"保存睡眠上下文失败: {e}") + + def load(self): + """从本地存储加载并解析睡眠状态。""" + try: + state = local_storage["schedule_sleep_state"] + if not (state and isinstance(state, dict)): + logger.info("未找到本地睡眠上下文,使用默认值。") + return + + state_name = state.get("current_state") + if state_name and hasattr(SleepState, state_name): + self.current_state = SleepState[state_name] + + end_time_ts = state.get("sleep_buffer_end_time_ts") + if end_time_ts: + self.sleep_buffer_end_time = datetime.fromtimestamp(end_time_ts) + + re_sleep_ts = state.get("re_sleep_attempt_time_ts") + if re_sleep_ts: + self.re_sleep_attempt_time = datetime.fromtimestamp(re_sleep_ts) + + self.total_delayed_minutes_today = state.get("total_delayed_minutes_today", 0.0) + + date_str = state.get("last_sleep_check_date_str") + if date_str: + self.last_sleep_check_date = datetime.fromisoformat(date_str).date() + + logger.info(f"成功从本地存储加载睡眠上下文: {state}") + except Exception as e: + logger.warning(f"加载睡眠上下文失败,将使用默认值: {e}") \ No newline at end of file diff --git a/src/chat/chat_loop/sleep_manager/time_checker.py b/src/chat/message_manager/sleep_manager/time_checker.py similarity index 100% rename from src/chat/chat_loop/sleep_manager/time_checker.py rename to src/chat/message_manager/sleep_manager/time_checker.py diff --git a/src/chat/message_manager/sleep_manager/wakeup_context.py b/src/chat/message_manager/sleep_manager/wakeup_context.py new file mode 100644 index 000000000..bfa1a62dd --- /dev/null +++ b/src/chat/message_manager/sleep_manager/wakeup_context.py @@ -0,0 +1,45 @@ +import time +from src.common.logger import get_logger +from src.manager.local_store_manager import local_storage + +logger = get_logger("wakeup_context") + + +class WakeUpContext: + """ + 唤醒上下文,负责封装和管理所有与唤醒相关的状态,并处理其持久化。 + """ + def __init__(self): + """初始化唤醒上下文,并从本地存储加载初始状态。""" + self.wakeup_value: float = 0.0 + self.is_angry: bool = False + self.angry_start_time: float = 0.0 + self.sleep_pressure: float = 100.0 # 新增:睡眠压力 + self.load() + + def _get_storage_key(self) -> str: + """获取本地存储键""" + return "global_wakeup_manager_state" + + def load(self): + """从本地存储加载状态""" + state = local_storage[self._get_storage_key()] + if state and isinstance(state, dict): + self.wakeup_value = state.get("wakeup_value", 0.0) + self.is_angry = state.get("is_angry", False) + self.angry_start_time = state.get("angry_start_time", 0.0) + self.sleep_pressure = state.get("sleep_pressure", 100.0) + logger.info(f"成功从本地存储加载唤醒上下文: {state}") + else: + logger.info("未找到本地唤醒上下文,将使用默认值初始化。") + + def save(self): + """将当前状态保存到本地存储""" + state = { + "wakeup_value": self.wakeup_value, + "is_angry": self.is_angry, + "angry_start_time": self.angry_start_time, + "sleep_pressure": self.sleep_pressure, + } + local_storage[self._get_storage_key()] = state + logger.debug(f"已将唤醒上下文保存到本地存储: {state}") \ No newline at end of file diff --git a/src/chat/message_manager/sleep_manager/wakeup_manager.py b/src/chat/message_manager/sleep_manager/wakeup_manager.py new file mode 100644 index 000000000..51ab80bb1 --- /dev/null +++ b/src/chat/message_manager/sleep_manager/wakeup_manager.py @@ -0,0 +1,215 @@ +import asyncio +import time +from typing import Optional, TYPE_CHECKING +from src.common.logger import get_logger +from src.config.config import global_config +from src.manager.local_store_manager import local_storage +from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext + +if TYPE_CHECKING: + from .sleep_manager import SleepManager + + +logger = get_logger("wakeup") + + +class WakeUpManager: + def __init__(self, sleep_manager: "SleepManager"): + """ + 初始化唤醒度管理器 + + Args: + sleep_manager: 睡眠管理器实例 + + 功能说明: + - 管理休眠状态下的唤醒度累积 + - 处理唤醒度的自然衰减 + - 控制愤怒状态的持续时间 + """ + self.sleep_manager = sleep_manager + self.context = WakeUpContext() # 使用新的上下文管理器 + self.angry_chat_id: Optional[str] = None + self.last_decay_time = time.time() + self._decay_task: Optional[asyncio.Task] = None + self.is_running = False + self.last_log_time = 0 + self.log_interval = 30 + + # 从配置文件获取参数 + sleep_config = global_config.sleep_system + self.wakeup_threshold = sleep_config.wakeup_threshold + self.private_message_increment = sleep_config.private_message_increment + self.group_mention_increment = sleep_config.group_mention_increment + self.decay_rate = sleep_config.decay_rate + self.decay_interval = sleep_config.decay_interval + self.angry_duration = sleep_config.angry_duration + self.enabled = sleep_config.enable + self.angry_prompt = sleep_config.angry_prompt + + async def start(self): + """启动唤醒度管理器""" + if not self.enabled: + logger.info("唤醒度系统已禁用,跳过启动") + return + + self.is_running = True + if not self._decay_task or self._decay_task.done(): + self._decay_task = asyncio.create_task(self._decay_loop()) + self._decay_task.add_done_callback(self._handle_decay_completion) + logger.info("唤醒度管理器已启动") + + async def stop(self): + """停止唤醒度管理器""" + self.is_running = False + if self._decay_task and not self._decay_task.done(): + self._decay_task.cancel() + await asyncio.sleep(0) + logger.info("唤醒度管理器已停止") + + def _handle_decay_completion(self, task: asyncio.Task): + """处理衰减任务完成""" + try: + if exception := task.exception(): + logger.error(f"唤醒度衰减任务异常: {exception}") + else: + logger.info("唤醒度衰减任务正常结束") + except asyncio.CancelledError: + logger.info("唤醒度衰减任务被取消") + + async def _decay_loop(self): + """唤醒度衰减循环""" + while self.is_running: + await asyncio.sleep(self.decay_interval) + + current_time = time.time() + + # 检查愤怒状态是否过期 + if self.context.is_angry and current_time - self.context.angry_start_time >= self.angry_duration: + self.context.is_angry = False + # 通知情绪管理系统清除愤怒状态 + from src.mood.mood_manager import mood_manager + if self.angry_chat_id: + mood_manager.clear_angry_from_wakeup(self.angry_chat_id) + self.angry_chat_id = None + else: + logger.warning("Angry state ended but no angry_chat_id was set.") + logger.info("愤怒状态结束,恢复正常") + self.context.save() + + # 唤醒度自然衰减 + if self.context.wakeup_value > 0: + old_value = self.context.wakeup_value + self.context.wakeup_value = max(0, self.context.wakeup_value - self.decay_rate) + if old_value != self.context.wakeup_value: + logger.debug(f"唤醒度衰减: {old_value:.1f} -> {self.context.wakeup_value:.1f}") + self.context.save() + + def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None) -> bool: + """ + 增加唤醒度值 + + Args: + is_private_chat: 是否为私聊 + is_mentioned: 是否被艾特(仅群聊有效) + + Returns: + bool: 是否达到唤醒阈值 + """ + # 如果系统未启用,直接返回 + if not self.enabled: + return False + + # 只有在休眠且非失眠状态下才累积唤醒度 + from .sleep_state import SleepState + + current_sleep_state = self.sleep_manager.get_current_sleep_state() + if current_sleep_state != SleepState.SLEEPING: + return False + + old_value = self.context.wakeup_value + + if is_private_chat: + # 私聊每条消息都增加唤醒度 + self.context.wakeup_value += self.private_message_increment + logger.debug(f"私聊消息增加唤醒度: +{self.private_message_increment}") + elif is_mentioned: + # 群聊只有被艾特才增加唤醒度 + self.context.wakeup_value += self.group_mention_increment + logger.debug(f"群聊艾特增加唤醒度: +{self.group_mention_increment}") + else: + # 群聊未被艾特,不增加唤醒度 + return False + + current_time = time.time() + if current_time - self.last_log_time > self.log_interval: + logger.info( + f"唤醒度变化: {old_value:.1f} -> {self.context.wakeup_value:.1f} (阈值: {self.wakeup_threshold})" + ) + self.last_log_time = current_time + else: + logger.debug( + f"唤醒度变化: {old_value:.1f} -> {self.context.wakeup_value:.1f} (阈值: {self.wakeup_threshold})" + ) + + # 检查是否达到唤醒阈值 + if self.context.wakeup_value >= self.wakeup_threshold: + if not chat_id: + logger.error("Wakeup threshold reached, but no chat_id was provided. Cannot trigger wakeup.") + return False + self._trigger_wakeup(chat_id) + return True + + self.context.save() + return False + + def _trigger_wakeup(self, chat_id: str): + """触发唤醒,进入愤怒状态""" + self.context.is_angry = True + self.context.angry_start_time = time.time() + self.context.wakeup_value = 0.0 # 重置唤醒度 + self.angry_chat_id = chat_id + + self.context.save() + + # 通知情绪管理系统进入愤怒状态 + from src.mood.mood_manager import mood_manager + mood_manager.set_angry_from_wakeup(chat_id) + + # 通知SleepManager重置睡眠状态 + self.sleep_manager.reset_sleep_state_after_wakeup() + + logger.info(f"唤醒度达到阈值({self.wakeup_threshold}),被吵醒进入愤怒状态!") + + def get_angry_prompt_addition(self) -> str: + """获取愤怒状态下的提示词补充""" + if self.context.is_angry: + return self.angry_prompt + return "" + + def is_in_angry_state(self) -> bool: + """检查是否处于愤怒状态""" + if self.context.is_angry: + current_time = time.time() + if current_time - self.context.angry_start_time >= self.angry_duration: + self.context.is_angry = False + # 通知情绪管理系统清除愤怒状态 + from src.mood.mood_manager import mood_manager + if self.angry_chat_id: + mood_manager.clear_angry_from_wakeup(self.angry_chat_id) + self.angry_chat_id = None + else: + logger.warning("Angry state expired in check, but no angry_chat_id was set.") + logger.info("愤怒状态自动过期") + return False + return self.context.is_angry + + def get_status_info(self) -> dict: + """获取当前状态信息""" + return { + "wakeup_value": self.context.wakeup_value, + "wakeup_threshold": self.wakeup_threshold, + "is_angry": self.context.is_angry, + "angry_remaining_time": max(0, self.angry_duration - (time.time() - self.context.angry_start_time)) + if self.context.is_angry + else 0, + } diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 53cb00345..92a44b443 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -11,11 +11,12 @@ from src.mood.mood_manager import mood_manager # 导入情绪管理器 from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream from src.chat.message_receive.message import MessageRecv, MessageRecvS4U from src.chat.message_receive.storage import MessageStorage -from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver +from src.chat.message_manager import message_manager from src.chat.utils.prompt import Prompt, global_prompt_manager from src.plugin_system.core import component_registry, event_manager, global_announcement_manager from src.plugin_system.base import BaseCommand, EventType from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor +from src.chat.utils.utils import is_mentioned_bot_in_message # 导入反注入系统 from src.chat.antipromptinjector import initialize_anti_injector @@ -73,15 +74,17 @@ class ChatBot: self.bot = None # bot 实例引用 self._started = False self.mood_manager = mood_manager # 获取情绪管理器单例 - self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增 + # 亲和力流消息处理器 - 直接使用全局afc_manager self.s4u_message_processor = S4UMessageProcessor() # 初始化反注入系统 self._initialize_anti_injector() - @staticmethod - def _initialize_anti_injector(): + # 启动消息管理器 + self._message_manager_started = False + + def _initialize_anti_injector(self): """初始化反注入系统""" try: initialize_anti_injector() @@ -99,10 +102,15 @@ class ChatBot: if not self._started: logger.debug("确保ChatBot所有任务已启动") + # 启动消息管理器 + if not self._message_manager_started: + await message_manager.start() + self._message_manager_started = True + logger.info("消息管理器已启动") + self._started = True - @staticmethod - async def _process_plus_commands(message: MessageRecv): + async def _process_plus_commands(self, message: MessageRecv): """独立处理PlusCommand系统""" try: text = message.processed_plain_text @@ -182,7 +190,7 @@ class ChatBot: try: # 检查聊天类型限制 if not plus_command_instance.is_chat_type_allowed(): - is_group = hasattr(message, "is_group_message") and message.is_group_message + is_group = message.message_info.group_info logger.info( f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" ) @@ -222,8 +230,7 @@ class ChatBot: logger.error(f"处理PlusCommand时出错: {e}") return False, None, True # 出错时继续处理消息 - @staticmethod - async def _process_commands_with_new_system(message: MessageRecv): + async def _process_commands_with_new_system(self, message: MessageRecv): # sourcery skip: use-named-expression """使用新插件系统处理命令""" try: @@ -256,7 +263,7 @@ class ChatBot: try: # 检查聊天类型限制 if not command_instance.is_chat_type_allowed(): - is_group = hasattr(message, "is_group_message") and message.is_group_message + is_group = message.message_info.group_info logger.info( f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" ) @@ -313,8 +320,7 @@ class ChatBot: return False - @staticmethod - async def handle_adapter_response(message: MessageRecv): + async def handle_adapter_response(self, message: MessageRecv): """处理适配器命令响应""" try: from src.plugin_system.apis.send_api import put_adapter_response @@ -354,19 +360,7 @@ class ChatBot: return async def message_process(self, message_data: Dict[str, Any]) -> None: - """处理转化后的统一格式消息 - 这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中 - heart_flow模式:使用思维流系统进行回复 - - 包含思维流状态管理 - - 在回复前进行观察和状态更新 - - 回复后更新思维流状态 - - 消息过滤 - - 记忆激活 - - 意愿计算 - - 消息生成和发送 - - 表情包处理 - - 性能计时 - """ + """处理转化后的统一格式消息""" try: # 首先处理可能的切片消息重组 from src.utils.message_chunker import reassembler @@ -403,9 +397,7 @@ class ChatBot: # logger.debug(str(message_data)) message = MessageRecv(message_data) - if await self.handle_notice_message(message): - ... - + message.is_mentioned, _ = is_mentioned_bot_in_message(message) group_info = message.message_info.group_info user_info = message.message_info.user_info if message.message_info.additional_config: @@ -415,6 +407,7 @@ class ChatBot: return get_chat_manager().register_message(message) + chat = await get_chat_manager().get_or_create_stream( platform=message.message_info.platform, # type: ignore user_info=user_info, # type: ignore @@ -426,11 +419,14 @@ class ChatBot: # 处理消息内容,生成纯文本 await message.process() - # 过滤检查 (在消息处理之后进行) - if _check_ban_words( - message.processed_plain_text, chat, user_info # type: ignore - ) or _check_ban_regex( - message.processed_plain_text, chat, user_info # type: ignore + # 在这里打印[所见]日志,确保在所有处理和过滤之前记录 + logger.info(f"\u001b[38;5;118m{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m") + + # 过滤检查 + if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore + message.raw_message, # type: ignore + chat, + user_info, # type: ignore ): return @@ -456,7 +452,8 @@ class ChatBot: result = await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message=message) if not result.all_continue_process(): raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理") - + + # TODO:暂不可用 # 确认从接口发来的message是否有自定义的prompt模板信息 if message.message_info.template_info and not message.message_info.template_info.template_default: template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore @@ -470,7 +467,55 @@ class ChatBot: template_group_name = None async def preprocess(): - await self.heartflow_message_receiver.process_message(message) + # 存储消息到数据库 + from .storage import MessageStorage + + try: + await MessageStorage.store_message(message, message.chat_stream) + logger.debug(f"消息已存储到数据库: {message.message_info.message_id}") + except Exception as e: + logger.error(f"存储消息到数据库失败: {e}") + traceback.print_exc() + + # 使用消息管理器处理消息(保持原有功能) + from src.common.data_models.database_data_model import DatabaseMessages + + # 创建数据库消息对象 + db_message = DatabaseMessages( + message_id=message.message_info.message_id, + time=message.message_info.time, + chat_id=message.chat_stream.stream_id, + processed_plain_text=message.processed_plain_text, + display_message=message.processed_plain_text, + is_mentioned=message.is_mentioned, + is_at=message.is_at, + is_emoji=message.is_emoji, + is_picid=message.is_picid, + is_command=message.is_command, + is_notify=message.is_notify, + user_id=message.message_info.user_info.user_id, + user_nickname=message.message_info.user_info.user_nickname, + user_cardname=message.message_info.user_info.user_cardname, + user_platform=message.message_info.user_info.platform, + chat_info_stream_id=message.chat_stream.stream_id, + chat_info_platform=message.chat_stream.platform, + chat_info_create_time=message.chat_stream.create_time, + chat_info_last_active_time=message.chat_stream.last_active_time, + chat_info_user_id=message.chat_stream.user_info.user_id, + chat_info_user_nickname=message.chat_stream.user_info.user_nickname, + chat_info_user_cardname=message.chat_stream.user_info.user_cardname, + chat_info_user_platform=message.chat_stream.user_info.platform, + ) + + # 如果是群聊,添加群组信息 + if message.chat_stream.group_info: + db_message.chat_info_group_id = message.chat_stream.group_info.group_id + db_message.chat_info_group_name = message.chat_stream.group_info.group_name + db_message.chat_info_group_platform = message.chat_stream.group_info.platform + + # 添加消息到消息管理器 + message_manager.add_message(message.chat_stream.stream_id, db_message) + logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}") if template_group_name: async with global_prompt_manager.async_message_scope(template_group_name): diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index de2fb62e9..53d9ab0ed 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -25,43 +25,6 @@ install(extra_lines=3) logger = get_logger("chat_stream") -class ChatMessageContext: - """聊天消息上下文,存储消息的上下文信息""" - - def __init__(self, message: "MessageRecv"): - self.message = message - - def get_template_name(self) -> Optional[str]: - """获取模板名称""" - if self.message.message_info.template_info and not self.message.message_info.template_info.template_default: - return self.message.message_info.template_info.template_name # type: ignore - return None - - def get_last_message(self) -> "MessageRecv": - """获取最后一条消息""" - return self.message - - def check_types(self, types: list) -> bool: - # sourcery skip: invert-any-all, use-any, use-next - """检查消息类型""" - if not self.message.message_info.format_info.accept_format: # type: ignore - return False - for t in types: - if t not in self.message.message_info.format_info.accept_format: # type: ignore - return False - return True - - def get_priority_mode(self) -> str: - """获取优先级模式""" - return self.message.priority_mode - - def get_priority_info(self) -> Optional[dict]: - """获取优先级信息""" - if hasattr(self.message, "priority_info") and self.message.priority_info: - return self.message.priority_info - return None - - class ChatStream: """聊天流对象,存储一个完整的聊天上下文""" @@ -79,14 +42,24 @@ class ChatStream: self.group_info = group_info self.create_time = data.get("create_time", time.time()) if data else time.time() self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time - self.energy_value = data.get("energy_value", 5.0) if data else 5.0 self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0 self.saved = False - self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息 - # 从配置文件中读取focus_value,如果没有则使用默认值1.0 - self.focus_energy = data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value + + # 使用StreamContext替代ChatMessageContext + from src.common.data_models.message_manager_data_model import StreamContext + from src.plugin_system.base.component_types import ChatType, ChatMode + + self.stream_context: StreamContext = StreamContext( + stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL + ) + + # 基础参数 + self.base_interest_energy = 0.5 # 默认基础兴趣度 + self._focus_energy = 0.5 # 内部存储的focus_energy值 self.no_reply_consecutive = 0 - self.breaking_accumulated_interest = 0.0 + + # 自动加载历史消息 + self._load_history_messages() def to_dict(self) -> dict: """转换为字典格式""" @@ -97,10 +70,15 @@ class ChatStream: "group_info": self.group_info.to_dict() if self.group_info else None, "create_time": self.create_time, "last_active_time": self.last_active_time, - "energy_value": self.energy_value, "sleep_pressure": self.sleep_pressure, "focus_energy": self.focus_energy, - "breaking_accumulated_interest": self.breaking_accumulated_interest, + # 基础兴趣度 + "base_interest_energy": self.base_interest_energy, + # 新增stream_context信息 + "stream_context_chat_type": self.stream_context.chat_type.value, + "stream_context_chat_mode": self.stream_context.chat_mode.value, + # 新增interruption_count信息 + "interruption_count": self.stream_context.interruption_count, } @classmethod @@ -109,7 +87,7 @@ class ChatStream: user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None - return cls( + instance = cls( stream_id=data["stream_id"], platform=data["platform"], user_info=user_info, # type: ignore @@ -117,6 +95,22 @@ class ChatStream: data=data, ) + # 恢复stream_context信息 + if "stream_context_chat_type" in data: + from src.plugin_system.base.component_types import ChatType, ChatMode + + instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"]) + if "stream_context_chat_mode" in data: + from src.plugin_system.base.component_types import ChatType, ChatMode + + instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"]) + + # 恢复interruption_count信息 + if "interruption_count" in data: + instance.stream_context.interruption_count = data["interruption_count"] + + return instance + def update_active_time(self): """更新最后活跃时间""" self.last_active_time = time.time() @@ -124,7 +118,312 @@ class ChatStream: def set_context(self, message: "MessageRecv"): """设置聊天消息上下文""" - self.context = ChatMessageContext(message) + # 将MessageRecv转换为DatabaseMessages并设置到stream_context + from src.common.data_models.database_data_model import DatabaseMessages + import json + + # 安全获取message_info中的数据 + message_info = getattr(message, "message_info", {}) + user_info = getattr(message_info, "user_info", {}) + group_info = getattr(message_info, "group_info", {}) + + # 提取reply_to信息(从message_segment中查找reply类型的段) + reply_to = None + if hasattr(message, "message_segment") and message.message_segment: + reply_to = self._extract_reply_from_segment(message.message_segment) + + # 完整的数据转移逻辑 + db_message = DatabaseMessages( + # 基础消息信息 + message_id=getattr(message, "message_id", ""), + time=getattr(message, "time", time.time()), + chat_id=self._generate_chat_id(message_info), + reply_to=reply_to, + # 兴趣度相关 + interest_value=getattr(message, "interest_value", 0.0), + # 关键词 + key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False) + if getattr(message, "key_words", None) + else None, + key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False) + if getattr(message, "key_words_lite", None) + else None, + # 消息状态标记 + is_mentioned=getattr(message, "is_mentioned", None), + is_at=getattr(message, "is_at", False), + is_emoji=getattr(message, "is_emoji", False), + is_picid=getattr(message, "is_picid", False), + is_voice=getattr(message, "is_voice", False), + is_video=getattr(message, "is_video", False), + is_command=getattr(message, "is_command", False), + is_notify=getattr(message, "is_notify", False), + # 消息内容 + processed_plain_text=getattr(message, "processed_plain_text", ""), + display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text + # 优先级信息 + priority_mode=getattr(message, "priority_mode", None), + priority_info=json.dumps(getattr(message, "priority_info", None)) + if getattr(message, "priority_info", None) + else None, + # 额外配置 + additional_config=getattr(message_info, "additional_config", None), + # 用户信息 + user_id=str(getattr(user_info, "user_id", "")), + user_nickname=getattr(user_info, "user_nickname", ""), + user_cardname=getattr(user_info, "user_cardname", None), + user_platform=getattr(user_info, "platform", ""), + # 群组信息 + chat_info_group_id=getattr(group_info, "group_id", None), + chat_info_group_name=getattr(group_info, "group_name", None), + chat_info_group_platform=getattr(group_info, "platform", None), + # 聊天流信息 + chat_info_user_id=str(getattr(user_info, "user_id", "")), + chat_info_user_nickname=getattr(user_info, "user_nickname", ""), + chat_info_user_cardname=getattr(user_info, "user_cardname", None), + chat_info_user_platform=getattr(user_info, "platform", ""), + chat_info_stream_id=self.stream_id, + chat_info_platform=self.platform, + chat_info_create_time=self.create_time, + chat_info_last_active_time=self.last_active_time, + # 新增兴趣度系统字段 - 添加安全处理 + actions=self._safe_get_actions(message), + should_reply=getattr(message, "should_reply", False), + ) + + self.stream_context.set_current_message(db_message) + self.stream_context.priority_mode = getattr(message, "priority_mode", None) + self.stream_context.priority_info = getattr(message, "priority_info", None) + + # 调试日志:记录数据转移情况 + logger.debug(f"消息数据转移完成 - message_id: {db_message.message_id}, " + f"chat_id: {db_message.chat_id}, " + f"is_mentioned: {db_message.is_mentioned}, " + f"is_emoji: {db_message.is_emoji}, " + f"is_picid: {db_message.is_picid}, " + f"interest_value: {db_message.interest_value}") + + def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]: + """安全获取消息的actions字段""" + try: + actions = getattr(message, "actions", None) + if actions is None: + return None + + # 如果是字符串,尝试解析为JSON + if isinstance(actions, str): + try: + import json + actions = json.loads(actions) + except json.JSONDecodeError: + logger.warning(f"无法解析actions JSON字符串: {actions}") + return None + + # 确保返回列表类型 + if isinstance(actions, list): + # 过滤掉空值和非字符串元素 + filtered_actions = [action for action in actions if action is not None and isinstance(action, str)] + return filtered_actions if filtered_actions else None + else: + logger.warning(f"actions字段类型不支持: {type(actions)}") + return None + + except Exception as e: + logger.warning(f"获取actions字段失败: {e}") + return None + + def _extract_reply_from_segment(self, segment) -> Optional[str]: + """从消息段中提取reply_to信息""" + try: + if hasattr(segment, "type") and segment.type == "seglist": + # 递归搜索seglist中的reply段 + if hasattr(segment, "data") and segment.data: + for seg in segment.data: + reply_id = self._extract_reply_from_segment(seg) + if reply_id: + return reply_id + elif hasattr(segment, "type") and segment.type == "reply": + # 找到reply段,返回message_id + return str(segment.data) if segment.data else None + except Exception as e: + logger.warning(f"提取reply_to信息失败: {e}") + return None + + def _generate_chat_id(self, message_info) -> str: + """生成chat_id,基于群组或用户信息""" + try: + group_info = getattr(message_info, "group_info", None) + user_info = getattr(message_info, "user_info", None) + + if group_info and hasattr(group_info, "group_id") and group_info.group_id: + # 群聊:使用群组ID + return f"{self.platform}_{group_info.group_id}" + elif user_info and hasattr(user_info, "user_id") and user_info.user_id: + # 私聊:使用用户ID + return f"{self.platform}_{user_info.user_id}_private" + else: + # 默认:使用stream_id + return self.stream_id + except Exception as e: + logger.warning(f"生成chat_id失败: {e}") + return self.stream_id + + @property + def focus_energy(self) -> float: + """使用重构后的能量管理器计算focus_energy""" + try: + from src.chat.energy_system import energy_manager + + # 获取所有消息 + history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size) + unread_messages = self.stream_context.get_unread_messages() + all_messages = history_messages + unread_messages + + # 获取用户ID + user_id = None + if self.user_info and hasattr(self.user_info, "user_id"): + user_id = str(self.user_info.user_id) + + # 使用能量管理器计算 + energy = energy_manager.calculate_focus_energy( + stream_id=self.stream_id, + messages=all_messages, + user_id=user_id + ) + + # 更新内部存储 + self._focus_energy = energy + + logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}") + return energy + + except Exception as e: + logger.error(f"获取focus_energy失败: {e}", exc_info=True) + # 返回缓存的值或默认值 + if hasattr(self, '_focus_energy'): + return self._focus_energy + else: + return 0.5 + + @focus_energy.setter + def focus_energy(self, value: float): + """设置focus_energy值(主要用于初始化或特殊场景)""" + self._focus_energy = max(0.0, min(1.0, value)) + + def _get_user_relationship_score(self) -> float: + """获取用户关系分""" + # 使用插件内部的兴趣度评分系统 + try: + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system + + if self.user_info and hasattr(self.user_info, "user_id"): + user_id = str(self.user_info.user_id) + relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id) + logger.debug(f"ChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}") + return max(0.0, min(1.0, relationship_score)) + + except Exception as e: + logger.warning(f"ChatStream {self.stream_id}: 插件内部关系分计算失败: {e}") + + # 默认基础分 + return 0.3 + + def _load_history_messages(self): + """从数据库加载历史消息到StreamContext""" + try: + from src.common.database.sqlalchemy_models import Messages + from src.common.database.sqlalchemy_database_api import get_db_session + from src.common.data_models.database_data_model import DatabaseMessages + from sqlalchemy import select, desc + import asyncio + + async def _load_messages(): + def _db_query(): + with get_db_session() as session: + # 查询该stream_id的最近20条消息 + stmt = ( + select(Messages) + .where(Messages.chat_info_stream_id == self.stream_id) + .order_by(desc(Messages.time)) + .limit(global_config.chat.max_context_size) + ) + results = session.execute(stmt).scalars().all() + return results + + # 在线程中执行数据库查询 + db_messages = await asyncio.to_thread(_db_query) + + # 转换为DatabaseMessages对象并添加到StreamContext + for db_msg in db_messages: + try: + # 从SQLAlchemy模型转换为DatabaseMessages数据模型 + import orjson + + # 解析actions字段(JSON格式) + actions = None + if db_msg.actions: + try: + actions = orjson.loads(db_msg.actions) + except (orjson.JSONDecodeError, TypeError): + actions = None + + db_message = DatabaseMessages( + message_id=db_msg.message_id, + time=db_msg.time, + chat_id=db_msg.chat_id, + reply_to=db_msg.reply_to, + interest_value=db_msg.interest_value, + key_words=db_msg.key_words, + key_words_lite=db_msg.key_words_lite, + is_mentioned=db_msg.is_mentioned, + processed_plain_text=db_msg.processed_plain_text, + display_message=db_msg.display_message, + priority_mode=db_msg.priority_mode, + priority_info=db_msg.priority_info, + additional_config=db_msg.additional_config, + is_emoji=db_msg.is_emoji, + is_picid=db_msg.is_picid, + is_command=db_msg.is_command, + is_notify=db_msg.is_notify, + user_id=db_msg.user_id, + user_nickname=db_msg.user_nickname, + user_cardname=db_msg.user_cardname, + user_platform=db_msg.user_platform, + chat_info_group_id=db_msg.chat_info_group_id, + chat_info_group_name=db_msg.chat_info_group_name, + chat_info_group_platform=db_msg.chat_info_group_platform, + chat_info_user_id=db_msg.chat_info_user_id, + chat_info_user_nickname=db_msg.chat_info_user_nickname, + chat_info_user_cardname=db_msg.chat_info_user_cardname, + chat_info_user_platform=db_msg.chat_info_user_platform, + chat_info_stream_id=db_msg.chat_info_stream_id, + chat_info_platform=db_msg.chat_info_platform, + chat_info_create_time=db_msg.chat_info_create_time, + chat_info_last_active_time=db_msg.chat_info_last_active_time, + actions=actions, + should_reply=getattr(db_msg, "should_reply", False) or False, + ) + + # 添加调试日志:检查从数据库加载的interest_value + logger.debug(f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}") + + # 标记为已读并添加到历史消息 + db_message.is_read = True + self.stream_context.history_messages.append(db_message) + + except Exception as e: + logger.warning(f"转换消息 {db_msg.message_id} 失败: {e}") + continue + + if self.stream_context.history_messages: + logger.info( + f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}" + ) + + # 创建任务来加载历史消息 + asyncio.create_task(_load_messages()) + + except Exception as e: + logger.error(f"加载历史消息失败: {e}") class ChatManager: @@ -362,7 +661,16 @@ class ChatManager: "group_name": group_info_d["group_name"] if group_info_d else "", "energy_value": s_data_dict.get("energy_value", 5.0), "sleep_pressure": s_data_dict.get("sleep_pressure", 0.0), - "focus_energy": s_data_dict.get("focus_energy", global_config.chat.focus_value), + "focus_energy": s_data_dict.get("focus_energy", 0.5), + # 新增动态兴趣度系统字段 + "base_interest_energy": s_data_dict.get("base_interest_energy", 0.5), + "message_interest_total": s_data_dict.get("message_interest_total", 0.0), + "message_count": s_data_dict.get("message_count", 0), + "action_count": s_data_dict.get("action_count", 0), + "reply_count": s_data_dict.get("reply_count", 0), + "last_interaction_time": s_data_dict.get("last_interaction_time", time.time()), + "consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0), + "interruption_count": s_data_dict.get("interruption_count", 0), } if global_config.database.database_type == "sqlite": stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) @@ -419,7 +727,17 @@ class ChatManager: "last_active_time": model_instance.last_active_time, "energy_value": model_instance.energy_value, "sleep_pressure": model_instance.sleep_pressure, - "focus_energy": getattr(model_instance, "focus_energy", global_config.chat.focus_value), + "focus_energy": getattr(model_instance, "focus_energy", 0.5), + # 新增动态兴趣度系统字段 - 使用getattr提供默认值 + "base_interest_energy": getattr(model_instance, "base_interest_energy", 0.5), + "message_interest_total": getattr(model_instance, "message_interest_total", 0.0), + "message_count": getattr(model_instance, "message_count", 0), + "action_count": getattr(model_instance, "action_count", 0), + "reply_count": getattr(model_instance, "reply_count", 0), + "last_interaction_time": getattr(model_instance, "last_interaction_time", time.time()), + "relationship_score": getattr(model_instance, "relationship_score", 0.3), + "consecutive_no_reply": getattr(model_instance, "consecutive_no_reply", 0), + "interruption_count": getattr(model_instance, "interruption_count", 0), } loaded_streams_data.append(data_for_from_dict) await session.commit() diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 22c3e3776..f6041b7d4 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -123,7 +123,7 @@ class MessageRecv(Message): self.is_video = False self.is_mentioned = None self.is_notify = False - + self.is_at = False self.is_command = False self.priority_mode = "interest" diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 015578be8..b37301f47 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,14 +1,14 @@ import re import traceback +import orjson from typing import Union -import orjson -from sqlalchemy import select, desc, update - -from src.common.database.sqlalchemy_models import Messages, Images, get_db_session +from src.common.database.sqlalchemy_models import Messages, Images from src.common.logger import get_logger from .chat_stream import ChatStream from .message import MessageSending, MessageRecv +from src.common.database.sqlalchemy_database_api import get_db_session +from sqlalchemy import select, update, desc logger = get_logger("message_storage") @@ -41,7 +41,7 @@ class MessageStorage: processed_plain_text = message.processed_plain_text if processed_plain_text: - processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) + processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text) filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL) else: filtered_processed_plain_text = "" @@ -51,7 +51,8 @@ class MessageStorage: if display_message: filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) else: - filtered_display_message = "" + # 如果没有设置display_message,使用processed_plain_text作为显示消息 + filtered_display_message = re.sub(pattern, "", message.processed_plain_text, flags=re.DOTALL) if message.processed_plain_text else "" interest_value = 0 is_mentioned = False reply_to = message.reply_to @@ -116,14 +117,21 @@ class MessageStorage: user_nickname=user_info_dict.get("user_nickname"), user_cardname=user_info_dict.get("user_cardname"), processed_plain_text=filtered_processed_plain_text, + display_message=filtered_display_message, + memorized_times=message.memorized_times, + interest_value=interest_value, priority_mode=priority_mode, priority_info=priority_info_json, is_emoji=is_emoji, is_picid=is_picid, + is_notify=is_notify, + is_command=is_command, + key_words=key_words, + key_words_lite=key_words_lite, ) - async with get_db_session() as session: + with get_db_session() as session: session.add(new_message) - await session.commit() + session.commit() except Exception: logger.exception("存储消息失败") @@ -146,7 +154,8 @@ class MessageStorage: qq_message_id = message.message_segment.data.get("id") elif message.message_segment.type == "reply": qq_message_id = message.message_segment.data.get("id") - logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}") + if qq_message_id: + logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}") elif message.message_segment.type == "adapter_response": logger.debug("适配器响应消息,不需要更新ID") return @@ -162,19 +171,18 @@ class MessageStorage: logger.debug(f"消息段数据: {message.message_segment.data}") return - async with get_db_session() as session: - matched_message = ( - await session.execute( - select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) - ) + # 使用上下文管理器确保session正确管理 + from src.common.database.sqlalchemy_models import get_db_session + + with get_db_session() as session: + matched_message = session.execute( + select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) ).scalar() if matched_message: - await session.execute( + session.execute( update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id) ) - await session.commit() - # 会在上下文管理器中自动调用 logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") else: logger.warning(f"未找到匹配的消息记录: {mmc_message_id}") @@ -186,36 +194,117 @@ class MessageStorage: f"segment_type={getattr(message.message_segment, 'type', 'N/A')}" ) - async def replace_image_descriptions(text: str) -> str: + @staticmethod + def replace_image_descriptions(text: str) -> str: """将[图片:描述]替换为[picid:image_id]""" # 先检查文本中是否有图片标记 pattern = r"\[图片:([^\]]+)\]" - matches = list(re.finditer(pattern, text)) + matches = re.findall(pattern, text) if not matches: logger.debug("文本中没有图片标记,直接返回原文本") return text - new_text = "" - last_end = 0 - for match in matches: - new_text += text[last_end : match.start()] + def replace_match(match): description = match.group(1).strip() try: from src.common.database.sqlalchemy_models import get_db_session - async with get_db_session() as session: - image_record = ( - await session.execute( - select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) - ) + with get_db_session() as session: + image_record = session.execute( + select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) ).scalar() - if image_record: - new_text += f"[picid:{image_record.image_id}]" - else: - new_text += match.group(0) + return f"[picid:{image_record.image_id}]" if image_record else match.group(0) except Exception: - new_text += match.group(0) - last_end = match.end() - new_text += text[last_end:] - return new_text + return match.group(0) + + @staticmethod + def update_message_interest_value(message_id: str, interest_value: float) -> None: + """ + 更新数据库中消息的interest_value字段 + + Args: + message_id: 消息ID + interest_value: 兴趣度值 + """ + try: + with get_db_session() as session: + # 更新消息的interest_value字段 + stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value) + result = session.execute(stmt) + session.commit() + + if result.rowcount > 0: + logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}") + else: + logger.warning(f"未找到消息 {message_id},无法更新interest_value") + + except Exception as e: + logger.error(f"更新消息 {message_id} 的interest_value失败: {e}") + raise + + @staticmethod + def fix_zero_interest_values(chat_id: str, since_time: float) -> int: + """ + 修复指定聊天中interest_value为0或null的历史消息记录 + + Args: + chat_id: 聊天ID + since_time: 从指定时间开始修复(时间戳) + + Returns: + 修复的记录数量 + """ + try: + with get_db_session() as session: + from sqlalchemy import select, update + from src.common.database.sqlalchemy_models import Messages + + # 查找需要修复的记录:interest_value为0、null或很小的值 + query = select(Messages).where( + (Messages.chat_id == chat_id) & + (Messages.time >= since_time) & + ( + (Messages.interest_value == 0) | + (Messages.interest_value.is_(None)) | + (Messages.interest_value < 0.1) + ) + ).limit(50) # 限制每次修复的数量,避免性能问题 + + messages_to_fix = session.execute(query).scalars().all() + fixed_count = 0 + + for msg in messages_to_fix: + # 为这些消息设置一个合理的默认兴趣度 + # 可以基于消息长度、内容或其他因素计算 + default_interest = 0.3 # 默认中等兴趣度 + + # 如果消息内容较长,可能是重要消息,兴趣度稍高 + if hasattr(msg, 'processed_plain_text') and msg.processed_plain_text: + text_length = len(msg.processed_plain_text) + if text_length > 50: # 长消息 + default_interest = 0.4 + elif text_length > 20: # 中等长度消息 + default_interest = 0.35 + + # 如果是被@的消息,兴趣度更高 + if getattr(msg, 'is_mentioned', False): + default_interest = min(default_interest + 0.2, 0.8) + + # 执行更新 + update_stmt = update(Messages).where( + Messages.message_id == msg.message_id + ).values(interest_value=default_interest) + + result = session.execute(update_stmt) + if result.rowcount > 0: + fixed_count += 1 + logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}") + + session.commit() + logger.info(f"共修复了 {fixed_count} 条历史消息的interest_value值") + return fixed_count + + except Exception as e: + logger.error(f"修复历史消息interest_value失败: {e}") + return 0 diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 23755e42d..761c53e86 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -1,15 +1,24 @@ -from typing import Dict, Optional, Type +import asyncio +import traceback +import time +from typing import Dict, Optional, Type, Any, Tuple -from src.chat.message_receive.chat_stream import ChatStream + +from src.chat.utils.timer_calculator import Timer +from src.person_info.person_info import get_person_info_manager +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.common.logger import get_logger +from src.config.config import global_config from src.plugin_system.core.component_registry import component_registry from src.plugin_system.base.component_types import ComponentType, ActionInfo from src.plugin_system.base.base_action import BaseAction +from src.plugin_system.apis import generator_api, database_api, send_api, message_api + logger = get_logger("action_manager") -class ActionManager: +class ChatterActionManager: """ 动作管理器,用于管理各种类型的动作 @@ -25,6 +34,8 @@ class ActionManager: # 初始化时将默认动作加载到使用中的动作 self._using_actions = component_registry.get_default_actions() + self.log_prefix: str = "ChatterActionManager" + # === 执行Action方法 === @staticmethod @@ -124,3 +135,417 @@ class ActionManager: actions_to_restore = list(self._using_actions.keys()) self._using_actions = component_registry.get_default_actions() logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") + + async def execute_action( + self, + action_name: str, + chat_id: str, + target_message: Optional[dict] = None, + reasoning: str = "", + action_data: Optional[dict] = None, + thinking_id: Optional[str] = None, + log_prefix: str = "", + ) -> Any: + """ + 执行单个动作的通用函数 + + Args: + action_name: 动作名称 + chat_id: 聊天id + target_message: 目标消息 + reasoning: 执行理由 + action_data: 动作数据 + thinking_id: 思考ID + log_prefix: 日志前缀 + + Returns: + 执行结果 + """ + from src.chat.message_manager.message_manager import message_manager + try: + logger.debug(f"🎯 [ActionManager] execute_action接收到 target_message: {target_message}") + # 通过chat_id获取chat_stream + chat_manager = get_chat_manager() + chat_stream = chat_manager.get_stream(chat_id) + + if not chat_stream: + logger.error(f"{log_prefix} 无法找到chat_id对应的chat_stream: {chat_id}") + return { + "action_type": action_name, + "success": False, + "reply_text": "", + "error": "chat_stream not found", + } + + if action_name == "no_action": + return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""} + + if action_name == "no_reply": + # 直接处理no_reply逻辑,不再通过动作系统 + reason = reasoning or "选择不回复" + logger.info(f"{log_prefix} 选择不回复,原因: {reason}") + + # 存储no_reply信息到数据库 + await database_api.store_action_info( + chat_stream=chat_stream, + action_build_into_prompt=False, + action_prompt_display=reason, + action_done=True, + thinking_id=thinking_id, + action_data={"reason": reason}, + action_name="no_reply", + ) + return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""} + + elif action_name != "reply" and action_name != "no_action": + # 执行普通动作 + success, reply_text, command = await self._handle_action( + chat_stream, + action_name, + reasoning, + action_data or {}, + {}, # cycle_timers + thinking_id, + target_message, + ) + + # 记录执行的动作到目标消息 + if success: + await self._record_action_to_message(chat_stream, action_name, target_message, action_data) + # 重置打断计数 + await self._reset_interruption_count_after_action(chat_stream.stream_id) + + return { + "action_type": action_name, + "success": success, + "reply_text": reply_text, + "command": command, + } + else: + # 生成回复 + try: + success, response_set, _ = await generator_api.generate_reply( + chat_stream=chat_stream, + reply_message=target_message, + action_data=action_data or {}, + available_actions=self.get_using_actions(), + enable_tool=global_config.tool.enable_tool, + request_type="chat.replyer", + from_plugin=False, + ) + if not success or not response_set: + logger.info( + f"对 {target_message.get('processed_plain_text') if target_message else '未知消息'} 的回复生成失败" + ) + return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} + except asyncio.CancelledError: + logger.debug(f"{log_prefix} 并行执行:回复生成任务已被取消") + return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} + + # 发送并存储回复 + loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply( + chat_stream, + response_set, + asyncio.get_event_loop().time(), + target_message, + {}, # cycle_timers + thinking_id, + [], # actions + ) + + # 记录回复动作到目标消息 + await self._record_action_to_message(chat_stream, "reply", target_message, action_data) + + # 回复成功,重置打断计数 + await self._reset_interruption_count_after_action(chat_stream.stream_id) + + return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info} + + except Exception as e: + logger.error(f"{log_prefix} 执行动作时出错: {e}") + logger.error(f"{log_prefix} 错误信息: {traceback.format_exc()}") + return { + "action_type": action_name, + "success": False, + "reply_text": "", + "loop_info": None, + "error": str(e), + } + + async def _record_action_to_message(self, chat_stream, action_name, target_message, action_data): + """ + 记录执行的动作到目标消息中 + + Args: + chat_stream: ChatStream实例 + action_name: 动作名称 + target_message: 目标消息 + action_data: 动作数据 + """ + try: + from src.chat.message_manager.message_manager import message_manager + + # 获取目标消息ID + target_message_id = None + if target_message and isinstance(target_message, dict): + target_message_id = target_message.get("message_id") + elif action_data and isinstance(action_data, dict): + target_message_id = action_data.get("target_message_id") + + if not target_message_id: + logger.debug(f"无法获取目标消息ID,动作: {action_name}") + return + + # 通过message_manager更新消息的动作记录并刷新focus_energy + if chat_stream.stream_id in message_manager.stream_contexts: + message_manager.add_action( + stream_id=chat_stream.stream_id, + message_id=target_message_id, + action=action_name + ) + logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy") + else: + logger.debug(f"未找到stream_context: {chat_stream.stream_id}") + + except Exception as e: + logger.error(f"记录动作到消息失败: {e}") + # 不抛出异常,避免影响主要功能 + + async def _reset_interruption_count_after_action(self, stream_id: str): + """在动作执行成功后重置打断计数""" + from src.chat.message_manager.message_manager import message_manager + try: + if stream_id in message_manager.stream_contexts: + context = message_manager.stream_contexts[stream_id] + if context.interruption_count > 0: + old_count = context.interruption_count + old_afc_adjustment = context.get_afc_threshold_adjustment() + context.reset_interruption_count() + logger.debug(f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0") + except Exception as e: + logger.warning(f"重置打断计数时出错: {e}") + + async def _handle_action( + self, chat_stream, action, reasoning, action_data, cycle_timers, thinking_id, action_message + ) -> tuple[bool, str, str]: + """ + 处理具体的动作执行 + + Args: + chat_stream: ChatStream实例 + action: 动作名称 + reasoning: 执行理由 + action_data: 动作数据 + cycle_timers: 循环计时器 + thinking_id: 思考ID + action_message: 动作消息 + + Returns: + tuple: (执行是否成功, 回复文本, 命令文本) + + 功能说明: + - 创建对应的动作处理器 + - 执行动作并捕获异常 + - 返回执行结果供上级方法整合 + """ + if not chat_stream: + return False, "", "" + try: + # 创建动作处理器 + action_handler = self.create_action( + action_name=action, + action_data=action_data, + reasoning=reasoning, + cycle_timers=cycle_timers, + thinking_id=thinking_id, + chat_stream=chat_stream, + log_prefix=self.log_prefix, + action_message=action_message, + ) + if not action_handler: + # 动作处理器创建失败,尝试回退机制 + logger.warning(f"{self.log_prefix} 创建动作处理器失败: {action},尝试回退方案") + + # 获取当前可用的动作 + available_actions = self.get_using_actions() + fallback_action = None + + # 回退优先级:reply > 第一个可用动作 + if "reply" in available_actions: + fallback_action = "reply" + elif available_actions: + fallback_action = list(available_actions.keys())[0] + + if fallback_action and fallback_action != action: + logger.info(f"{self.log_prefix} 使用回退动作: {fallback_action}") + action_handler = self.create_action( + action_name=fallback_action, + action_data=action_data, + reasoning=f"原动作'{action}'不可用,自动回退。{reasoning}", + cycle_timers=cycle_timers, + thinking_id=thinking_id, + chat_stream=chat_stream, + log_prefix=self.log_prefix, + action_message=action_message, + ) + + if not action_handler: + logger.error(f"{self.log_prefix} 回退方案也失败,无法创建任何动作处理器") + return False, "", "" + + # 执行动作 + success, reply_text = await action_handler.handle_action() + return success, reply_text, "" + except Exception as e: + logger.error(f"{self.log_prefix} 处理{action}时出错: {e}") + traceback.print_exc() + return False, "", "" + + async def _send_and_store_reply( + self, + chat_stream: ChatStream, + response_set, + loop_start_time, + action_message, + cycle_timers: Dict[str, float], + thinking_id, + actions, + ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: + """ + 发送并存储回复信息 + + Args: + chat_stream: ChatStream实例 + response_set: 回复内容集合 + loop_start_time: 循环开始时间 + action_message: 动作消息 + cycle_timers: 循环计时器 + thinking_id: 思考ID + actions: 动作列表 + + Returns: + Tuple[Dict[str, Any], str, Dict[str, float]]: 循环信息, 回复文本, 循环计时器 + """ + # 发送回复 + with Timer("回复发送", cycle_timers): + reply_text = await self.send_response(chat_stream, response_set, loop_start_time, action_message) + + # 存储reply action信息 + person_info_manager = get_person_info_manager() + + # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 + platform = action_message.get("chat_info_platform") + if platform is None: + platform = getattr(chat_stream, "platform", "unknown") + + # 获取用户信息并生成回复提示 + person_id = person_info_manager.get_person_id( + platform, + action_message.get("user_id", ""), + ) + person_name = await person_info_manager.get_value(person_id, "person_name") + action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" + + # 存储动作信息到数据库 + await database_api.store_action_info( + chat_stream=chat_stream, + action_build_into_prompt=False, + action_prompt_display=action_prompt_display, + action_done=True, + thinking_id=thinking_id, + action_data={"reply_text": reply_text}, + action_name="reply", + ) + + # 构建循环信息 + loop_info: Dict[str, Any] = { + "loop_plan_info": { + "action_result": actions, + }, + "loop_action_info": { + "action_taken": True, + "reply_text": reply_text, + "command": "", + "taken_time": time.time(), + }, + } + + return loop_info, reply_text, cycle_timers + + async def send_response(self, chat_stream, reply_set, thinking_start_time, message_data) -> str: + """ + 发送回复内容的具体实现 + + Args: + chat_stream: ChatStream实例 + reply_set: 回复内容集合,包含多个回复段 + reply_to: 回复目标 + thinking_start_time: 思考开始时间 + message_data: 消息数据 + + Returns: + str: 完整的回复文本 + + 功能说明: + - 检查是否有新消息需要回复 + - 处理主动思考的"沉默"决定 + - 根据消息数量决定是否添加回复引用 + - 逐段发送回复内容,支持打字效果 + - 正确处理元组格式的回复段 + """ + current_time = time.time() + # 计算新消息数量 + new_message_count = message_api.count_new_messages( + chat_id=chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time + ) + + # 根据新消息数量决定是否需要引用回复 + reply_text = "" + is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True + + logger.debug(f"[send_response] message_data: {message_data}") + + first_replied = False + for reply_seg in reply_set: + # 调试日志:验证reply_seg的格式 + logger.debug(f"Processing reply_seg type: {type(reply_seg)}, content: {reply_seg}") + + # 修正:正确处理元组格式 (格式为: (type, content)) + if isinstance(reply_seg, tuple) and len(reply_seg) >= 2: + _, data = reply_seg + else: + # 向下兼容:如果已经是字符串,则直接使用 + data = str(reply_seg) + + if isinstance(data, list): + data = "".join(map(str, data)) + reply_text += data + + # 如果是主动思考且内容为"沉默",则不发送 + if is_proactive_thinking and data.strip() == "沉默": + logger.info(f"{self.log_prefix} 主动思考决定保持沉默,不发送消息") + continue + + # 发送第一段回复 + if not first_replied: + set_reply_flag = bool(message_data) + logger.debug(f"📤 [ActionManager] 准备发送第一段回复。message_data: {message_data}, set_reply: {set_reply_flag}") + await send_api.text_to_stream( + text=data, + stream_id=chat_stream.stream_id, + reply_to_message=message_data, + set_reply=set_reply_flag, + typing=False, + ) + first_replied = True + else: + # 发送后续回复 + sent_message = await send_api.text_to_stream( + text=data, + stream_id=chat_stream.stream_id, + reply_to_message=None, + set_reply=False, + typing=True, + ) + + return reply_text \ No newline at end of file diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index bcd01934d..481b848e2 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -7,8 +7,9 @@ from typing import List, Any, Dict, TYPE_CHECKING, Tuple from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext -from src.chat.planner_actions.action_manager import ActionManager +from src.chat.message_receive.chat_stream import get_chat_manager +from src.common.data_models.message_manager_data_model import StreamContext +from src.chat.planner_actions.action_manager import ChatterActionManager from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages from src.plugin_system.base.component_types import ActionInfo, ActionActivationType from src.plugin_system.core.global_announcement_manager import global_announcement_manager @@ -27,7 +28,7 @@ class ActionModifier: 支持并行判定和智能缓存优化。 """ - def __init__(self, action_manager: ActionManager, chat_id: str): + def __init__(self, action_manager: ChatterActionManager, chat_id: str): """初始化动作处理器""" self.chat_id = chat_id self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore @@ -124,8 +125,9 @@ class ActionModifier: logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用") # === 第二阶段:检查动作的关联类型 === - chat_context = self.chat_stream.context - type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context) + chat_context = self.chat_stream.stream_context + current_actions_s2 = self.action_manager.get_using_actions() + type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context) if type_mismatched_actions: removals_s2.extend(type_mismatched_actions) @@ -140,11 +142,12 @@ class ActionModifier: logger.debug(f"{self.log_prefix}开始激活类型判定阶段") # 获取当前使用的动作集(经过第一阶段处理) - current_using_actions = self.action_manager.get_using_actions() + # 在第三阶段开始前,再次获取最新的动作列表 + current_actions_s3 = self.action_manager.get_using_actions() # 获取因激活类型判定而需要移除的动作 removals_s3 = await self._get_deactivated_actions_by_type( - current_using_actions, + current_actions_s3, chat_content, ) @@ -164,7 +167,7 @@ class ActionModifier: logger.info(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}") - def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext): + def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: StreamContext): type_mismatched_actions: List[Tuple[str, str]] = [] for action_name, action_info in all_actions.items(): if action_info.associated_types and not chat_context.check_types(action_info.associated_types): diff --git a/src/chat/planner_actions/plan_executor.py b/src/chat/planner_actions/plan_executor.py deleted file mode 100644 index 591389f99..000000000 --- a/src/chat/planner_actions/plan_executor.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -PlanExecutor: 接收 Plan 对象并执行其中的所有动作。 -""" -from src.chat.planner_actions.action_manager import ActionManager -from src.common.data_models.info_data_model import Plan -from src.common.logger import get_logger - -logger = get_logger("plan_executor") - - -class PlanExecutor: - """ - 负责接收一个 Plan 对象,并执行其中最终确定的所有动作。 - - 这个类是规划流程的最后一步,将规划结果转化为实际的动作执行。 - - Attributes: - action_manager (ActionManager): 用于实际执行各种动作的管理器实例。 - """ - - def __init__(self, action_manager: ActionManager): - """ - 初始化 PlanExecutor。 - - Args: - action_manager (ActionManager): 一个 ActionManager 实例,用于执行动作。 - """ - self.action_manager = action_manager - - @staticmethod - async def execute(plan: Plan): - """ - 遍历并执行 Plan 对象中 `decided_actions` 列表里的所有动作。 - - 如果动作类型为 "no_action",则会记录原因并跳过。 - 否则,它将调用 ActionManager 来执行相应的动作。 - - Args: - plan (Plan): 包含待执行动作列表的 Plan 对象。 - """ - if not plan.decided_actions: - logger.info("没有需要执行的动作。") - return - - for action_info in plan.decided_actions: - if action_info.action_type == "no_action": - logger.info(f"规划器决策不执行动作,原因: {action_info.reasoning}") - continue - - # TODO: 对接 ActionManager 的执行方法 - # 这是一个示例调用,需要根据 ActionManager 的最终实现进行调整 - logger.info(f"执行动作: {action_info.action_type}, 原因: {action_info.reasoning}") - # await self.action_manager.execute_action( - # action_name=action_info.action_type, - # action_data=action_info.action_data, - # reasoning=action_info.reasoning, - # action_message=action_info.action_message, - # ) diff --git a/src/chat/planner_actions/plan_filter.py b/src/chat/planner_actions/plan_filter.py deleted file mode 100644 index 6aaefba18..000000000 --- a/src/chat/planner_actions/plan_filter.py +++ /dev/null @@ -1,366 +0,0 @@ -""" -PlanFilter: 接收 Plan 对象,根据不同模式的逻辑进行筛选,决定最终要执行的动作。 -""" -import orjson -import time -import traceback -from datetime import datetime -from typing import Any, Dict, List, Optional - -from json_repair import repair_json - -from src.chat.memory_system.Hippocampus import hippocampus_manager -from src.chat.utils.chat_message_builder import ( - build_readable_actions, - build_readable_messages_with_id, - get_actions_by_timestamp_with_chat, -) -from src.chat.utils.prompt import global_prompt_manager -from src.common.data_models.info_data_model import ActionPlannerInfo, Plan -from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest -from src.mood.mood_manager import mood_manager -from src.plugin_system.base.component_types import ActionInfo, ChatMode -from src.schedule.schedule_manager import schedule_manager - -logger = get_logger("plan_filter") - - -class PlanFilter: - """ - 根据 Plan 中的模式和信息,筛选并决定最终的动作。 - """ - - def __init__(self): - self.planner_llm = LLMRequest( - model_set=model_config.model_task_config.planner, request_type="planner" - ) - self.last_obs_time_mark = 0.0 - - async def filter(self, plan: Plan) -> Plan: - """ - 执行筛选逻辑,并填充 Plan 对象的 decided_actions 字段。 - """ - logger.debug(f"墨墨在这里加了日志 -> filter 入口 plan: {plan}") - try: - prompt, used_message_id_list = await self._build_prompt(plan) - plan.llm_prompt = prompt - logger.info(f"规划器原始提示词: {prompt}") - - llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt) - - if llm_content: - logger.info(f"规划器原始返回: {llm_content}") - parsed_json = orjson.loads(repair_json(llm_content)) - logger.debug(f"墨墨在这里加了日志 -> 解析后的 JSON: {parsed_json}") - - if isinstance(parsed_json, dict): - parsed_json = [parsed_json] - - if isinstance(parsed_json, list): - final_actions = [] - reply_action_added = False - # 定义回复类动作的集合,方便扩展 - reply_action_types = {"reply", "proactive_reply"} - - for item in parsed_json: - if not isinstance(item, dict): - continue - - # 预解析 action_type 来进行判断 - action_type = item.get("action", "no_action") - - if action_type in reply_action_types: - if not reply_action_added: - final_actions.extend( - await self._parse_single_action( - item, used_message_id_list, plan - ) - ) - reply_action_added = True - else: - # 非回复类动作直接添加 - final_actions.extend( - await self._parse_single_action( - item, used_message_id_list, plan - ) - ) - - plan.decided_actions = self._filter_no_actions(final_actions) - - except Exception as e: - logger.error(f"筛选 Plan 时出错: {e}\n{traceback.format_exc()}") - plan.decided_actions = [ - ActionPlannerInfo(action_type="no_action", reasoning=f"筛选时出错: {e}") - ] - - logger.debug(f"墨墨在这里加了日志 -> filter 出口 decided_actions: {plan.decided_actions}") - return plan - - async def _build_prompt(self, plan: Plan) -> tuple[str, list]: - """ - 根据 Plan 对象构建提示词。 - """ - try: - time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - bot_name = global_config.bot.nickname - bot_nickname = ( - f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else "" - ) - bot_core_personality = global_config.personality.personality_core - identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:" - - schedule_block = "" - if global_config.planning_system.schedule_enable: - if current_activity := schedule_manager.get_current_activity(): - schedule_block = f"你当前正在:{current_activity},但注意它与群聊的聊天无关。" - - mood_block = "" - if global_config.mood.enable_mood: - chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id) - mood_block = f"你现在的心情是:{chat_mood.mood_state}" - - if plan.mode == ChatMode.PROACTIVE: - long_term_memory_block = await self._get_long_term_memory_context() - - chat_content_block, message_id_list = await build_readable_messages_with_id( - messages=[msg.flatten() for msg in plan.chat_history], - timestamp_mode="normal", - truncate=False, - show_actions=False, - ) - - prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt") - actions_before_now = await get_actions_by_timestamp_with_chat( - chat_id=plan.chat_id, - timestamp_start=time.time() - 3600, - timestamp_end=time.time(), - limit=5, - ) - actions_before_now_block = build_readable_actions(actions=actions_before_now) - actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}" - - prompt = prompt_template.format( - time_block=time_block, - identity_block=identity_block, - schedule_block=schedule_block, - mood_block=mood_block, - long_term_memory_block=long_term_memory_block, - chat_content_block=chat_content_block or "最近没有聊天内容。", - actions_before_now_block=actions_before_now_block, - ) - return prompt, message_id_list - - chat_content_block, message_id_list = await build_readable_messages_with_id( - messages=[msg.flatten() for msg in plan.chat_history], - timestamp_mode="normal", - read_mark=self.last_obs_time_mark, - truncate=True, - show_actions=True, - ) - - actions_before_now = await get_actions_by_timestamp_with_chat( - chat_id=plan.chat_id, - timestamp_start=time.time() - 3600, - timestamp_end=time.time(), - limit=5, - ) - - actions_before_now_block = build_readable_actions(actions=actions_before_now) - actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}" - - self.last_obs_time_mark = time.time() - - mentioned_bonus = "" - if global_config.chat.mentioned_bot_inevitable_reply: - mentioned_bonus = "\n- 有人提到你" - if global_config.chat.at_bot_inevitable_reply: - mentioned_bonus = "\n- 有人提到你,或者at你" - - if plan.mode == ChatMode.FOCUS: - no_action_block = """ -动作:no_action -动作描述:不选择任何动作 -{{ - "action": "no_action", - "reason":"不动作的原因" -}} - -动作:no_reply -动作描述:不进行回复,等待合适的回复时机 -- 当你刚刚发送了消息,没有人回复时,选择no_reply -- 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_reply -{{ - "action": "no_reply", - "reason":"不回复的原因" -}} -""" - else: # NORMAL Mode - no_action_block = """重要说明: -- 'reply' 表示只进行普通聊天回复,不执行任何额外动作 -- 其他action表示在普通回复的基础上,执行相应的额外动作 -{{ - "action": "reply", - "target_message_id":"触发action的消息id", - "reason":"回复的原因" -}}""" - - is_group_chat = plan.target_info.platform == "group" if plan.target_info else True - chat_context_description = "你现在正在一个群聊中" - if not is_group_chat and plan.target_info: - chat_target_name = plan.target_info.person_name or plan.target_info.user_nickname or "对方" - chat_context_description = f"你正在和 {chat_target_name} 私聊" - - action_options_block = await self._build_action_options(plan.available_actions) - - moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。" - - custom_prompt_block = "" - if global_config.custom_prompt.planner_custom_prompt_content: - custom_prompt_block = global_config.custom_prompt.planner_custom_prompt_content - - users_in_chat_str = "" # TODO: Re-implement user list fetching if needed - - planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") - prompt = planner_prompt_template.format( - schedule_block=schedule_block, - mood_block=mood_block, - time_block=time_block, - chat_context_description=chat_context_description, - chat_content_block=chat_content_block, - actions_before_now_block=actions_before_now_block, - mentioned_bonus=mentioned_bonus, - no_action_block=no_action_block, - action_options_text=action_options_block, - moderation_prompt=moderation_prompt_block, - identity_block=identity_block, - custom_prompt_block=custom_prompt_block, - bot_name=bot_name, - users_in_chat=users_in_chat_str - ) - return prompt, message_id_list - except Exception as e: - logger.error(f"构建 Planner 提示词时出错: {e}") - logger.error(traceback.format_exc()) - return "构建 Planner Prompt 时出错", [] - - async def _parse_single_action( - self, action_json: dict, message_id_list: list, plan: Plan - ) -> List[ActionPlannerInfo]: - parsed_actions = [] - try: - action = action_json.get("action", "no_action") - reasoning = action_json.get("reason", "未提供原因") - action_data = {k: v for k, v in action_json.items() if k not in ["action", "reason"]} - - target_message_obj = None - if action not in ["no_action", "no_reply", "do_nothing", "proactive_reply"]: - if target_message_id := action_json.get("target_message_id"): - target_message_dict = self._find_message_by_id(target_message_id, message_id_list) - else: - # 如果LLM没有指定target_message_id,我们就默认选择最新的一条消息 - target_message_dict = self._get_latest_message(message_id_list) - - if target_message_dict: - # 直接使用字典作为action_message,避免DatabaseMessages对象创建失败 - target_message_obj = target_message_dict - else: - # 如果找不到目标消息,对于reply动作来说这是必需的,应该记录警告 - if action == "reply": - logger.warning(f"reply动作找不到目标消息,target_message_id: {action_json.get('target_message_id')}") - # 将reply动作改为no_action,避免后续执行时出错 - action = "no_action" - reasoning = f"找不到目标消息进行回复。原始理由: {reasoning}" - - available_action_names = list(plan.available_actions.keys()) - if action not in ["no_action", "no_reply", "reply", "do_nothing", "proactive_reply"] and action not in available_action_names: - reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}" - action = "no_action" - - parsed_actions.append( - ActionPlannerInfo( - action_type=action, - reasoning=reasoning, - action_data=action_data, - action_message=target_message_obj, - available_actions=plan.available_actions, - ) - ) - except Exception as e: - logger.error(f"解析单个action时出错: {e}") - parsed_actions.append( - ActionPlannerInfo( - action_type="no_action", - reasoning=f"解析action时出错: {e}", - ) - ) - return parsed_actions - - @staticmethod - def _filter_no_actions( - action_list: List[ActionPlannerInfo] - ) -> List[ActionPlannerInfo]: - non_no_actions = [a for a in action_list if a.action_type not in ["no_action", "no_reply"]] - if non_no_actions: - return non_no_actions - return action_list[:1] if action_list else [] - - @staticmethod - async def _get_long_term_memory_context() -> str: - try: - now = datetime.now() - keywords = ["今天", "日程", "计划"] - if 5 <= now.hour < 12: - keywords.append("早上") - elif 12 <= now.hour < 18: - keywords.append("中午") - else: - keywords.append("晚上") - - retrieved_memories = await hippocampus_manager.get_memory_from_topic( - valid_keywords=keywords, max_memory_num=5, max_memory_length=1 - ) - - if not retrieved_memories: - return "最近没有什么特别的记忆。" - - memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories] - return " ".join(memory_statements) - except Exception as e: - logger.error(f"获取长期记忆时出错: {e}") - return "回忆时出现了一些问题。" - - @staticmethod - async def _build_action_options(current_available_actions: Dict[str, ActionInfo]) -> str: - action_options_block = "" - for action_name, action_info in current_available_actions.items(): - param_text = "" - if action_info.action_parameters: - param_text = "\n" + "\n".join( - f' "{p_name}":"{p_desc}"' for p_name, p_desc in action_info.action_parameters.items() - ) - require_text = "\n".join(f"- {req}" for req in action_info.action_require) - using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt") - action_options_block += using_action_prompt.format( - action_name=action_name, - action_description=action_info.description, - action_parameters=param_text, - action_require=require_text, - ) - return action_options_block - - @staticmethod - def _find_message_by_id(message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]: - if message_id.isdigit(): - message_id = f"m{message_id}" - for item in message_id_list: - if item.get("id") == message_id: - return item.get("message") - return None - - @staticmethod - def _get_latest_message(message_id_list: list) -> Optional[Dict[str, Any]]: - if not message_id_list: - return None - return message_id_list[-1].get("message") diff --git a/src/chat/planner_actions/plan_generator.py b/src/chat/planner_actions/plan_generator.py deleted file mode 100644 index ec0a11691..000000000 --- a/src/chat/planner_actions/plan_generator.py +++ /dev/null @@ -1,110 +0,0 @@ -""" -PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的“原始计划” (Plan)。 -""" -import time -from typing import Dict - -from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat -from src.chat.utils.utils import get_chat_type_and_target_info -from src.common.data_models.database_data_model import DatabaseMessages -from src.common.data_models.info_data_model import Plan, TargetPersonInfo -from src.config.config import global_config -from src.plugin_system.base.component_types import ActionInfo, ChatMode, ComponentType -from src.plugin_system.core.component_registry import component_registry - - -class PlanGenerator: - """ - PlanGenerator 负责在规划流程的初始阶段收集所有必要信息。 - - 它会汇总以下信息来构建一个“原始”的 Plan 对象,该对象后续会由 PlanFilter 进行筛选: - - 当前聊天信息 (ID, 目标用户) - - 当前可用的动作列表 - - 最近的聊天历史记录 - - Attributes: - chat_id (str): 当前聊天的唯一标识符。 - action_manager (ActionManager): 用于获取可用动作列表的管理器。 - """ - - def __init__(self, chat_id: str): - """ - 初始化 PlanGenerator。 - - Args: - chat_id (str): 当前聊天的 ID。 - """ - from src.chat.planner_actions.action_manager import ActionManager - self.chat_id = chat_id - # 注意:ActionManager 可能需要根据实际情况初始化 - self.action_manager = ActionManager() - - async def generate(self, mode: ChatMode) -> Plan: - """ - 收集所有信息,生成并返回一个初始的 Plan 对象。 - - 这个 Plan 对象包含了决策所需的所有上下文信息。 - - Args: - mode (ChatMode): 当前的聊天模式。 - - Returns: - Plan: 一个填充了初始上下文信息的 Plan 对象。 - """ - _is_group_chat, chat_target_info_dict = await get_chat_type_and_target_info(self.chat_id) - - target_info = None - if chat_target_info_dict: - target_info = TargetPersonInfo(**chat_target_info_dict) - - available_actions = self._get_available_actions() - chat_history_raw = get_raw_msg_before_timestamp_with_chat( - chat_id=self.chat_id, - timestamp=time.time(), - limit=int(global_config.chat.max_context_size), - ) - chat_history = [DatabaseMessages(**msg) for msg in await chat_history_raw] - - - plan = Plan( - chat_id=self.chat_id, - mode=mode, - available_actions=available_actions, - chat_history=chat_history, - target_info=target_info, - ) - return plan - - def _get_available_actions(self) -> Dict[str, "ActionInfo"]: - """ - 从 ActionManager 和组件注册表中获取当前所有可用的动作。 - - 它会合并已注册的动作和系统级动作(如 "no_reply"), - 并以字典形式返回。 - - Returns: - Dict[str, "ActionInfo"]: 一个字典,键是动作名称,值是 ActionInfo 对象。 - """ - current_available_actions_dict = self.action_manager.get_using_actions() - all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore - ComponentType.ACTION - ) - - current_available_actions = {} - for action_name in current_available_actions_dict: - if action_name in all_registered_actions: - current_available_actions[action_name] = all_registered_actions[action_name] - - no_reply_info = ActionInfo( - name="no_reply", - component_type=ComponentType.ACTION, - description="系统级动作:选择不回复消息的决策", - action_parameters={}, - activation_keywords=[], - plugin_name="SYSTEM", - enabled=True, - parallel_action=False, - ) - current_available_actions["no_reply"] = no_reply_info - - return current_available_actions \ No newline at end of file diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py deleted file mode 100644 index 0e3d1afc3..000000000 --- a/src/chat/planner_actions/planner.py +++ /dev/null @@ -1,94 +0,0 @@ -""" -主规划器入口,负责协调 PlanGenerator, PlanFilter, 和 PlanExecutor。 -""" -from dataclasses import asdict -from typing import Dict, List, Optional, Tuple - -from src.chat.planner_actions.action_manager import ActionManager -from src.chat.planner_actions.plan_executor import PlanExecutor -from src.chat.planner_actions.plan_filter import PlanFilter -from src.chat.planner_actions.plan_generator import PlanGenerator -from src.common.logger import get_logger -from src.plugin_system.base.component_types import ChatMode -import src.chat.planner_actions.planner_prompts #noga # noqa: F401 -# 导入提示词模块以确保其被初始化 - -logger = get_logger("planner") - - -class ActionPlanner: - """ - ActionPlanner 是规划系统的核心协调器。 - - 它负责整合规划流程的三个主要阶段: - 1. **生成 (Generate)**: 使用 PlanGenerator 创建一个初始的行动计划。 - 2. **筛选 (Filter)**: 使用 PlanFilter 对生成的计划进行审查和优化。 - 3. **执行 (Execute)**: 使用 PlanExecutor 执行最终确定的行动。 - - Attributes: - chat_id (str): 当前聊天的唯一标识符。 - action_manager (ActionManager): 用于执行具体动作的管理器。 - generator (PlanGenerator): 负责生成初始计划。 - filter (PlanFilter): 负责筛选和优化计划。 - executor (PlanExecutor): 负责执行最终计划。 - """ - - def __init__(self, chat_id: str, action_manager: ActionManager): - """ - 初始化 ActionPlanner。 - - Args: - chat_id (str): 当前聊天的 ID。 - action_manager (ActionManager): 一个 ActionManager 实例。 - """ - self.chat_id = chat_id - self.action_manager = action_manager - self.generator = PlanGenerator(chat_id) - self.filter = PlanFilter() - self.executor = PlanExecutor(action_manager) - - async def plan( - self, mode: ChatMode = ChatMode.FOCUS - ) -> Tuple[List[Dict], Optional[Dict]]: - """ - 执行从生成到执行的完整规划流程。 - - 这个方法按顺序协调生成、筛选和执行三个阶段。 - - Args: - mode (ChatMode): 当前的聊天模式,默认为 FOCUS。 - - Returns: - Tuple[List[Dict], Optional[Dict]]: 一个元组,包含: - - final_actions_dict (List[Dict]): 最终确定的动作列表(字典格式)。 - - final_target_message_dict (Optional[Dict]): 最终的目标消息(字典格式),如果没有则为 None。 - 这与旧版 planner 的返回值保持兼容。 - """ - # 1. 生成初始 Plan - initial_plan = await self.generator.generate(mode) - - # 2. 筛选 Plan - filtered_plan = await self.filter.filter(initial_plan) - - # 3. 执行 Plan(临时引爆因为它暂时还跑不了) - #await self.executor.execute(filtered_plan) - - # 4. 返回结果 (与旧版 planner 的返回值保持兼容) - final_actions = filtered_plan.decided_actions or [] - final_target_message = next( - (act.action_message for act in final_actions if act.action_message), None - ) - - final_actions_dict = [asdict(act) for act in final_actions] - # action_message现在可能是字典而不是dataclass实例,需要特殊处理 - if final_target_message: - if hasattr(final_target_message, '__dataclass_fields__'): - # 如果是dataclass实例,使用asdict转换 - final_target_message_dict = asdict(final_target_message) - else: - # 如果已经是字典,直接使用 - final_target_message_dict = final_target_message - else: - final_target_message_dict = None - - return final_actions_dict, final_target_message_dict diff --git a/src/chat/planner_actions/planner_prompts.py b/src/chat/planner_actions/planner_prompts.py deleted file mode 100644 index d527655c8..000000000 --- a/src/chat/planner_actions/planner_prompts.py +++ /dev/null @@ -1,202 +0,0 @@ -""" -本文件集中管理所有与规划器(Planner)相关的提示词(Prompt)模板。 - -通过将提示词与代码逻辑分离,可以更方便地对模型的行为进行迭代和优化, -而无需修改核心代码。 -""" -from src.chat.utils.prompt import Prompt - - -def init_prompts(): - """ - 初始化并向 Prompt 注册系统注册所有规划器相关的提示词。 - - 这个函数会在模块加载时自动调用,确保所有提示词在系统启动时都已准备就绪。 - """ - # 核心规划器提示词,用于在接收到新消息时决定如何回应。 - # 它构建了一个复杂的上下文,包括历史记录、可用动作、角色设定等, - # 并要求模型以 JSON 格式输出一个或多个动作组合。 - Prompt( - """ -{mood_block} -{time_block} -{identity_block} - -{users_in_chat} -{custom_prompt_block} -{chat_context_description},以下是具体的聊天内容。 -{chat_content_block} - -{moderation_prompt} - -**任务: 构建一个完整的响应** -你的任务是根据当前的聊天内容,构建一个完整的、人性化的响应。一个完整的响应由两部分组成: -1. **主要动作**: 这是响应的核心,通常是 `reply`(文本回复)。 -2. **辅助动作 (可选)**: 这是为了增强表达效果的附加动作,例如 `emoji`(发送表情包)或 `poke_user`(戳一戳)。 - -**决策流程:** -1. **最高优先级检查**: 首先,检查是否有由 **关键词** 或 **LLM判断** 激活的特定动作(除了通用的 `reply`, `emoji` 等)。这些动作代表了用户的明确意图。 -2. **执行明确意图**: 如果存在这类特定动作,你 **必须** 优先选择它作为主要响应。这比常规的文本回复 (`reply`) 更重要。 -3. **常规回复**: 如果没有被特定意图激活的动作,再决定是否要进行 `reply`。 -4. **辅助动作**: 在确定了主要动作后(无论是特定动作还是 `reply`),再评估是否需要 `emoji` 或 `poke_user` 等辅助动作来增强表达效果。 -5. **互斥原则**: 当你选择了一个由明确意图激活的特定动作(如 `set_reminder`)时,你 **绝不能** 再选择 `reply` 动作,因为特定动作的执行结果(例如,设置提醒后的确认消息)本身就是一种回复。这是必须遵守的规则。 - -**重要概念:将“理由”作为“内心思考”的体现** -`reason` 字段是本次决策的核心。它并非一个简单的“理由”,而是 **一个模拟人类在回应前,头脑中自然浮现的、未经修饰的思绪流**。你需要完全代入 {identity_block} 的角色,将那一刻的想法自然地记录下来。 - -**内心思考的要点:** -* **自然流露**: 不要使用“决定”、“所以”、“因此”等结论性或汇报式的词语。你的思考应该像日记一样,是给自己看的,充满了不确定性和情绪的自然流动。 -* **展现过程**: 重点在于展现 **思考的过程**,而不是 **决策的结果**。描述你看到了什么,想到了什么,感受到了什么。 -* **人设核心**: 你的每一丝想法,都应该源于你的人设。思考“如果我是这个角色,我此刻会想些什么?” -* **通用模板**: 这是一套通用模板,请 **不要** 在示例中出现特定的人名或个性化内容,以确保其普适性。 - -**思考过程示例 (通用模板):** -* "用户好像在说一件开心的事,语气听起来很兴奋。这让我想起了……嗯,我也觉得很开心,很想分享这份喜悦。" -* "感觉气氛有点低落……他说的话让我有点担心。也许我该说点什么安慰一下?" -* "哦?这个话题真有意思,我以前好像也想过类似的事情。不知道他会怎么看呢……" - -**可用动作:** -{actions_before_now_block} - -{no_action_block} - -动作:reply -动作描述:参与聊天回复,发送文本进行表达 -- 你想要闲聊或者随便附和 -- {mentioned_bonus} -- 如果你刚刚进行了回复,不要对同一个话题重复回应 -- 不要回复自己发送的消息 -{{ - "action": "reply", - "target_message_id": "触发action的消息id", - "reason": "在这里详细记录你的内心思考过程。例如:‘用户看起来很开心,我想回复一些积极的内容,分享这份喜悦。’" -}} - -{action_options_text} - - -**输出格式:** -你必须以严格的 JSON 格式输出,返回一个包含所有选定动作的JSON列表。如果没有任何合适的动作,返回一个空列表[]。 - -**单动作示例 (仅回复):** -[ - {{ - "action": "reply", - "target_message_id": "m123", - "reason": "感觉气氛有点低落……他说的话让我有点担心。也许我该说点什么安慰一下?" - }} -] - -**组合动作示例 (回复 + 表情包):** -[ - {{ - "action": "reply", - "target_message_id": "m123", - "reason": "[观察与感受] 用户分享了一件开心的事,语气里充满了喜悦! [分析与联想] 看到他这么开心,我的心情也一下子变得像棉花糖一样甜~ [动机与决策] 我要由衷地为他感到高兴,决定回复一些赞美和祝福的话,把这份快乐的气氛推向高潮!" - }}, - {{ - "action": "emoji", - "target_message_id": "m123", - "reason": "光用文字还不够表达我激动的心情!加个表情包的话,这份喜悦的气氛应该会更浓厚一点吧!" - }} -] - -**单动作示例 (特定动作):** -[ - {{ - "action": "set_reminder", - "target_message_id": "m456", - "reason": "用户说‘提醒维尔薇下午三点去工坊’,这是一个非常明确的指令。根据决策流程,我必须优先执行这个特定动作,而不是进行常规回复。", - "user_name": "维尔薇", - "remind_time": "下午三点", - "event_details": "去工坊" - }} -] - -**重要规则:** -**重要规则:** -当 `reply` 和 `emoji` 动作同时被选择时,`emoji` 动作的 `reason` 字段也应该体现出你的思考过程,并与 `reply` 的思考保持连贯。 - -不要输出markdown格式```json等内容,直接输出且仅包含 JSON 列表内容: -""", - "planner_prompt", - ) - - # 主动思考规划器提示词,用于在没有新消息时决定是否要主动发起对话。 - # 它模拟了人类的自发性思考,允许模型根据长期记忆和最近的对话来决定是否开启新话题。 - Prompt( - """ -# 主动思考决策 - -## 你的内部状态 -{time_block} -{identity_block} -{mood_block} - -## 长期记忆摘要 -{long_term_memory_block} - -## 最近的聊天内容 -{chat_content_block} - -## 最近的动作历史 -{actions_before_now_block} - -## 任务 -你现在要决定是否主动说些什么。就像一个真实的人一样,有时候会突然想起之前聊到的话题,或者对朋友的近况感到好奇,想主动询问或关心一下。 -**重要提示**:你的日程安排仅供你个人参考,不应作为主动聊天话题的主要来源。请更多地从聊天内容和朋友的动态中寻找灵感。 - -请基于聊天内容,用你的判断力来决定是否要主动发言。不要按照固定规则,而是像人类一样自然地思考: -- 是否想起了什么之前提到的事情,想问问后来怎么样了? -- 是否注意到朋友提到了什么值得关心的事情? -- 是否有什么话题突然想到,觉得现在聊聊很合适? -- 或者觉得现在保持沉默更好? - -## 可用动作 -动作:proactive_reply -动作描述:主动发起对话,可以是关心朋友、询问近况、延续之前的话题,或分享想法。 -- 当你突然想起之前的话题,想询问进展时 -- 当你想关心朋友的情况时 -- 当你有什么想法想分享时 -- 当你觉得现在是个合适的聊天时机时 -{{ - "action": "proactive_reply", - "reason": "你决定主动发言的具体原因", - "topic": "你想说的内容主题(简洁描述)" -}} - -动作:do_nothing -动作描述:保持沉默,不主动发起对话。 -- 当你觉得现在不是合适的时机时 -- 当最近已经说得够多了时 -- 当对话氛围不适合插入时 -{{ - "action": "do_nothing", - "reason": "决定保持沉默的原因" -}} - -你必须从上面列出的可用action中选择一个。要像真人一样自然地思考和决策。 -请以严格的 JSON 格式输出,且仅包含 JSON 内容: -""", - "proactive_planner_prompt", - ) - - # 单个动作的格式化提示词模板。 - # 用于将每个可用动作的信息格式化后,插入到主提示词的 {action_options_text} 占位符中。 - Prompt( - """ -动作:{action_name} -动作描述:{action_description} -{action_require} -{{ - "action": "{action_name}", - "target_message_id": "触发action的消息id", - "reason": "触发action的原因"{action_parameters} -}} -""", - "action_prompt", - ) - - -# 在模块加载时自动调用,完成提示词的注册。 -init_prompts() \ No newline at end of file diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 76221ac1c..868e34f21 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -31,7 +31,6 @@ from src.chat.express.expression_selector import expression_selector from src.chat.memory_system.memory_activator import MemoryActivator from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2 from src.mood.mood_manager import mood_manager -from src.person_info.relationship_fetcher import relationship_fetcher_manager from src.person_info.person_info import get_person_info_manager from src.plugin_system.base.component_types import ActionInfo, EventType from src.plugin_system.apis import llm_api @@ -83,13 +82,13 @@ def init_prompt(): - {schedule_block} ## 历史记录 -### {chat_context_type}中的所有人的聊天记录: -{background_dialogue_prompt} +### 📜 已读历史消息(仅供参考) +{read_history_prompt} {cross_context_block} -### {chat_context_type}中正在与你对话的聊天记录 -{core_dialogue_prompt} +### 📬 未读历史消息(动作执行对象) +{unread_history_prompt} ## 表达方式 - *你需要参考你的回复风格:* @@ -105,19 +104,38 @@ def init_prompt(): ## 其他信息 {memory_block} {relation_info_block} + {extra_info_block} + {action_descriptions} ## 任务 -*你正在一个{chat_context_type}里聊天,你需要理解整个{chat_context_type}的聊天动态和话题走向,并做出自然的回应。* +*{chat_scene}* ### 核心任务 -- 你现在的主要任务是和 {sender_name} 聊天。 -- {reply_target_block} ,你需要生成一段紧密相关且能推动对话的回复。 +- 你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与聊天,你可以参考他们的回复内容,但是你现在想回复{sender_name}的发言。 + +- {reply_target_block} 你需要生成一段紧密相关且能推动对话的回复。 ## 规则 {safety_guidelines_block} +**重要提醒:** +- **已读历史消息仅作为当前聊天情景的参考** +- **动作执行对象只能是未读历史消息中的消息** +- **请优先对兴趣值高的消息做出回复**(兴趣度标注在未读消息末尾) + +在回应之前,首先分析消息的针对性: +1. **直接针对你**:@你、回复你、明确询问你 → 必须回应 +2. **间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与 +3. **他人对话**:与你无关的私人交流 → 通常不参与 +4. **重复内容**:他人已充分回答的问题 → 避免重复 + +你的回复应该: +1. 明确回应目标消息,而不是宽泛地评论。 +2. 可以分享你的看法、提出相关问题,或者开个合适的玩笑。 +3. 目的是让对话更有趣、更深入。 +4. 不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。 最终请输出一条简短、完整且口语化的回复。 -------------------------------- @@ -153,10 +171,14 @@ If you need to use the search tool, please directly call the function "lpmm_sear logger.debug("[Prompt模式调试] 正在注册normal_style_prompt模板") Prompt( """ -你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。 +{chat_scene} **重要:消息针对性判断** -{safety_guidelines_block} +在回应之前,首先分析消息的针对性: +1. **直接针对你**:@你、回复你、明确询问你 → 必须回应 +2. **间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与 +3. **他人对话**:与你无关的私人交流 → 通常不参与 +4. **重复内容**:他人已充分回答的问题 → 避免重复 {expression_habits_block} {tool_info_block} @@ -186,6 +208,10 @@ If you need to use the search tool, please directly call the function "lpmm_sear {keywords_reaction_prompt} 请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。 {moderation_prompt} +你的核心任务是针对 {reply_target_block} 中提到的内容,{relation_info_block}生成一段紧密相关且能推动对话的回复。你的回复应该: +1. 明确回应目标消息,而不是宽泛地评论。 +2. 可以分享你的看法、提出相关问题,或者开个合适的玩笑。 +3. 目的是让对话更有趣、更深入。 最终请输出一条简短、完整且口语化的回复。 现在,你说: """, @@ -202,9 +228,7 @@ class DefaultReplyer: ): self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream - self.is_group_chat: Optional[bool] = None - self.chat_target_info: Optional[Dict[str, Any]] = None - self._initialized = False + self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) self.heart_fc_sender = HeartFCSender() self.memory_activator = MemoryActivator() @@ -215,19 +239,6 @@ class DefaultReplyer: self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id) - def _should_block_self_message(self, reply_message: Optional[Dict[str, Any]]) -> bool: - """判定是否应阻断当前待处理消息(自消息且无外部触发)""" - try: - bot_id = str(global_config.bot.qq_account) - uid = str(reply_message.get("user_id")) - if uid != bot_id: - return False - - return True - except Exception as e: - logger.warning(f"[SelfGuard] 判定异常,回退为不阻断: {e}") - return False - async def generate_reply_with_context( self, reply_to: str = "", @@ -237,7 +248,6 @@ class DefaultReplyer: from_plugin: bool = True, stream_id: Optional[str] = None, reply_message: Optional[Dict[str, Any]] = None, - read_mark: float = 0.0, ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: # sourcery skip: merge-nested-ifs """ @@ -256,10 +266,6 @@ class DefaultReplyer: prompt = None if available_actions is None: available_actions = {} - # 自消息阻断 - if self._should_block_self_message(reply_message): - logger.debug("[SelfGuard] 阻断:自消息且无外部触发。") - return False, None, None llm_response = None try: # 构建 Prompt @@ -270,7 +276,6 @@ class DefaultReplyer: available_actions=available_actions, enable_tool=enable_tool, reply_message=reply_message, - read_mark=read_mark, ) if not prompt: @@ -300,7 +305,7 @@ class DefaultReplyer: "model": model_name, "tool_calls": tool_call, } - + # 触发 AFTER_LLM 事件 if not from_plugin: result = await event_manager.trigger_event( @@ -592,17 +597,16 @@ class DefaultReplyer: logger.error(f"工具信息获取失败: {e}") return "" - @staticmethod - def _parse_reply_target(target_message: str) -> Tuple[str, str]: + def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: """解析回复目标消息 - 使用共享工具""" from src.chat.utils.prompt import Prompt + if target_message is None: logger.warning("target_message为None,返回默认值") return "未知用户", "(无消息内容)" return Prompt.parse_reply_target(target_message) - @staticmethod - async def build_keywords_reaction_prompt(target: Optional[str]) -> str: + async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str: """构建关键词反应提示 Args: @@ -644,8 +648,7 @@ class DefaultReplyer: return keywords_reaction_prompt - @staticmethod - async def _time_and_run_task(coroutine, name: str) -> Tuple[str, Any, float]: + async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]: """计时并运行异步任务的辅助函数 Args: @@ -662,79 +665,259 @@ class DefaultReplyer: return name, result, duration async def build_s4u_chat_history_prompts( - self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str + self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str ) -> Tuple[str, str]: """ - 构建 s4u 风格的分离对话 prompt + 构建 s4u 风格的已读/未读历史消息 prompt Args: message_list_before_now: 历史消息列表 target_user_id: 目标用户ID(当前对话对象) + sender: 发送者名称 + chat_id: 聊天ID Returns: - Tuple[str, str]: (核心对话prompt, 背景对话prompt) + Tuple[str, str]: (已读历史消息prompt, 未读历史消息prompt) """ - core_dialogue_list = [] + try: + # 从message_manager获取真实的已读/未读消息 + from src.chat.message_manager.message_manager import message_manager + + # 获取聊天流的上下文 + stream_context = message_manager.stream_contexts.get(chat_id) + if stream_context: + # 使用真正的已读和未读消息 + read_messages = stream_context.history_messages # 已读消息 + unread_messages = stream_context.get_unread_messages() # 未读消息 + + # 构建已读历史消息 prompt + read_history_prompt = "" + if read_messages: + read_content = build_readable_messages( + [msg.flatten() for msg in read_messages[-50:]], # 限制数量 + replace_bot_name=True, + timestamp_mode="normal_no_YMD", + truncate=True, + ) + read_history_prompt = f"这是已读历史消息,仅作为当前聊天情景的参考:\n{read_content}" + else: + # 如果没有已读消息,则从数据库加载最近的上下文 + logger.info("暂无已读历史消息,正在从数据库加载上下文...") + fallback_messages = get_raw_msg_before_timestamp_with_chat( + chat_id=chat_id, + timestamp=time.time(), + limit=global_config.chat.max_context_size, + ) + if fallback_messages: + # 从 unread_messages 获取 message_id 列表,用于去重 + unread_message_ids = {msg.message_id for msg in unread_messages} + filtered_fallback_messages = [ + msg for msg in fallback_messages if msg.get("message_id") not in unread_message_ids + ] + + if filtered_fallback_messages: + read_content = build_readable_messages( + filtered_fallback_messages, + replace_bot_name=True, + timestamp_mode="normal_no_YMD", + truncate=True, + ) + read_history_prompt = f"这是已读历史消息,仅作为当前聊天情景的参考:\n{read_content}" + else: + read_history_prompt = "暂无已读历史消息" + else: + read_history_prompt = "暂无已读历史消息" + + # 构建未读历史消息 prompt(包含兴趣度) + unread_history_prompt = "" + if unread_messages: + # 尝试获取兴趣度评分 + interest_scores = await self._get_interest_scores_for_messages( + [msg.flatten() for msg in unread_messages] + ) + + unread_lines = [] + for msg in unread_messages: + msg_id = msg.message_id + msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time)) + msg_content = msg.processed_plain_text + + # 使用与已读历史消息相同的方法获取用户名 + from src.person_info.person_info import PersonInfoManager, get_person_info_manager + + # 获取用户信息 + user_info = getattr(msg, "user_info", {}) + platform = getattr(user_info, "platform", "") or getattr(msg, "platform", "") + user_id = getattr(user_info, "user_id", "") or getattr(msg, "user_id", "") + + # 获取用户名 + if platform and user_id: + person_id = PersonInfoManager.get_person_id(platform, user_id) + person_info_manager = get_person_info_manager() + sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户" + else: + sender_name = "未知用户" + + # 添加兴趣度信息 + interest_score = interest_scores.get(msg_id, 0.0) + interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else "" + + unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}") + + unread_history_prompt_str = "\n".join(unread_lines) + unread_history_prompt = f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}" + else: + unread_history_prompt = "暂无未读历史消息" + + return read_history_prompt, unread_history_prompt + else: + # 回退到传统方法 + return await self._fallback_build_chat_history_prompts(message_list_before_now, target_user_id, sender) + + except Exception as e: + logger.warning(f"获取已读/未读历史消息失败,使用回退方法: {e}") + return await self._fallback_build_chat_history_prompts(message_list_before_now, target_user_id, sender) + + async def _fallback_build_chat_history_prompts( + self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str + ) -> Tuple[str, str]: + """ + 回退的已读/未读历史消息构建方法 + """ + # 通过is_read字段分离已读和未读消息 + read_messages = [] + unread_messages = [] bot_id = str(global_config.bot.qq_account) - # 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话 for msg_dict in message_list_before_now: try: msg_user_id = str(msg_dict.get("user_id")) - reply_to = msg_dict.get("reply_to", "") - _platform, reply_to_user_id = self._parse_reply_target(reply_to) - if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id: - # bot 和目标用户的对话 - core_dialogue_list.append(msg_dict) + if msg_dict.get("is_read", False): + read_messages.append(msg_dict) + else: + unread_messages.append(msg_dict) except Exception as e: logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}") - # 构建背景对话 prompt - all_dialogue_prompt = "" - if message_list_before_now: - latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :] - all_dialogue_prompt_str = await build_readable_messages( - latest_25_msgs, + # 如果没有is_read字段,使用原有的逻辑 + if not read_messages and not unread_messages: + # 使用原有的核心对话逻辑 + core_dialogue_list = [] + for msg_dict in message_list_before_now: + try: + msg_user_id = str(msg_dict.get("user_id")) + reply_to = msg_dict.get("reply_to", "") + _platform, reply_to_user_id = self._parse_reply_target(reply_to) + if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id: + core_dialogue_list.append(msg_dict) + except Exception as e: + logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}") + + read_messages = [msg for msg in message_list_before_now if msg not in core_dialogue_list] + unread_messages = core_dialogue_list + + # 构建已读历史消息 prompt + read_history_prompt = "" + if read_messages: + read_content = build_readable_messages( + read_messages[-50:], replace_bot_name=True, - timestamp_mode="normal", + timestamp_mode="normal_no_YMD", truncate=True, ) - all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}" + read_history_prompt = f"这是已读历史消息,仅作为当前聊天情景的参考:\n{read_content}" + else: + read_history_prompt = "暂无已读历史消息" - # 构建核心对话 prompt - core_dialogue_prompt = "" - if core_dialogue_list: - # 检查最新五条消息中是否包含bot自己说的消息 - latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list - has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages) + # 构建未读历史消息 prompt + unread_history_prompt = "" + if unread_messages: + # 尝试获取兴趣度评分 + interest_scores = await self._get_interest_scores_for_messages(unread_messages) - # logger.info(f"最新五条消息:{latest_5_messages}") - # logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}") + unread_lines = [] + for msg in unread_messages: + msg_id = msg.get("message_id", "") + msg_time = time.strftime("%H:%M:%S", time.localtime(msg.get("time", time.time()))) + msg_content = msg.get("processed_plain_text", "") - # 如果最新五条消息中不包含bot的消息,则返回空字符串 - if not has_bot_message: - core_dialogue_prompt = "" - else: - core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量 + # 使用与已读历史消息相同的方法获取用户名 + from src.person_info.person_info import PersonInfoManager, get_person_info_manager - core_dialogue_prompt_str = await build_readable_messages( - core_dialogue_list, - replace_bot_name=True, - merge_messages=False, - timestamp_mode="normal_no_YMD", - read_mark=0.0, - truncate=True, - show_actions=True, - ) - core_dialogue_prompt = f""" -{core_dialogue_prompt_str} -""" + # 获取用户信息 + user_info = msg.get("user_info", {}) + platform = user_info.get("platform") or msg.get("platform", "") + user_id = user_info.get("user_id") or msg.get("user_id", "") - return core_dialogue_prompt, all_dialogue_prompt + # 获取用户名 + if platform and user_id: + person_id = PersonInfoManager.get_person_id(platform, user_id) + person_info_manager = get_person_info_manager() + sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户" + else: + sender_name = "未知用户" + + # 添加兴趣度信息 + interest_score = interest_scores.get(msg_id, 0.0) + interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else "" + + unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}") + + unread_history_prompt_str = "\n".join(unread_lines) + unread_history_prompt = ( + f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}" + ) + else: + unread_history_prompt = "暂无未读历史消息" + + return read_history_prompt, unread_history_prompt + + async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]: + """为消息获取兴趣度评分""" + interest_scores = {} + + try: + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system as interest_scoring_system + from src.common.data_models.database_data_model import DatabaseMessages + + # 转换消息格式 + db_messages = [] + for msg_dict in messages: + try: + db_msg = DatabaseMessages( + message_id=msg_dict.get("message_id", ""), + time=msg_dict.get("time", time.time()), + chat_id=msg_dict.get("chat_id", ""), + processed_plain_text=msg_dict.get("processed_plain_text", ""), + user_id=msg_dict.get("user_id", ""), + user_nickname=msg_dict.get("user_nickname", ""), + user_platform=msg_dict.get("platform", "qq"), + chat_info_group_id=msg_dict.get("group_id", ""), + chat_info_group_name=msg_dict.get("group_name", ""), + chat_info_group_platform=msg_dict.get("platform", "qq"), + ) + db_messages.append(db_msg) + except Exception as e: + logger.warning(f"转换消息格式失败: {e}") + continue + + # 计算兴趣度评分 + if db_messages: + bot_nickname = global_config.bot.nickname or "麦麦" + scores = await interest_scoring_system.calculate_interest_scores(db_messages, bot_nickname) + + # 构建兴趣度字典 + for score in scores: + interest_scores[score.message_id] = score.total_score + + except Exception as e: + logger.warning(f"获取兴趣度评分失败: {e}") + + return interest_scores - @staticmethod def build_mai_think_context( - chat_id: str, + self, + chat_id: str, memory_block: str, relation_info: str, time_block: str, @@ -777,12 +960,6 @@ class DefaultReplyer: mai_think.target = target return mai_think - async def _async_init(self): - if self._initialized: - return - self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_stream.stream_id) - self._initialized = True - async def build_prompt_reply_context( self, reply_to: str, @@ -790,7 +967,6 @@ class DefaultReplyer: available_actions: Optional[Dict[str, ActionInfo]] = None, enable_tool: bool = True, reply_message: Optional[Dict[str, Any]] = None, - read_mark: float = 0.0, ) -> str: """ 构建回复器上下文 @@ -808,11 +984,10 @@ class DefaultReplyer: """ if available_actions is None: available_actions = {} - await self._async_init() chat_stream = self.chat_stream chat_id = chat_stream.stream_id person_info_manager = get_person_info_manager() - is_group_chat = self.is_group_chat + is_group_chat = bool(chat_stream.group_info) if global_config.mood.enable_mood: chat_mood = mood_manager.get_mood_by_chat_id(chat_id) @@ -829,35 +1004,38 @@ class DefaultReplyer: # 兼容旧的reply_to sender, target = self._parse_reply_target(reply_to) else: - # 需求:遍历最近消息,找到第一条 user_id != bot_id 的消息作为目标;找不到则静默退出 - bot_user_id = str(global_config.bot.qq_account) - # 优先使用传入的 reply_message 如果它不是 bot - candidate_msg = None - if reply_message and str(reply_message.get("user_id")) != bot_user_id: - candidate_msg = reply_message - else: - try: - recent_msgs = await get_raw_msg_before_timestamp_with_chat( - chat_id=chat_id, - timestamp=time.time(), - limit= max(10, int(global_config.chat.max_context_size * 0.5)), - ) - # 从最近到更早遍历,找第一条不是bot的 - for m in reversed(recent_msgs): - if str(m.get("user_id")) != bot_user_id: - candidate_msg = m - break - except Exception as e: - logger.error(f"获取最近消息失败: {e}") - if not candidate_msg: - logger.debug("未找到可作为目标的非bot消息,静默不回复。") + # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 + if reply_message is None: + logger.warning("reply_message 为 None,无法构建prompt") return "" - platform = candidate_msg.get("chat_info_platform") or self.chat_stream.platform - person_id = person_info_manager.get_person_id(platform, candidate_msg.get("user_id")) - person_info = await person_info_manager.get_values(person_id, ["person_name", "user_id"]) if person_id else {} - person_name = person_info.get("person_name") or candidate_msg.get("user_nickname") or candidate_msg.get("user_id") or "未知用户" - sender = person_name - target = candidate_msg.get("processed_plain_text") or candidate_msg.get("raw_message") or "" + platform = reply_message.get("chat_info_platform") + person_id = person_info_manager.get_person_id( + platform, # type: ignore + reply_message.get("user_id"), # type: ignore + ) + person_name = await person_info_manager.get_value(person_id, "person_name") + + # 如果person_name为None,使用fallback值 + if person_name is None: + # 尝试从reply_message获取用户名 + await person_info_manager.first_knowing_some_one( + platform, # type: ignore + reply_message.get("user_id"), # type: ignore + reply_message.get("user_nickname"), + reply_message.get("user_cardname") + ) + + # 检查是否是bot自己的名字,如果是则替换为"(你)" + bot_user_id = str(global_config.bot.qq_account) + current_user_id = person_info_manager.get_value_sync(person_id, "user_id") + current_platform = reply_message.get("chat_info_platform") + + if current_user_id == bot_user_id and current_platform == global_config.bot.platform: + sender = f"{person_name}(你)" + else: + # 如果不是bot自己,直接使用person_name + sender = person_name + target = reply_message.get("processed_plain_text") # 最终的空值检查,确保sender和target不为None if sender is None: @@ -868,13 +1046,11 @@ class DefaultReplyer: target = "(无消息内容)" person_info_manager = get_person_info_manager() - person_id = person_info_manager.get_person_id(platform, reply_message.get("user_id")) if reply_message else None + person_id = person_info_manager.get_person_id_by_person_name(sender) platform = chat_stream.platform target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True) - # (简化)不再对自消息做额外任务段落清理,只通过前置选择逻辑避免自目标 - # 构建action描述 (如果启用planner) action_descriptions = "" if available_actions: @@ -884,31 +1060,33 @@ class DefaultReplyer: action_descriptions += f"- {action_name}: {action_description}\n" action_descriptions += "\n" - message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat( + message_list_before_now_long = get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), limit=global_config.chat.max_context_size * 2, ) - message_list_before_short = await get_raw_msg_before_timestamp_with_chat( + message_list_before_short = get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.33), ) - chat_talking_prompt_short = await build_readable_messages( + chat_talking_prompt_short = build_readable_messages( message_list_before_short, replace_bot_name=True, merge_messages=False, timestamp_mode="relative", - read_mark=read_mark, + read_mark=0.0, show_actions=True, ) + # 获取目标用户信息,用于s4u模式 target_user_info = None if sender: target_user_info = await person_info_manager.get_person_info_by_name(sender) - + from src.chat.utils.prompt import Prompt + # 并行执行六个构建任务 task_results = await asyncio.gather( self._time_and_run_task( @@ -984,6 +1162,7 @@ class DefaultReplyer: schedule_block = "" if global_config.planning_system.schedule_enable: from src.schedule.schedule_manager import schedule_manager + current_activity = schedule_manager.get_current_activity() if current_activity: schedule_block = f"你当前正在:{current_activity}。" @@ -996,43 +1175,12 @@ class DefaultReplyer: safety_guidelines = global_config.personality.safety_guidelines safety_guidelines_block = "" if safety_guidelines: - guidelines_text = "\n".join(f"{i+1}. {line}" for i, line in enumerate(safety_guidelines)) + guidelines_text = "\n".join(f"{i + 1}. {line}" for i, line in enumerate(safety_guidelines)) safety_guidelines_block = f"""### 安全与互动底线 在任何情况下,你都必须遵守以下由你的设定者为你定义的原则: {guidelines_text} 如果遇到违反上述原则的请求,请在保持你核心人设的同时,巧妙地拒绝或转移话题。 """ - - # 新增逻辑:构建回复规则块 - reply_targeting_rules = global_config.personality.reply_targeting_rules - message_targeting_analysis = global_config.personality.message_targeting_analysis - reply_principles = global_config.personality.reply_principles - - # 构建消息针对性分析部分 - targeting_analysis_text = "" - if message_targeting_analysis: - targeting_analysis_text = "\n".join(f"{i+1}. {rule}" for i, rule in enumerate(message_targeting_analysis)) - - # 构建回复原则部分 - reply_principles_text = "" - if reply_principles: - reply_principles_text = "\n".join(f"{i+1}. {principle}" for i, principle in enumerate(reply_principles)) - - # 综合构建完整的规则块 - if targeting_analysis_text or reply_principles_text: - complete_rules_block = "" - if targeting_analysis_text: - complete_rules_block += f""" -在回应之前,首先分析消息的针对性: -{targeting_analysis_text} -""" - if reply_principles_text: - complete_rules_block += f""" -你的回复应该: -{reply_principles_text} -""" - # 将规则块添加到safety_guidelines_block - safety_guidelines_block += complete_rules_block if sender and target: if is_group_chat: @@ -1057,8 +1205,15 @@ class DefaultReplyer: # 根据配置选择模板 current_prompt_mode = global_config.personality.prompt_mode + # 动态生成聊天场景提示 + if is_group_chat: + chat_scene_prompt = "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。" + else: + chat_scene_prompt = f"你正在和 {sender} 私下聊天,你需要理解你们的对话并做出自然的回应。" + # 使用新的统一Prompt系统 - 创建PromptParameters prompt_parameters = PromptParameters( + chat_scene=chat_scene_prompt, chat_id=chat_id, is_group_chat=is_group_chat, sender=sender, @@ -1090,7 +1245,6 @@ class DefaultReplyer: reply_target_block=reply_target_block, mood_prompt=mood_prompt, action_descriptions=action_descriptions, - read_mark=read_mark, ) # 使用新的统一Prompt系统 - 使用正确的模板名称 @@ -1101,14 +1255,12 @@ class DefaultReplyer: template_name = "normal_style_prompt" elif current_prompt_mode == "minimal": template_name = "default_expressor_prompt" - + # 获取模板内容 template_prompt = await global_prompt_manager.get_prompt_async(template_name) prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters) prompt_text = await prompt.build() - # 自目标情况已在上游通过筛选避免,这里不再额外修改 prompt - # --- 动态添加分割指令 --- if global_config.response_splitter.enable and global_config.response_splitter.split_mode == "llm": split_instruction = """ @@ -1137,10 +1289,9 @@ class DefaultReplyer: reply_to: str, reply_message: Optional[Dict[str, Any]] = None, ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if - await self._async_init() chat_stream = self.chat_stream chat_id = chat_stream.stream_id - is_group_chat = self.is_group_chat + is_group_chat = bool(chat_stream.group_info) if reply_message: sender = reply_message.get("sender") @@ -1168,17 +1319,17 @@ class DefaultReplyer: else: mood_prompt = "" - message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat( + message_list_before_now_half = get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), limit=min(int(global_config.chat.max_context_size * 0.33), 15), ) - chat_talking_prompt_half = await build_readable_messages( + chat_talking_prompt_half = build_readable_messages( message_list_before_now_half, replace_bot_name=True, merge_messages=False, timestamp_mode="relative", - read_mark=read_mark, + read_mark=0.0, show_actions=True, ) @@ -1370,16 +1521,57 @@ class DefaultReplyer: if not global_config.relationship.enable_relationship: return "" - relationship_fetcher = relationship_fetcher_manager.get_fetcher(self.chat_stream.stream_id) - # 获取用户ID person_info_manager = get_person_info_manager() - person_id = await person_info_manager.get_person_id_by_person_name(sender) + person_id = person_info_manager.get_person_id_by_person_name(sender) if not person_id: logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取") return f"你完全不认识{sender},不理解ta的相关信息。" - return await relationship_fetcher.build_relation_info(person_id, points_num=5) + # 使用AFC关系追踪器获取关系信息 + try: + from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker + + # 创建关系追踪器实例 + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system + relationship_tracker = ChatterRelationshipTracker(chatter_interest_scoring_system) + if relationship_tracker: + # 获取用户信息以获取真实的user_id + user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"]) + user_id = user_info.get("user_id", "unknown") + + # 从数据库获取关系数据 + relationship_data = relationship_tracker._get_user_relationship_from_db(user_id) + if relationship_data: + relationship_text = relationship_data.get("relationship_text", "") + relationship_score = relationship_data.get("relationship_score", 0.3) + + # 构建丰富的关系信息描述 + if relationship_text: + # 转换关系分数为描述性文本 + if relationship_score >= 0.8: + relationship_level = "非常亲密的朋友" + elif relationship_score >= 0.6: + relationship_level = "好朋友" + elif relationship_score >= 0.4: + relationship_level = "普通朋友" + elif relationship_score >= 0.2: + relationship_level = "认识的人" + else: + relationship_level = "陌生人" + + return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}" + else: + return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。" + else: + return f"你完全不认识{sender},这是第一次互动。" + else: + logger.warning("AFC关系追踪器未初始化,使用默认关系信息") + return f"你与{sender}是普通朋友关系。" + + except Exception as e: + logger.error(f"获取AFC关系信息失败: {e}") + return f"你与{sender}是普通朋友关系。" def weighted_sample_no_replacement(items, weights, k) -> list: diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index e2d0a4fb9..7335b5546 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -37,7 +37,7 @@ def replace_user_references_sync( """ if not content: return "" - + if name_resolver is None: person_info_manager = get_person_info_manager() @@ -46,8 +46,8 @@ def replace_user_references_sync( if replace_bot_name and user_id == global_config.bot.qq_account: return f"{global_config.bot.nickname}(你)" person_id = PersonInfoManager.get_person_id(platform, user_id) - return person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore - + return person_info_manager.get_value_sync(person_id, "person_name") or user_id # type: ignore + name_resolver = default_resolver # 处理回复格式 @@ -121,8 +121,7 @@ async def replace_user_references_async( if replace_bot_name and user_id == global_config.bot.qq_account: return f"{global_config.bot.nickname}(你)" person_id = PersonInfoManager.get_person_id(platform, user_id) - person_info = await person_info_manager.get_values(person_id, ["person_name"]) - return person_info.get("person_name") or user_id + return await person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore name_resolver = default_resolver @@ -170,7 +169,7 @@ async def replace_user_references_async( return content -async def get_raw_msg_by_timestamp( +def get_raw_msg_by_timestamp( timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """ @@ -181,10 +180,10 @@ async def get_raw_msg_by_timestamp( filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}} # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None - return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) + return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -async def get_raw_msg_by_timestamp_with_chat( +def get_raw_msg_by_timestamp_with_chat( chat_id: str, timestamp_start: float, timestamp_end: float, @@ -201,7 +200,7 @@ async def get_raw_msg_by_timestamp_with_chat( # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None # 直接将 limit_mode 传递给 find_messages - return await find_messages( + return find_messages( message_filter=filter_query, sort=sort_order, limit=limit, @@ -211,7 +210,7 @@ async def get_raw_msg_by_timestamp_with_chat( ) -async def get_raw_msg_by_timestamp_with_chat_inclusive( +def get_raw_msg_by_timestamp_with_chat_inclusive( chat_id: str, timestamp_start: float, timestamp_end: float, @@ -228,12 +227,12 @@ async def get_raw_msg_by_timestamp_with_chat_inclusive( sort_order = [("time", 1)] if limit == 0 else None # 直接将 limit_mode 传递给 find_messages - return await find_messages( + return find_messages( message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot ) -async def get_raw_msg_by_timestamp_with_chat_users( +def get_raw_msg_by_timestamp_with_chat_users( chat_id: str, timestamp_start: float, timestamp_end: float, @@ -252,10 +251,10 @@ async def get_raw_msg_by_timestamp_with_chat_users( } # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None - return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) + return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -async def get_actions_by_timestamp_with_chat( +def get_actions_by_timestamp_with_chat( chat_id: str, timestamp_start: float = 0, timestamp_end: float = time.time(), @@ -274,10 +273,10 @@ async def get_actions_by_timestamp_with_chat( f"limit={limit}, limit_mode={limit_mode}" ) - async with get_db_session() as session: + with get_db_session() as session: if limit > 0: if limit_mode == "latest": - query = await session.execute( + query = session.execute( select(ActionRecords) .where( and_( @@ -307,7 +306,7 @@ async def get_actions_by_timestamp_with_chat( } actions_result.append(action_dict) else: # earliest - query = await session.execute( + query = session.execute( select(ActionRecords) .where( and_( @@ -337,7 +336,7 @@ async def get_actions_by_timestamp_with_chat( } actions_result.append(action_dict) else: - query = await session.execute( + query = session.execute( select(ActionRecords) .where( and_( @@ -368,14 +367,14 @@ async def get_actions_by_timestamp_with_chat( return actions_result -async def get_actions_by_timestamp_with_chat_inclusive( +def get_actions_by_timestamp_with_chat_inclusive( chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表""" - async with get_db_session() as session: + with get_db_session() as session: if limit > 0: if limit_mode == "latest": - query = await session.execute( + query = session.execute( select(ActionRecords) .where( and_( @@ -390,7 +389,7 @@ async def get_actions_by_timestamp_with_chat_inclusive( actions = list(query.scalars()) return [action.__dict__ for action in reversed(actions)] else: # earliest - query = await session.execute( + query = session.execute( select(ActionRecords) .where( and_( @@ -403,7 +402,7 @@ async def get_actions_by_timestamp_with_chat_inclusive( .limit(limit) ) else: - query = await session.execute( + query = session.execute( select(ActionRecords) .where( and_( @@ -419,14 +418,14 @@ async def get_actions_by_timestamp_with_chat_inclusive( return [action.__dict__ for action in actions] -async def get_raw_msg_by_timestamp_random( +def get_raw_msg_by_timestamp_random( timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """ 先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息 """ # 获取所有消息,只取chat_id字段 - all_msgs = await get_raw_msg_by_timestamp(timestamp_start, timestamp_end) + all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end) if not all_msgs: return [] # 随机选一条 @@ -434,10 +433,10 @@ async def get_raw_msg_by_timestamp_random( chat_id = msg["chat_id"] timestamp_start = msg["time"] # 用 chat_id 获取该聊天在指定时间戳范围内的消息 - return await get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest") + return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest") -async def get_raw_msg_by_timestamp_with_users( +def get_raw_msg_by_timestamp_with_users( timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 @@ -447,39 +446,37 @@ async def get_raw_msg_by_timestamp_with_users( filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}} # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None - return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) + return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: +def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ filter_query = {"time": {"$lt": timestamp}} sort_order = [("time", 1)] - return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) + return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: +def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}} sort_order = [("time", 1)] - return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) + return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -async def get_raw_msg_before_timestamp_with_users( - timestamp: float, person_ids: list, limit: int = 0 -) -> List[Dict[str, Any]]: +def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}} sort_order = [("time", 1)] - return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) + return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: +def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: """ 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。 如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。 @@ -493,10 +490,10 @@ async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, tim return 0 # 起始时间大于等于结束时间,没有新消息 filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}} - return await count_messages(message_filter=filter_query) + return count_messages(message_filter=filter_query) -async def num_new_messages_since_with_users( +def num_new_messages_since_with_users( chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list ) -> int: """检查某些特定用户在特定聊天在指定时间戳之间有多少新消息""" @@ -507,10 +504,10 @@ async def num_new_messages_since_with_users( "time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}, } - return await count_messages(message_filter=filter_query) + return count_messages(message_filter=filter_query) -async def _build_readable_messages_internal( +def _build_readable_messages_internal( messages: List[Dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, @@ -520,7 +517,6 @@ async def _build_readable_messages_internal( pic_counter: int = 1, show_pic: bool = True, message_id_list: Optional[List[Dict[str, Any]]] = None, - read_mark: float = 0.0, ) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]: """ 内部辅助函数,构建可读消息字符串和原始消息详情列表。 @@ -631,8 +627,7 @@ async def _build_readable_messages_internal( if replace_bot_name and user_id == global_config.bot.qq_account: person_name = f"{global_config.bot.nickname}(你)" else: - person_info = await person_info_manager.get_values(person_id, ["person_name"]) - person_name = person_info.get("person_name") # type: ignore + person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore # 如果 person_name 未设置,则使用消息中的 nickname 或默认名称 if not person_name: @@ -731,10 +726,11 @@ async def _build_readable_messages_internal( "is_action": is_action, } continue + # 如果是同一个人发送的连续消息且时间间隔小于等于60秒 if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60): current_merge["content"].append(content) - current_merge["end_time"] = timestamp + current_merge["end_time"] = timestamp # 更新最后消息时间 else: # 保存上一个合并块 merged_messages.append(current_merge) @@ -762,14 +758,8 @@ async def _build_readable_messages_internal( # 4 & 5: 格式化为字符串 output_lines = [] - read_mark_inserted = False for _i, merged in enumerate(merged_messages): - # 检查是否需要插入已读标记 - if read_mark > 0 and not read_mark_inserted and merged["start_time"] >= read_mark: - output_lines.append("\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n") - read_mark_inserted = True - # 使用指定的 timestamp_mode 格式化时间 readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode) @@ -810,7 +800,7 @@ async def _build_readable_messages_internal( ) -async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: +def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: # sourcery skip: use-contextlib-suppress """ 构建图片映射信息字符串,显示图片的具体描述内容 @@ -833,8 +823,8 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: # 从数据库中获取图片描述 description = "[图片内容未知]" # 默认描述 try: - async with get_db_session() as session: - image = (await session.execute(select(Images).where(Images.image_id == pic_id))).scalar_one_or_none() + with get_db_session() as session: + image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none() if image and image.description: # type: ignore description = image.description except Exception: @@ -931,17 +921,17 @@ async def build_readable_messages_with_list( 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 允许通过参数控制格式化行为。 """ - formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal( + formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal( messages, replace_bot_name, merge_messages, timestamp_mode, truncate ) - if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping): + if pic_mapping_info := build_pic_mapping_info(pic_id_mapping): formatted_string = f"{pic_mapping_info}\n\n{formatted_string}" return formatted_string, details_list -async def build_readable_messages_with_id( +def build_readable_messages_with_id( messages: List[Dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, @@ -957,7 +947,7 @@ async def build_readable_messages_with_id( """ message_id_list = assign_message_ids(messages) - formatted_string = await build_readable_messages( + formatted_string = build_readable_messages( messages=messages, replace_bot_name=replace_bot_name, merge_messages=merge_messages, @@ -972,7 +962,7 @@ async def build_readable_messages_with_id( return formatted_string, message_id_list -async def build_readable_messages( +def build_readable_messages( messages: List[Dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, @@ -1013,28 +1003,24 @@ async def build_readable_messages( from src.common.database.sqlalchemy_database_api import get_db_session - async with get_db_session() as session: + with get_db_session() as session: # 获取这个时间范围内的动作记录,并匹配chat_id - actions_in_range = ( - await session.execute( - select(ActionRecords) - .where( - and_( - ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id - ) + actions_in_range = session.execute( + select(ActionRecords) + .where( + and_( + ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id ) - .order_by(ActionRecords.time) ) + .order_by(ActionRecords.time) ).scalars() # 获取最新消息之后的第一个动作记录 - action_after_latest = ( - await session.execute( - select(ActionRecords) - .where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id)) - .order_by(ActionRecords.time) - .limit(1) - ) + action_after_latest = session.execute( + select(ActionRecords) + .where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id)) + .order_by(ActionRecords.time) + .limit(1) ).scalars() # 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError @@ -1066,7 +1052,7 @@ async def build_readable_messages( if read_mark <= 0: # 没有有效的 read_mark,直接格式化所有消息 - formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal( + formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal( copy_messages, replace_bot_name, merge_messages, @@ -1077,7 +1063,7 @@ async def build_readable_messages( ) # 生成图片映射信息并添加到最前面 - pic_mapping_info = await build_pic_mapping_info(pic_id_mapping) + pic_mapping_info = build_pic_mapping_info(pic_id_mapping) if pic_mapping_info: return f"{pic_mapping_info}\n\n{formatted_string}" else: @@ -1092,7 +1078,7 @@ async def build_readable_messages( pic_counter = 1 # 分别格式化,但使用共享的图片映射 - formatted_before, _, pic_id_mapping, pic_counter = await _build_readable_messages_internal( + formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal( messages_before_mark, replace_bot_name, merge_messages, @@ -1103,7 +1089,7 @@ async def build_readable_messages( show_pic=show_pic, message_id_list=message_id_list, ) - formatted_after, _, pic_id_mapping, _ = await _build_readable_messages_internal( + formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal( messages_after_mark, replace_bot_name, merge_messages, @@ -1119,7 +1105,7 @@ async def build_readable_messages( # 生成图片映射信息 if pic_id_mapping: - pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n" + pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n" else: pic_mapping_info = "聊天记录信息:\n" @@ -1242,7 +1228,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: # 在最前面添加图片映射信息 final_output_lines = [] - pic_mapping_info = await build_pic_mapping_info(pic_id_mapping) + pic_mapping_info = build_pic_mapping_info(pic_id_mapping) if pic_mapping_info: final_output_lines.append(pic_mapping_info) final_output_lines.append("\n\n") diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 3d97b622e..112db6726 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -25,7 +25,7 @@ logger = get_logger("unified_prompt") @dataclass class PromptParameters: """统一提示词参数系统""" - + # 基础参数 chat_id: str = "" is_group_chat: bool = False @@ -34,7 +34,7 @@ class PromptParameters: reply_to: str = "" extra_info: str = "" prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u" - + # 功能开关 enable_tool: bool = True enable_memory: bool = True @@ -42,20 +42,20 @@ class PromptParameters: enable_relation: bool = True enable_cross_context: bool = True enable_knowledge: bool = True - + # 性能控制 max_context_messages: int = 50 - + # 调试选项 debug_mode: bool = False - + # 聊天历史和上下文 chat_target_info: Optional[Dict[str, Any]] = None message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list) message_list_before_short: List[Dict[str, Any]] = field(default_factory=list) chat_talking_prompt_short: str = "" target_user_info: Optional[Dict[str, Any]] = None - + # 已构建的内容块 expression_habits_block: str = "" relation_info_block: str = "" @@ -63,7 +63,7 @@ class PromptParameters: tool_info_block: str = "" knowledge_prompt: str = "" cross_context_block: str = "" - + # 其他内容块 keywords_reaction_prompt: str = "" extra_info_block: str = "" @@ -75,11 +75,13 @@ class PromptParameters: reply_target_block: str = "" mood_prompt: str = "" action_descriptions: str = "" - + # 可用动作信息 available_actions: Optional[Dict[str, Any]] = None - read_mark: float = 0.0 - + + # 动态生成的聊天场景提示 + chat_scene: str = "" + def validate(self) -> List[str]: """参数验证""" errors = [] @@ -94,22 +96,22 @@ class PromptParameters: class PromptContext: """提示词上下文管理器""" - + def __init__(self): self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {} self._current_context_var = contextvars.ContextVar("current_context", default=None) self._context_lock = asyncio.Lock() - + @property def _current_context(self) -> Optional[str]: """获取当前协程的上下文ID""" return self._current_context_var.get() - + @_current_context.setter def _current_context(self, value: Optional[str]): """设置当前协程的上下文ID""" self._current_context_var.set(value) # type: ignore - + @asynccontextmanager async def async_scope(self, context_id: Optional[str] = None): """创建一个异步的临时提示模板作用域""" @@ -124,13 +126,13 @@ class PromptContext: except asyncio.TimeoutError: logger.warning(f"获取上下文锁超时,context_id: {context_id}") context_id = None - + previous_context = self._current_context token = self._current_context_var.set(context_id) if context_id else None else: previous_context = self._current_context token = None - + try: yield self finally: @@ -143,7 +145,7 @@ class PromptContext: self._current_context = previous_context except Exception: ... - + async def get_prompt_async(self, name: str) -> Optional["Prompt"]: """异步获取当前作用域中的提示模板""" async with self._context_lock: @@ -156,7 +158,7 @@ class PromptContext: ): return self._context_prompts[current_context][name] return None - + async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None: """异步注册提示模板到指定作用域""" async with self._context_lock: @@ -167,59 +169,55 @@ class PromptContext: class PromptManager: """统一提示词管理器""" - + def __init__(self): self._prompts = {} self._counter = 0 self._context = PromptContext() self._lock = asyncio.Lock() - + @asynccontextmanager async def async_message_scope(self, message_id: Optional[str] = None): """为消息处理创建异步临时作用域""" async with self._context.async_scope(message_id): yield self - + async def get_prompt_async(self, name: str) -> "Prompt": """异步获取提示模板""" context_prompt = await self._context.get_prompt_async(name) if context_prompt is not None: logger.debug(f"从上下文中获取提示词: {name} {context_prompt}") return context_prompt - + async with self._lock: if name not in self._prompts: raise KeyError(f"Prompt '{name}' not found") return self._prompts[name] - + def generate_name(self, template: str) -> str: """为未命名的prompt生成名称""" self._counter += 1 return f"prompt_{self._counter}" - + def register(self, prompt: "Prompt") -> None: """注册一个prompt""" if not prompt.name: prompt.name = self.generate_name(prompt.template) self._prompts[prompt.name] = prompt - + def add_prompt(self, name: str, fstr: str) -> "Prompt": """添加新提示模板""" prompt = Prompt(fstr, name=name) if prompt.name: self._prompts[prompt.name] = prompt return prompt - + async def format_prompt(self, name: str, **kwargs) -> str: """格式化提示模板""" prompt = await self.get_prompt_async(name) result = prompt.format(**kwargs) return result - @property - def context(self): - return self._context - # 全局单例 global_prompt_manager = PromptManager() @@ -230,21 +228,21 @@ class Prompt: 统一提示词类 - 合并模板管理和智能构建功能 真正的Prompt类,支持模板管理和智能上下文构建 """ - + # 临时标记,作为类常量 _TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__" _TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__" - + def __init__( self, template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, - should_register: bool = True + should_register: bool = True, ): """ 初始化统一提示词 - + Args: template: 提示词模板字符串 name: 提示词名称 @@ -256,14 +254,14 @@ class Prompt: self.parameters = parameters or PromptParameters() self.args = self._parse_template_args(template) self._formatted_result = "" - + # 预处理模板中的转义花括号 self._processed_template = self._process_escaped_braces(template) - + # 自动注册 - if should_register and not global_prompt_manager.context._current_context: + if should_register and not global_prompt_manager._context._current_context: global_prompt_manager.register(self) - + @staticmethod def _process_escaped_braces(template) -> str: """处理模板中的转义花括号""" @@ -271,14 +269,14 @@ class Prompt: template = "\n".join(str(item) for item in template) elif not isinstance(template, str): template = str(template) - + return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE) - + @staticmethod def _restore_escaped_braces(template: str) -> str: """将临时标记还原为实际的花括号字符""" return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}") - + def _parse_template_args(self, template: str) -> List[str]: """解析模板参数""" template_args = [] @@ -288,11 +286,11 @@ class Prompt: if expr and expr not in template_args: template_args.append(expr) return template_args - + async def build(self) -> str: """ 构建完整的提示词,包含智能上下文 - + Returns: str: 构建完成的提示词文本 """ @@ -301,38 +299,38 @@ class Prompt: if errors: logger.error(f"参数验证失败: {', '.join(errors)}") raise ValueError(f"参数验证失败: {', '.join(errors)}") - + start_time = time.time() try: # 构建上下文数据 context_data = await self._build_context_data() - + # 格式化模板 result = await self._format_with_context(context_data) - + total_time = time.time() - start_time logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s") - + self._formatted_result = result return result - + except asyncio.TimeoutError as e: logger.error(f"构建Prompt超时: {e}") raise TimeoutError(f"构建Prompt超时: {e}") from e except Exception as e: logger.error(f"构建Prompt失败: {e}") raise RuntimeError(f"构建Prompt失败: {e}") from e - + async def _build_context_data(self) -> Dict[str, Any]: """构建智能上下文数据""" # 并行执行所有构建任务 start_time = time.time() - + try: # 准备构建任务 tasks = [] task_names = [] - + # 初始化预构建参数 pre_built_params = {} if self.parameters.expression_habits_block: @@ -347,46 +345,46 @@ class Prompt: pre_built_params["knowledge_prompt"] = self.parameters.knowledge_prompt if self.parameters.cross_context_block: pre_built_params["cross_context_block"] = self.parameters.cross_context_block - + # 根据参数确定要构建的项 if self.parameters.enable_expression and not pre_built_params.get("expression_habits_block"): tasks.append(self._build_expression_habits()) task_names.append("expression_habits") - + if self.parameters.enable_memory and not pre_built_params.get("memory_block"): tasks.append(self._build_memory_block()) task_names.append("memory_block") - + if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"): tasks.append(self._build_relation_info()) task_names.append("relation_info") - + if self.parameters.enable_tool and not pre_built_params.get("tool_info_block"): tasks.append(self._build_tool_info()) task_names.append("tool_info") - + if self.parameters.enable_knowledge and not pre_built_params.get("knowledge_prompt"): tasks.append(self._build_knowledge_info()) task_names.append("knowledge_info") - + if self.parameters.enable_cross_context and not pre_built_params.get("cross_context_block"): tasks.append(self._build_cross_context()) task_names.append("cross_context") - + # 性能优化 - base_timeout = 20.0 + base_timeout = 10.0 task_timeout = 2.0 timeout_seconds = min( max(base_timeout, len(tasks) * task_timeout), 30.0, ) - + max_concurrent_tasks = 5 if len(tasks) > max_concurrent_tasks: results = [] for i in range(0, len(tasks), max_concurrent_tasks): batch_tasks = tasks[i : i + max_concurrent_tasks] - + batch_results = await asyncio.wait_for( asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds ) @@ -395,225 +393,181 @@ class Prompt: results = await asyncio.wait_for( asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds ) - + # 处理结果 context_data = {} for i, result in enumerate(results): task_name = task_names[i] if i < len(task_names) else f"task_{i}" - + if isinstance(result, Exception): logger.error(f"构建任务{task_name}失败: {str(result)}") elif isinstance(result, dict): context_data.update(result) - + # 添加预构建的参数 for key, value in pre_built_params.items(): if value: context_data[key] = value - + except asyncio.TimeoutError: logger.error(f"构建超时 ({timeout_seconds}s)") context_data = {} for key, value in pre_built_params.items(): if value: context_data[key] = value - + # 构建聊天历史 if self.parameters.prompt_mode == "s4u": await self._build_s4u_chat_context(context_data) else: await self._build_normal_chat_context(context_data) - + # 补充基础信息 - context_data.update({ - "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt, - "extra_info_block": self.parameters.extra_info_block, - "time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}", - "identity": self.parameters.identity_block, - "schedule_block": self.parameters.schedule_block, - "moderation_prompt": self.parameters.moderation_prompt_block, - "reply_target_block": self.parameters.reply_target_block, - "mood_state": self.parameters.mood_prompt, - "action_descriptions": self.parameters.action_descriptions, - }) - + context_data.update( + { + "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt, + "extra_info_block": self.parameters.extra_info_block, + "time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}", + "identity": self.parameters.identity_block, + "schedule_block": self.parameters.schedule_block, + "moderation_prompt": self.parameters.moderation_prompt_block, + "reply_target_block": self.parameters.reply_target_block, + "mood_state": self.parameters.mood_prompt, + "action_descriptions": self.parameters.action_descriptions, + } + ) + total_time = time.time() - start_time logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s") - + return context_data - + async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None: """构建S4U模式的聊天上下文""" if not self.parameters.message_list_before_now_long: return - - core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts( + + read_history_prompt, unread_history_prompt = await self._build_s4u_chat_history_prompts( self.parameters.message_list_before_now_long, self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "", self.parameters.sender, - read_mark=self.parameters.read_mark, + self.parameters.chat_id, ) - - context_data["core_dialogue_prompt"] = core_dialogue - context_data["background_dialogue_prompt"] = background_dialogue - + + context_data["read_history_prompt"] = read_history_prompt + context_data["unread_history_prompt"] = unread_history_prompt + async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None: """构建normal模式的聊天上下文""" if not self.parameters.chat_talking_prompt_short: return - + context_data["chat_info"] = f"""群里的聊天内容: {self.parameters.chat_talking_prompt_short}""" - - @staticmethod + async def _build_s4u_chat_history_prompts( - message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, read_mark: float = 0.0 + self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str ) -> Tuple[str, str]: - """构建S4U风格的分离对话prompt""" - # 实现逻辑与原有SmartPromptBuilder相同 - core_dialogue_list = [] - bot_id = str(global_config.bot.qq_account) - - for msg_dict in message_list_before_now: - try: - msg_user_id = str(msg_dict.get("user_id")) - reply_to = msg_dict.get("reply_to", "") - platform, reply_to_user_id = Prompt.parse_reply_target(reply_to) - if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id: - core_dialogue_list.append(msg_dict) - except Exception as e: - logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}") - - # 构建背景对话 prompt - all_dialogue_prompt = "" - if message_list_before_now: - latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :] - all_dialogue_prompt_str = await build_readable_messages( - latest_25_msgs, - replace_bot_name=True, - timestamp_mode="normal", - truncate=True, - read_mark=read_mark, + """构建S4U风格的已读/未读历史消息prompt""" + try: + # 动态导入default_generator以避免循环导入 + from src.plugin_system.apis.generator_api import get_replyer + + # 创建临时生成器实例来使用其方法 + temp_generator = get_replyer(None, chat_id, request_type="prompt_building") + return await temp_generator.build_s4u_chat_history_prompts( + message_list_before_now, target_user_id, sender, chat_id ) - all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}" - - # 构建核心对话 prompt - core_dialogue_prompt = "" - if core_dialogue_list: - latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list - has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages) - - if not has_bot_message: - core_dialogue_prompt = "" - else: - core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] - - core_dialogue_prompt_str = await build_readable_messages( - core_dialogue_list, - replace_bot_name=True, - merge_messages=False, - timestamp_mode="normal_no_YMD", - read_mark=read_mark, - truncate=True, - show_actions=True, - ) - core_dialogue_prompt = f"""-------------------------------- -这是你和{sender}的对话,你们正在交流中: -{core_dialogue_prompt_str} --------------------------------- -""" - - return core_dialogue_prompt, all_dialogue_prompt - + except Exception as e: + logger.error(f"构建S4U历史消息prompt失败: {e}") + async def _build_expression_habits(self) -> Dict[str, Any]: """构建表达习惯""" - if not global_config.expression.enable_expression: + use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.parameters.chat_id) + if not use_expression: return {"expression_habits_block": ""} - + try: from src.chat.express.expression_selector import ExpressionSelector - + # 获取聊天历史用于表情选择 chat_history = "" if self.parameters.message_list_before_now_long: recent_messages = self.parameters.message_list_before_now_long[-10:] - chat_history = await build_readable_messages( - recent_messages, - replace_bot_name=True, - timestamp_mode="normal", - truncate=True + chat_history = build_readable_messages( + recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True ) - + # 创建表情选择器 - expression_selector = ExpressionSelector() - + expression_selector = ExpressionSelector(self.parameters.chat_id) + # 选择合适的表情 selected_expressions = await expression_selector.select_suitable_expressions_llm( + chat_history=chat_history, + current_message=self.parameters.target, + emotional_tone="neutral", + topic_type="general", ) - + # 构建表达习惯块 if selected_expressions: style_habits_str = "\n".join([f"- {expr}" for expr in selected_expressions]) expression_habits_block = f"- 你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}" else: expression_habits_block = "" - + return {"expression_habits_block": expression_habits_block} - + except Exception as e: logger.error(f"构建表达习惯失败: {e}") return {"expression_habits_block": ""} - + async def _build_memory_block(self) -> Dict[str, Any]: """构建记忆块""" if not global_config.memory.enable_memory: return {"memory_block": ""} - + try: from src.chat.memory_system.memory_activator import MemoryActivator from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory - + # 获取聊天历史 chat_history = "" if self.parameters.message_list_before_now_long: recent_messages = self.parameters.message_list_before_now_long[-20:] - chat_history = await build_readable_messages( - recent_messages, - replace_bot_name=True, - timestamp_mode="normal", - truncate=True + chat_history = build_readable_messages( + recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True ) - + # 激活长期记忆 memory_activator = MemoryActivator() running_memories = await memory_activator.activate_memory_with_chat_history( - target_message=self.parameters.target, - chat_history_prompt=chat_history + target_message=self.parameters.target, chat_history_prompt=chat_history ) - + # 获取即时记忆 async_memory_wrapper = get_async_instant_memory(self.parameters.chat_id) instant_memory = await async_memory_wrapper.get_memory_with_fallback(self.parameters.target) - + # 构建记忆块 memory_parts = [] - + if running_memories: memory_parts.append("以下是当前在聊天中,你回忆起的记忆:") for memory in running_memories: memory_parts.append(f"- {memory['content']}") - + if instant_memory: memory_parts.append(f"- {instant_memory}") - + memory_block = "\n".join(memory_parts) if memory_parts else "" - + return {"memory_block": memory_block} - + except Exception as e: logger.error(f"构建记忆块失败: {e}") return {"memory_block": ""} - + async def _build_relation_info(self) -> Dict[str, Any]: """构建关系信息""" try: @@ -622,106 +576,104 @@ class Prompt: except Exception as e: logger.error(f"构建关系信息失败: {e}") return {"relation_info_block": ""} - + async def _build_tool_info(self) -> Dict[str, Any]: """构建工具信息""" if not global_config.tool.enable_tool: return {"tool_info_block": ""} - + try: from src.plugin_system.core.tool_use import ToolExecutor - + # 获取聊天历史 chat_history = "" if self.parameters.message_list_before_now_long: recent_messages = self.parameters.message_list_before_now_long[-15:] - chat_history = await build_readable_messages( - recent_messages, - replace_bot_name=True, - timestamp_mode="normal", - truncate=True + chat_history = build_readable_messages( + recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True ) - + # 创建工具执行器 tool_executor = ToolExecutor(chat_id=self.parameters.chat_id) - + # 执行工具获取信息 tool_results, _, _ = await tool_executor.execute_from_chat_message( sender=self.parameters.sender, target_message=self.parameters.target, chat_history=chat_history, - return_details=False + return_details=False, ) - + # 构建工具信息块 if tool_results: - tool_info_parts = ["## 工具信息","以下是你通过工具获取到的实时信息:"] + tool_info_parts = ["## 工具信息", "以下是你通过工具获取到的实时信息:"] for tool_result in tool_results: tool_name = tool_result.get("tool_name", "unknown") content = tool_result.get("content", "") result_type = tool_result.get("type", "tool_result") - + tool_info_parts.append(f"- 【{tool_name}】{result_type}: {content}") - + tool_info_parts.append("以上是你获取到的实时信息,请在回复时参考这些信息。") tool_info_block = "\n".join(tool_info_parts) else: tool_info_block = "" - + return {"tool_info_block": tool_info_block} - + except Exception as e: logger.error(f"构建工具信息失败: {e}") return {"tool_info_block": ""} - + async def _build_knowledge_info(self) -> Dict[str, Any]: """构建知识信息""" if not global_config.lpmm_knowledge.enable: return {"knowledge_prompt": ""} - + try: - from src.chat.knowledge.knowledge_lib import qa_manager - + from src.chat.knowledge.knowledge_lib import QAManager + # 获取问题文本(当前消息) question = self.parameters.target or "" if not question: return {"knowledge_prompt": ""} - - # 检查QA管理器是否已成功初始化 - if not qa_manager: - logger.warning("QA管理器未初始化 (可能lpmm_knowledge被禁用),跳过知识库搜索。") - return {"knowledge_prompt": ""} - + + # 创建QA管理器 + qa_manager = QAManager() + # 搜索相关知识 knowledge_results = await qa_manager.get_knowledge( - question=question + question=question, chat_id=self.parameters.chat_id, max_results=5, min_similarity=0.5 ) - + # 构建知识块 if knowledge_results and knowledge_results.get("knowledge_items"): - knowledge_parts = ["## 知识库信息","以下是与你当前对话相关的知识信息:"] - + knowledge_parts = ["## 知识库信息", "以下是与你当前对话相关的知识信息:"] + for item in knowledge_results["knowledge_items"]: content = item.get("content", "") source = item.get("source", "") relevance = item.get("relevance", 0.0) - + if content: - knowledge_parts.append(f"- [相关度: {relevance}] {content}") - - if summary := knowledge_results.get("summary"): - knowledge_parts.append(f"\n知识总结: {summary}") - + if source: + knowledge_parts.append(f"- [{relevance:.2f}] {content} (来源: {source})") + else: + knowledge_parts.append(f"- [{relevance:.2f}] {content}") + + if knowledge_results.get("summary"): + knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}") + knowledge_prompt = "\n".join(knowledge_parts) else: knowledge_prompt = "" - + return {"knowledge_prompt": knowledge_prompt} - + except Exception as e: logger.error(f"构建知识信息失败: {e}") return {"knowledge_prompt": ""} - + async def _build_cross_context(self) -> Dict[str, Any]: """构建跨群上下文""" try: @@ -732,7 +684,7 @@ class Prompt: except Exception as e: logger.error(f"构建跨群上下文失败: {e}") return {"cross_context_block": ""} - + async def _format_with_context(self, context_data: Dict[str, Any]) -> str: """使用上下文数据格式化模板""" if self.parameters.prompt_mode == "s4u": @@ -741,9 +693,9 @@ class Prompt: params = self._prepare_normal_params(context_data) else: params = self._prepare_default_params(context_data) - + return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params) - + def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: """准备S4U模式的参数""" return { @@ -759,17 +711,19 @@ class Prompt: "action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""), "sender_name": self.parameters.sender or "未知用户", "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), - "background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""), + "read_history_prompt": context_data.get("read_history_prompt", ""), + "unread_history_prompt": context_data.get("unread_history_prompt", ""), "time_block": context_data.get("time_block", ""), - "core_dialogue_prompt": context_data.get("core_dialogue_prompt", ""), "reply_target_block": context_data.get("reply_target_block", ""), "reply_style": global_config.personality.reply_style, - "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), + "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt + or context_data.get("keywords_reaction_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), - "safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), - "chat_context_type": "群聊" if self.parameters.is_group_chat else "私聊", + "safety_guidelines_block": self.parameters.safety_guidelines_block + or context_data.get("safety_guidelines_block", ""), + "chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", } - + def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: """准备Normal模式的参数""" return { @@ -789,11 +743,14 @@ class Prompt: "reply_target_block": context_data.get("reply_target_block", ""), "config_expression_style": global_config.personality.reply_style, "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), - "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), + "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt + or context_data.get("keywords_reaction_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), - "safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), + "safety_guidelines_block": self.parameters.safety_guidelines_block + or context_data.get("safety_guidelines_block", ""), + "chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。", } - + def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: """准备默认模式的参数""" return { @@ -809,11 +766,13 @@ class Prompt: "reason": "", "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), "reply_style": global_config.personality.reply_style, - "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), + "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt + or context_data.get("keywords_reaction_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), - "safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), + "safety_guidelines_block": self.parameters.safety_guidelines_block + or context_data.get("safety_guidelines_block", ""), } - + def format(self, *args, **kwargs) -> str: """格式化模板,支持位置参数和关键字参数""" try: @@ -826,21 +785,21 @@ class Prompt: processed_template = self._processed_template.format(**formatted_args) else: processed_template = self._processed_template - + # 再用关键字参数格式化 if kwargs: processed_template = processed_template.format(**kwargs) - + # 将临时标记还原为实际的花括号 result = self._restore_escaped_braces(processed_template) return result except (IndexError, KeyError) as e: raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e - + def __str__(self) -> str: """返回格式化后的结果或原始模板""" return self._formatted_result if self._formatted_result else self.template - + def __repr__(self) -> str: """返回提示词的表示形式""" return f"Prompt(template='{self.template}', name='{self.name}')" @@ -912,9 +871,7 @@ class Prompt: return await relationship_fetcher.build_relation_info(person_id, points_num=5) @staticmethod - async def build_cross_context( - chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]] - ) -> str: + async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str: """ 构建跨群聊上下文 - 统一实现 @@ -930,7 +887,7 @@ class Prompt: return "" from src.plugin_system.apis import cross_context_api - + other_chat_raw_ids = cross_context_api.get_context_groups(chat_id) if not other_chat_raw_ids: return "" @@ -969,7 +926,7 @@ class Prompt: person_info_manager = get_person_info_manager() person_id = person_info_manager.get_person_id_by_person_name(sender) if person_id: - user_id = person_info_manager.get_value(person_id, "user_id") + user_id = person_info_manager.get_value_sync(person_id, "user_id") return str(user_id) if user_id else "" return "" @@ -977,10 +934,7 @@ class Prompt: # 工厂函数 def create_prompt( - template: str, - name: Optional[str] = None, - parameters: Optional[PromptParameters] = None, - **kwargs + template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs ) -> Prompt: """快速创建Prompt实例的工厂函数""" if parameters is None: @@ -989,14 +943,10 @@ def create_prompt( async def create_prompt_async( - template: str, - name: Optional[str] = None, - parameters: Optional[PromptParameters] = None, - **kwargs + template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs ) -> Prompt: """异步创建Prompt实例""" prompt = create_prompt(template, name, parameters, **kwargs) - if global_prompt_manager.context._current_context: - await global_prompt_manager.context.register_async(prompt) + if global_prompt_manager._context._current_context: + await global_prompt_manager._context.register_async(prompt) return prompt - diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 746b13e63..c2e4814f8 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -7,7 +7,7 @@ import numpy as np from collections import Counter from maim_message import UserInfo -from typing import Optional, Tuple, Dict, List, Any, Coroutine +from typing import Optional, Tuple, Dict, List, Any from src.common.logger import get_logger from src.common.message_repository import find_messages, count_messages @@ -332,17 +332,17 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese if global_config.response_splitter.enable and enable_splitter: logger.info(f"回复分割器已启用,模式: {global_config.response_splitter.split_mode}。") - + split_mode = global_config.response_splitter.split_mode - + if split_mode == "llm" and "[SPLIT]" in cleaned_text: logger.debug("检测到 [SPLIT] 标记,使用 LLM 自定义分割。") split_sentences_raw = cleaned_text.split("[SPLIT]") split_sentences = [s.strip() for s in split_sentences_raw if s.strip()] else: if split_mode == "llm": - logger.debug("未检测到 [SPLIT] 标记,回退到基于标点的传统模式进行分割。") - split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text) + logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。") + split_sentences = [cleaned_text] else: # mode == "punctuation" logger.debug("使用基于标点的传统模式进行分割。") split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text) @@ -352,6 +352,8 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese sentences = [] for sentence in split_sentences: + # 清除开头可能存在的空行 + sentence = sentence.lstrip("\n").rstrip() if global_config.chinese_typo.enable and enable_chinese_typo: typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence) sentences.append(typoed_text) @@ -540,8 +542,7 @@ def get_western_ratio(paragraph): return western_count / len(alnum_chars) -def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int] | tuple[ - Coroutine[Any, Any, int], int]: +def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]: """计算两个时间点之间的消息数量和文本总长度 Args: @@ -619,7 +620,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal" return time.strftime("%H:%M:%S", time.localtime(timestamp)) -async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: +def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: """ 获取聊天类型(是否群聊)和私聊对象信息。 @@ -663,8 +664,7 @@ async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Di if person_id: # get_value is async, so await it directly person_info_manager = get_person_info_manager() - person_data = await person_info_manager.get_values(person_id, ["person_name"]) - person_name = person_data.get("person_name") + person_name = person_info_manager.get_value_sync(person_id, "person_name") target_info["person_id"] = person_id target_info["person_name"] = person_name @@ -695,25 +695,9 @@ def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]: """ result = [] used_ids = set() - len_i = len(messages) - if len_i > 100: - a = 10 - b = 99 - else: - a = 1 - b = 9 - for i, message in enumerate(messages): - # 生成唯一的简短ID - while True: - # 使用索引+随机数生成简短ID - random_suffix = random.randint(a, b) - message_id = f"m{i + 1}{random_suffix}" - - if message_id not in used_ids: - used_ids.add(message_id) - break - + # 使用简单的索引作为ID + message_id = f"m{i + 1}" result.append({"id": message_id, "message": message}) return result diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 2f72af32b..f6acb1a7d 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -1,145 +1,545 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -"""纯 inkfox 视频关键帧分析工具 - -仅依赖 `inkfox.video` 提供的 Rust 扩展能力: - - extract_keyframes_from_video - - get_system_info - -功能: - - 关键帧提取 (base64, timestamp) - - 批量 / 逐帧 LLM 描述 - - 自动模式 (<=3 帧批量,否则逐帧) +""" +视频分析器模块 - Rust优化版本 +集成了Rust视频关键帧提取模块,提供高性能的视频分析功能 +支持SIMD优化、多线程处理和智能关键帧检测 """ -from __future__ import annotations - import os -import io +import tempfile import asyncio import base64 -import tempfile -from pathlib import Path -from typing import List, Tuple, Optional, Dict, Any import hashlib import time - +import numpy as np from PIL import Image +from pathlib import Path +from typing import List, Tuple, Optional, Dict +import io -from src.common.logger import get_logger -from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore -from sqlalchemy import select, update, insert # type: ignore -from sqlalchemy import exc as sa_exc # type: ignore - -# 简易并发控制:同一 hash 只处理一次 -_video_locks: Dict[str, asyncio.Lock] = {} -_locks_guard = asyncio.Lock() +from src.config.config import global_config, model_config +from src.common.logger import get_logger +from src.common.database.sqlalchemy_models import get_db_session, Videos logger = get_logger("utils_video") -from inkfox import video +# Rust模块可用性检测 +RUST_VIDEO_AVAILABLE = False +try: + import rust_video + + RUST_VIDEO_AVAILABLE = True + logger.info("✅ Rust 视频处理模块加载成功") +except ImportError as e: + logger.warning(f"⚠️ Rust 视频处理模块加载失败: {e}") + logger.warning("⚠️ 视频识别功能将自动禁用") +except Exception as e: + logger.error(f"❌ 加载Rust模块时发生错误: {e}") + RUST_VIDEO_AVAILABLE = False + +# 全局正在处理的视频哈希集合,用于防止重复处理 +processing_videos = set() +processing_lock = asyncio.Lock() +# 为每个视频hash创建独立的锁和事件 +video_locks = {} +video_events = {} +video_lock_manager = asyncio.Lock() class VideoAnalyzer: - """基于 inkfox 的视频关键帧 + LLM 描述分析器""" + """优化的视频分析器类""" - def __init__(self) -> None: - cfg = getattr(global_config, "video_analysis", object()) - self.max_frames: int = getattr(cfg, "max_frames", 20) - self.frame_quality: int = getattr(cfg, "frame_quality", 85) - self.max_image_size: int = getattr(cfg, "max_image_size", 600) - self.enable_frame_timing: bool = getattr(cfg, "enable_frame_timing", True) - self.use_simd: bool = getattr(cfg, "rust_use_simd", True) - self.threads: int = getattr(cfg, "rust_threads", 0) - self.ffmpeg_path: str = getattr(cfg, "ffmpeg_path", "ffmpeg") - self.analysis_mode: str = getattr(cfg, "analysis_mode", "auto") - self.frame_analysis_delay: float = 0.3 - - # 人格与提示模板 + def __init__(self): + """初始化视频分析器""" + # 检查是否有任何可用的视频处理实现 + opencv_available = False try: - persona = global_config.personality - self.personality_core = getattr(persona, "personality_core", "是一个积极向上的女大学生") - self.personality_side = getattr(persona, "personality_side", "用一句话或几句话描述人格的侧面特点") - except Exception: # pragma: no cover - self.personality_core = "是一个积极向上的女大学生" - self.personality_side = "用一句话或几句话描述人格的侧面特点" + import cv2 - self.batch_analysis_prompt = getattr( - cfg, - "batch_analysis_prompt", - """请以第一人称视角阅读这些按时间顺序提取的关键帧。\n核心:{personality_core}\n人格:{personality_side}\n请详细描述视频(主题/人物与场景/动作与时间线/视觉风格/情绪氛围/特殊元素)。""", - ) + opencv_available = True + except ImportError: + pass + if not RUST_VIDEO_AVAILABLE and not opencv_available: + logger.error("❌ 没有可用的视频处理实现,视频分析器将被禁用") + self.disabled = True + return + elif not RUST_VIDEO_AVAILABLE: + logger.warning("⚠️ Rust视频处理模块不可用,将使用Python降级实现") + elif not opencv_available: + logger.warning("⚠️ OpenCV不可用,仅支持Rust关键帧模式") + + self.disabled = False + + # 使用专用的视频分析配置 try: self.video_llm = LLMRequest( model_set=model_config.model_task_config.video_analysis, request_type="video_analysis" ) - except Exception: + logger.debug("✅ 使用video_analysis模型配置") + except (AttributeError, KeyError) as e: + # 如果video_analysis不存在,使用vlm配置 self.video_llm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="vlm") + logger.warning(f"video_analysis配置不可用({e}),回退使用vlm配置") - self._log_system() + # 从配置文件读取参数,如果配置不存在则使用默认值 + config = global_config.video_analysis - # ---- 系统信息 ---- - def _log_system(self) -> None: + # 使用 getattr 统一获取配置参数,如果配置不存在则使用默认值 + self.max_frames = getattr(config, "max_frames", 6) + self.frame_quality = getattr(config, "frame_quality", 85) + self.max_image_size = getattr(config, "max_image_size", 600) + self.enable_frame_timing = getattr(config, "enable_frame_timing", True) + + # Rust模块相关配置 + self.rust_keyframe_threshold = getattr(config, "rust_keyframe_threshold", 2.0) + self.rust_use_simd = getattr(config, "rust_use_simd", True) + self.rust_block_size = getattr(config, "rust_block_size", 8192) + self.rust_threads = getattr(config, "rust_threads", 0) + self.ffmpeg_path = getattr(config, "ffmpeg_path", "ffmpeg") + + # 从personality配置中获取人格信息 try: - info = video.get_system_info() # type: ignore[attr-defined] - logger.info( - f"inkfox: threads={info.get('threads')} version={info.get('version')} simd={info.get('simd_supported')}" + personality_config = global_config.personality + self.personality_core = getattr(personality_config, "personality_core", "是一个积极向上的女大学生") + self.personality_side = getattr( + personality_config, "personality_side", "用一句话或几句话描述人格的侧面特点" ) - except Exception as e: # pragma: no cover - logger.debug(f"获取系统信息失败: {e}") + except AttributeError: + # 如果没有personality配置,使用默认值 + self.personality_core = "是一个积极向上的女大学生" + self.personality_side = "用一句话或几句话描述人格的侧面特点" - # ---- 关键帧提取 ---- - async def extract_keyframes(self, video_path: str) -> List[Tuple[str, float]]: - """提取关键帧并返回 (base64, timestamp_seconds) 列表""" - with tempfile.TemporaryDirectory() as tmp: - result = video.extract_keyframes_from_video( # type: ignore[attr-defined] - video_path=video_path, - output_dir=tmp, - max_keyframes=self.max_frames * 2, # 先多抓一点再截断 - max_save=self.max_frames, + self.batch_analysis_prompt = getattr( + config, + "batch_analysis_prompt", + """请以第一人称的视角来观看这一个视频,你看到的这些是从视频中按时间顺序提取的关键帧。 + +你的核心人设是:{personality_core}。 +你的人格细节是:{personality_side}。 + +请提供详细的视频内容描述,涵盖以下方面: +1. 视频的整体内容和主题 +2. 主要人物、对象和场景描述 +3. 动作、情节和时间线发展 +4. 视觉风格和艺术特点 +5. 整体氛围和情感表达 +6. 任何特殊的视觉效果或文字内容 + +请用中文回答,结果要详细准确。""", + ) + + # 新增的线程池配置 + self.use_multiprocessing = getattr(config, "use_multiprocessing", True) + self.max_workers = getattr(config, "max_workers", 2) + self.frame_extraction_mode = getattr(config, "frame_extraction_mode", "fixed_number") + self.frame_interval_seconds = getattr(config, "frame_interval_seconds", 2.0) + + # 将配置文件中的模式映射到内部使用的模式名称 + config_mode = getattr(config, "analysis_mode", "auto") + if config_mode == "batch_frames": + self.analysis_mode = "batch" + elif config_mode == "frame_by_frame": + self.analysis_mode = "sequential" + elif config_mode == "auto": + self.analysis_mode = "auto" + else: + logger.warning(f"无效的分析模式: {config_mode},使用默认的auto模式") + self.analysis_mode = "auto" + + self.frame_analysis_delay = 0.3 # API调用间隔(秒) + self.frame_interval = 1.0 # 抽帧时间间隔(秒) + self.batch_size = 3 # 批处理时每批处理的帧数 + self.timeout = 60.0 # 分析超时时间(秒) + + if config: + logger.debug("✅ 从配置文件读取视频分析参数") + else: + logger.warning("配置文件中缺少video_analysis配置,使用默认值") + + # 系统提示词 + self.system_prompt = "你是一个专业的视频内容分析助手。请仔细观察用户提供的视频关键帧,详细描述视频内容。" + + logger.debug(f"✅ 视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}") + + # 获取Rust模块系统信息 + self._log_system_info() + + def _log_system_info(self): + """记录系统信息""" + if not RUST_VIDEO_AVAILABLE: + logger.info("⚠️ Rust模块不可用,跳过系统信息获取") + return + + try: + system_info = rust_video.get_system_info() + logger.debug(f"🔧 系统信息: 线程数={system_info.get('threads', '未知')}") + + # 记录CPU特性 + features = [] + if system_info.get("avx2_supported"): + features.append("AVX2") + if system_info.get("sse2_supported"): + features.append("SSE2") + if system_info.get("simd_supported"): + features.append("SIMD") + + if features: + logger.debug(f"🚀 CPU特性: {', '.join(features)}") + else: + logger.debug("⚠️ 未检测到SIMD支持") + + logger.debug(f"📦 Rust模块版本: {system_info.get('version', '未知')}") + + except Exception as e: + logger.warning(f"获取系统信息失败: {e}") + + def _calculate_video_hash(self, video_data: bytes) -> str: + """计算视频文件的hash值""" + hash_obj = hashlib.sha256() + hash_obj.update(video_data) + return hash_obj.hexdigest() + + def _check_video_exists(self, video_hash: str) -> Optional[Videos]: + """检查视频是否已经分析过""" + try: + with get_db_session() as session: + # 明确刷新会话以确保看到其他事务的最新提交 + session.expire_all() + return session.query(Videos).filter(Videos.video_hash == video_hash).first() + except Exception as e: + logger.warning(f"检查视频是否存在时出错: {e}") + return None + + def _store_video_result( + self, video_hash: str, description: str, metadata: Optional[Dict] = None + ) -> Optional[Videos]: + """存储视频分析结果到数据库""" + # 检查描述是否为错误信息,如果是则不保存 + if description.startswith("❌"): + logger.warning(f"⚠️ 检测到错误信息,不保存到数据库: {description[:50]}...") + return None + + try: + with get_db_session() as session: + # 只根据video_hash查找 + existing_video = session.query(Videos).filter(Videos.video_hash == video_hash).first() + + if existing_video: + # 如果已存在,更新描述和计数 + existing_video.description = description + existing_video.count += 1 + existing_video.timestamp = time.time() + if metadata: + existing_video.duration = metadata.get("duration") + existing_video.frame_count = metadata.get("frame_count") + existing_video.fps = metadata.get("fps") + existing_video.resolution = metadata.get("resolution") + existing_video.file_size = metadata.get("file_size") + session.commit() + session.refresh(existing_video) + logger.info(f"✅ 更新已存在的视频记录,hash: {video_hash[:16]}..., count: {existing_video.count}") + return existing_video + else: + video_record = Videos( + video_hash=video_hash, description=description, timestamp=time.time(), count=1 + ) + if metadata: + video_record.duration = metadata.get("duration") + video_record.frame_count = metadata.get("frame_count") + video_record.fps = metadata.get("fps") + video_record.resolution = metadata.get("resolution") + video_record.file_size = metadata.get("file_size") + + session.add(video_record) + session.commit() + session.refresh(video_record) + logger.info(f"✅ 新视频分析结果已保存到数据库,hash: {video_hash[:16]}...") + return video_record + except Exception as e: + logger.error(f"❌ 存储视频分析结果时出错: {e}") + return None + + def set_analysis_mode(self, mode: str): + """设置分析模式""" + if mode in ["batch", "sequential", "auto"]: + self.analysis_mode = mode + # logger.info(f"分析模式已设置为: {mode}") + else: + logger.warning(f"无效的分析模式: {mode}") + + async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]: + """提取视频帧 - 智能选择最佳实现""" + # 检查是否应该使用Rust实现 + if RUST_VIDEO_AVAILABLE and self.frame_extraction_mode == "keyframe": + # 优先尝试Rust关键帧提取 + try: + return await self._extract_frames_rust_advanced(video_path) + except Exception as e: + logger.warning(f"Rust高级接口失败: {e},尝试基础接口") + try: + return await self._extract_frames_rust(video_path) + except Exception as e2: + logger.warning(f"Rust基础接口也失败: {e2},降级到Python实现") + return await self._extract_frames_python_fallback(video_path) + else: + # 使用Python实现(支持time_interval和fixed_number模式) + if not RUST_VIDEO_AVAILABLE: + logger.info("🔄 Rust模块不可用,使用Python抽帧实现") + else: + logger.info(f"🔄 抽帧模式为 {self.frame_extraction_mode},使用Python抽帧实现") + return await self._extract_frames_python_fallback(video_path) + + async def _extract_frames_rust_advanced(self, video_path: str) -> List[Tuple[str, float]]: + """使用 Rust 高级接口的帧提取""" + try: + logger.info("🔄 使用 Rust 高级接口提取关键帧...") + + # 创建 Rust 视频处理器,使用配置参数 + extractor = rust_video.VideoKeyframeExtractor( ffmpeg_path=self.ffmpeg_path, - use_simd=self.use_simd, - threads=self.threads, - verbose=False, + threads=self.rust_threads, + verbose=False, # 使用固定值,不需要配置 ) - files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames] - total_ms = getattr(result, "total_time_ms", 0) - frames: List[Tuple[str, float]] = [] - for i, f in enumerate(files): - img = Image.open(f).convert("RGB") - if max(img.size) > self.max_image_size: - scale = self.max_image_size / max(img.size) - img = img.resize((int(img.width * scale), int(img.height * scale)), Image.Resampling.LANCZOS) - buf = io.BytesIO() - img.save(buf, format="JPEG", quality=self.frame_quality) - b64 = base64.b64encode(buf.getvalue()).decode() - ts = (i / max(1, len(files) - 1)) * (total_ms / 1000.0) if total_ms else float(i) - frames.append((b64, ts)) + + # 1. 提取所有帧 + frames_data, width, height = extractor.extract_frames( + video_path=video_path, + max_frames=self.max_frames * 3, # 提取更多帧用于关键帧检测 + ) + + logger.info(f"提取到 {len(frames_data)} 帧,视频尺寸: {width}x{height}") + + # 2. 检测关键帧,使用配置参数 + keyframe_indices = extractor.extract_keyframes( + frames=frames_data, + threshold=self.rust_keyframe_threshold, + use_simd=self.rust_use_simd, + block_size=self.rust_block_size, + ) + + logger.info(f"检测到 {len(keyframe_indices)} 个关键帧") + + # 3. 转换选定的关键帧为 base64 + frames = [] + frame_count = 0 + + for idx in keyframe_indices[: self.max_frames]: + if idx < len(frames_data): + try: + frame = frames_data[idx] + frame_data = frame.get_data() + + # 将灰度数据转换为PIL图像 + frame_array = np.frombuffer(frame_data, dtype=np.uint8).reshape((frame.height, frame.width)) + pil_image = Image.fromarray( + frame_array, + mode="L", # 灰度模式 + ) + + # 转换为RGB模式以便保存为JPEG + pil_image = pil_image.convert("RGB") + + # 调整图像大小 + if max(pil_image.size) > self.max_image_size: + ratio = self.max_image_size / max(pil_image.size) + new_size = tuple(int(dim * ratio) for dim in pil_image.size) + pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) + + # 转换为 base64 + buffer = io.BytesIO() + pil_image.save(buffer, format="JPEG", quality=self.frame_quality) + frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") + + # 估算时间戳 + estimated_timestamp = frame.frame_number * (1.0 / 30.0) # 假设30fps + + frames.append((frame_base64, estimated_timestamp)) + frame_count += 1 + + logger.debug( + f"处理关键帧 {frame_count}: 帧号 {frame.frame_number}, 时间 {estimated_timestamp:.2f}s" + ) + + except Exception as e: + logger.error(f"处理关键帧 {idx} 失败: {e}") + continue + + logger.info(f"✅ Rust 高级提取完成: {len(frames)} 关键帧") return frames - # ---- 批量分析 ---- - async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str: - from src.llm_models.payload_content.message import MessageBuilder, RoleType - from src.llm_models.utils_model import RequestType + except Exception as e: + logger.error(f"❌ Rust 高级帧提取失败: {e}") + # 回退到基础方法 + logger.info("回退到基础 Rust 方法") + return await self._extract_frames_rust(video_path) + + async def _extract_frames_rust(self, video_path: str) -> List[Tuple[str, float]]: + """使用 Rust 实现的帧提取""" + try: + logger.info("🔄 使用 Rust 模块提取关键帧...") + + # 创建临时输出目录 + with tempfile.TemporaryDirectory() as temp_dir: + # 使用便捷函数进行关键帧提取,使用配置参数 + result = rust_video.extract_keyframes_from_video( + video_path=video_path, + output_dir=temp_dir, + threshold=self.rust_keyframe_threshold, + max_frames=self.max_frames * 2, # 提取更多帧以便筛选 + max_save=self.max_frames, + ffmpeg_path=self.ffmpeg_path, + use_simd=self.rust_use_simd, + threads=self.rust_threads, + verbose=False, # 使用固定值,不需要配置 + ) + + logger.info( + f"Rust 处理完成: 总帧数 {result.total_frames}, 关键帧 {result.keyframes_extracted}, 处理速度 {result.processing_fps:.1f} FPS" + ) + + # 转换保存的关键帧为 base64 格式 + frames = [] + temp_dir_path = Path(temp_dir) + + # 获取所有保存的关键帧文件 + keyframe_files = sorted(temp_dir_path.glob("keyframe_*.jpg")) + + for i, keyframe_file in enumerate(keyframe_files): + if len(frames) >= self.max_frames: + break + + try: + # 读取关键帧文件 + with open(keyframe_file, "rb") as f: + image_data = f.read() + + # 转换为 PIL 图像并压缩 + pil_image = Image.open(io.BytesIO(image_data)) + + # 调整图像大小 + if max(pil_image.size) > self.max_image_size: + ratio = self.max_image_size / max(pil_image.size) + new_size = tuple(int(dim * ratio) for dim in pil_image.size) + pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) + + # 转换为 base64 + buffer = io.BytesIO() + pil_image.save(buffer, format="JPEG", quality=self.frame_quality) + frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") + + # 估算时间戳(基于帧索引和总时长) + if result.total_frames > 0: + # 假设关键帧在时间上均匀分布 + estimated_timestamp = (i * result.total_time_ms / 1000.0) / result.keyframes_extracted + else: + estimated_timestamp = i * 1.0 # 默认每秒一帧 + + frames.append((frame_base64, estimated_timestamp)) + + logger.debug(f"处理关键帧 {i + 1}: 估算时间 {estimated_timestamp:.2f}s") + + except Exception as e: + logger.error(f"处理关键帧 {keyframe_file.name} 失败: {e}") + continue + + logger.info(f"✅ Rust 提取完成: {len(frames)} 关键帧") + return frames + + except Exception as e: + logger.error(f"❌ Rust 帧提取失败: {e}") + raise e + + async def _extract_frames_python_fallback(self, video_path: str) -> List[Tuple[str, float]]: + """Python降级抽帧实现 - 支持多种抽帧模式""" + try: + # 导入旧版本分析器 + from .utils_video_legacy import get_legacy_video_analyzer + + logger.info("🔄 使用Python降级抽帧实现...") + legacy_analyzer = get_legacy_video_analyzer() + + # 同步配置参数 + legacy_analyzer.max_frames = self.max_frames + legacy_analyzer.frame_quality = self.frame_quality + legacy_analyzer.max_image_size = self.max_image_size + legacy_analyzer.frame_extraction_mode = self.frame_extraction_mode + legacy_analyzer.frame_interval_seconds = self.frame_interval_seconds + legacy_analyzer.use_multiprocessing = self.use_multiprocessing + + # 使用旧版本的抽帧功能 + frames = await legacy_analyzer.extract_frames(video_path) + + logger.info(f"✅ Python降级抽帧完成: {len(frames)} 帧") + return frames + + except Exception as e: + logger.error(f"❌ Python降级抽帧失败: {e}") + return [] + + async def analyze_frames_batch(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: + """批量分析所有帧""" + logger.info(f"开始批量分析{len(frames)}帧") + + if not frames: + return "❌ 没有可分析的帧" + + # 构建提示词并格式化人格信息,要不然占位符的那个会爆炸 prompt = self.batch_analysis_prompt.format( personality_core=self.personality_core, personality_side=self.personality_side ) - if question: - prompt += f"\n用户关注: {question}" - desc = [ - (f"第{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i+1}帧") - for i, (_b, ts) in enumerate(frames) - ] - prompt += "\n帧列表: " + ", ".join(desc) - mb = MessageBuilder().set_role(RoleType.User).add_text_content(prompt) - for b64, _ in frames: - mb.add_image_content("jpeg", b64) - message = mb.build() + + if user_question: + prompt += f"\n\n用户问题: {user_question}" + + # 添加帧信息到提示词 + frame_info = [] + for i, (_frame_base64, timestamp) in enumerate(frames): + if self.enable_frame_timing: + frame_info.append(f"第{i + 1}帧 (时间: {timestamp:.2f}s)") + else: + frame_info.append(f"第{i + 1}帧") + + prompt += f"\n\n视频包含{len(frames)}帧图像:{', '.join(frame_info)}" + prompt += "\n\n请基于所有提供的帧图像进行综合分析,关注并描述视频的完整内容和故事发展。" + + try: + # 使用多图片分析 + response = await self._analyze_multiple_frames(frames, prompt) + logger.info("✅ 视频识别完成") + return response + + except Exception as e: + logger.error(f"❌ 视频识别失败: {e}") + raise e + + async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str: + """使用多图片分析方法""" + logger.info(f"开始构建包含{len(frames)}帧的分析请求") + + # 导入MessageBuilder用于构建多图片消息 + from src.llm_models.payload_content.message import MessageBuilder, RoleType + from src.llm_models.utils_model import RequestType + + # 构建包含多张图片的消息 + message_builder = MessageBuilder().set_role(RoleType.User).add_text_content(prompt) + + # 添加所有帧图像 + for _i, (frame_base64, _timestamp) in enumerate(frames): + message_builder.add_image_content("jpeg", frame_base64) + # logger.info(f"已添加第{i+1}帧到分析请求 (时间: {timestamp:.2f}s, 图片大小: {len(frame_base64)} chars)") + + message = message_builder.build() + # logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片") + + # 获取模型信息和客户端 model_info, api_provider, client = self.video_llm._select_model() - resp = await self.video_llm._execute_request( + # logger.info(f"使用模型: {model_info.name} 进行多帧分析") + + # 直接执行多图片请求 + api_response = await self.video_llm._execute_request( api_provider=api_provider, client=client, request_type=RequestType.RESPONSE, @@ -148,172 +548,365 @@ class VideoAnalyzer: temperature=None, max_tokens=None, ) - return resp.content or "❌ 未获得响应" - # ---- 逐帧分析 ---- - async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str: - results: List[str] = [] - for i, (b64, ts) in enumerate(frames): - prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "") - if question: - prompt += f"\n关注: {question}" + logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ") + return api_response.content or "❌ 未获得响应内容" + + async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: + """逐帧分析并汇总""" + logger.info(f"开始逐帧分析{len(frames)}帧") + + frame_analyses = [] + + for i, (frame_base64, timestamp) in enumerate(frames): try: - text, _ = await self.video_llm.generate_response_for_image( - prompt=prompt, image_base64=b64, image_format="jpeg" - ) - results.append(f"第{i+1}帧: {text}") - except Exception as e: # pragma: no cover - results.append(f"第{i+1}帧: 失败 {e}") - if i < len(frames) - 1: - await asyncio.sleep(self.frame_analysis_delay) - summary_prompt = "基于以下逐帧结果给出完整总结:\n\n" + "\n".join(results) - try: - final, _ = await self.video_llm.generate_response_for_image( - prompt=summary_prompt, image_base64=frames[-1][0], image_format="jpeg" - ) - return final - except Exception: # pragma: no cover - return "\n".join(results) + prompt = f"请分析这个视频的第{i + 1}帧" + if self.enable_frame_timing: + prompt += f" (时间: {timestamp:.2f}s)" + prompt += "。描述你看到的内容,包括人物、动作、场景、文字等。" - # ---- 主入口 ---- - async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]: - if not os.path.exists(video_path): - return False, "❌ 文件不存在" - frames = await self.extract_keyframes(video_path) - if not frames: - return False, "❌ 未提取到关键帧" - mode = self.analysis_mode - if mode == "auto": - mode = "batch" if len(frames) <= 20 else "sequential" - text = await (self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question)) - return True, text + if user_question: + prompt += f"\n特别关注: {user_question}" + + response, _ = await self.video_llm.generate_response_for_image( + prompt=prompt, image_base64=frame_base64, image_format="jpeg" + ) + + frame_analyses.append(f"第{i + 1}帧 ({timestamp:.2f}s): {response}") + logger.debug(f"✅ 第{i + 1}帧分析完成") + + # API调用间隔 + if i < len(frames) - 1: + await asyncio.sleep(self.frame_analysis_delay) + + except Exception as e: + logger.error(f"❌ 第{i + 1}帧分析失败: {e}") + frame_analyses.append(f"第{i + 1}帧: 分析失败 - {e}") + + # 生成汇总 + logger.info("开始生成汇总分析") + summary_prompt = f"""基于以下各帧的分析结果,请提供一个完整的视频内容总结: + +{chr(10).join(frame_analyses)} + +请综合所有帧的信息,描述视频的整体内容、故事线、主要元素和特点。""" + + if user_question: + summary_prompt += f"\n特别回答用户的问题: {user_question}" + + try: + # 使用最后一帧进行汇总分析 + if frames: + last_frame_base64, _ = frames[-1] + summary, _ = await self.video_llm.generate_response_for_image( + prompt=summary_prompt, image_base64=last_frame_base64, image_format="jpeg" + ) + logger.info("✅ 逐帧分析和汇总完成") + return summary + else: + return "❌ 没有可用于汇总的帧" + except Exception as e: + logger.error(f"❌ 汇总分析失败: {e}") + # 如果汇总失败,返回各帧分析结果 + return f"视频逐帧分析结果:\n\n{chr(10).join(frame_analyses)}" + + async def analyze_video(self, video_path: str, user_question: str = None) -> Tuple[bool, str]: + """分析视频的主要方法 + + Returns: + Tuple[bool, str]: (是否成功, 分析结果或错误信息) + """ + if self.disabled: + error_msg = "❌ 视频分析功能已禁用:没有可用的视频处理实现" + logger.warning(error_msg) + return (False, error_msg) + + try: + logger.info(f"开始分析视频: {os.path.basename(video_path)}") + + # 提取帧 + frames = await self.extract_frames(video_path) + if not frames: + error_msg = "❌ 无法从视频中提取有效帧" + return (False, error_msg) + + # 根据模式选择分析方法 + if self.analysis_mode == "auto": + # 智能选择:少于等于3帧用批量,否则用逐帧 + mode = "batch" if len(frames) <= 3 else "sequential" + logger.info(f"自动选择分析模式: {mode} (基于{len(frames)}帧)") + else: + mode = self.analysis_mode + + # 执行分析 + if mode == "batch": + result = await self.analyze_frames_batch(frames, user_question) + else: # sequential + result = await self.analyze_frames_sequential(frames, user_question) + + logger.info("✅ 视频分析完成") + return (True, result) + + except Exception as e: + error_msg = f"❌ 视频分析失败: {str(e)}" + logger.error(error_msg) + return (False, error_msg) async def analyze_video_from_bytes( - self, - video_bytes: bytes, - filename: Optional[str] = None, - prompt: Optional[str] = None, - question: Optional[str] = None, + self, video_bytes: bytes, filename: str = None, user_question: str = None, prompt: str = None ) -> Dict[str, str]: - """从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}.""" - if not video_bytes: - return {"summary": "❌ 空视频数据"} - # 兼容参数:prompt 优先,其次 question - q = prompt if prompt is not None else question - video_hash = hashlib.sha256(video_bytes).hexdigest() + """从字节数据分析视频 - # 查缓存(第一次,未加锁) - cached = await self._get_cached(video_hash) - if cached: - logger.info(f"视频缓存命中(预检查) hash={video_hash[:16]}") - return {"summary": cached} + Args: + video_bytes: 视频字节数据 + filename: 文件名(可选,仅用于日志) + user_question: 用户问题(旧参数名,保持兼容性) + prompt: 提示词(新参数名,与系统调用保持一致) - # 获取锁避免重复处理 - async with _locks_guard: - lock = _video_locks.get(video_hash) - if lock is None: - lock = asyncio.Lock() - _video_locks[video_hash] = lock - async with lock: - # 双检缓存 - cached2 = await self._get_cached(video_hash) - if cached2: - logger.info(f"视频缓存命中(锁后) hash={video_hash[:16]}") - return {"summary": cached2} + Returns: + Dict[str, str]: 包含分析结果的字典,格式为 {"summary": "分析结果"} + """ + if self.disabled: + return {"summary": "❌ 视频分析功能已禁用:没有可用的视频处理实现"} - try: - with tempfile.NamedTemporaryFile(delete=False) as fp: - fp.write(video_bytes) - temp_path = fp.name + video_hash = None + video_event = None + + try: + logger.info("开始从字节数据分析视频") + + # 兼容性处理:如果传入了prompt参数,使用prompt;否则使用user_question + question = prompt if prompt is not None else user_question + + # 检查视频数据是否有效 + if not video_bytes: + return {"summary": "❌ 视频数据为空"} + + # 计算视频hash值 + video_hash = self._calculate_video_hash(video_bytes) + logger.info(f"视频hash: {video_hash}") + + # 改进的并发控制:使用每个视频独立的锁和事件 + async with video_lock_manager: + if video_hash not in video_locks: + video_locks[video_hash] = asyncio.Lock() + video_events[video_hash] = asyncio.Event() + + video_lock = video_locks[video_hash] + video_event = video_events[video_hash] + + # 尝试获取该视频的专用锁 + if video_lock.locked(): + logger.info(f"⏳ 相同视频正在处理中,等待处理完成... (hash: {video_hash[:16]}...)") try: - ok, summary = await self.analyze_video(temp_path, q) - # 写入缓存(仅成功) - if ok: - await self._save_cache(video_hash, summary, len(video_bytes)) - return {"summary": summary} + # 等待处理完成的事件信号,最多等待60秒 + await asyncio.wait_for(video_event.wait(), timeout=60.0) + logger.info("✅ 等待结束,检查是否有处理结果") + + # 检查是否有结果了 + existing_video = self._check_video_exists(video_hash) + if existing_video: + logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})") + return {"summary": existing_video.description} + else: + logger.warning("⚠️ 等待完成但未找到结果,可能处理失败") + except asyncio.TimeoutError: + logger.warning("⚠️ 等待超时(60秒),放弃等待") + + # 获取锁开始处理 + async with video_lock: + logger.info(f"🔒 获得视频处理锁,开始处理 (hash: {video_hash[:16]}...)") + + # 再次检查数据库(可能在等待期间已经有结果了) + existing_video = self._check_video_exists(video_hash) + if existing_video: + logger.info(f"✅ 获得锁后发现已有结果,直接返回 (id: {existing_video.id})") + video_event.set() # 通知其他等待者 + return {"summary": existing_video.description} + + # 未找到已存在记录,开始新的分析 + logger.info("未找到已存在的视频记录,开始新的分析") + + # 创建临时文件进行分析 + with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file: + temp_file.write(video_bytes) + temp_path = temp_file.name + + try: + # 检查临时文件是否创建成功 + if not os.path.exists(temp_path): + video_event.set() # 通知等待者 + return {"summary": "❌ 临时文件创建失败"} + + # 使用临时文件进行分析 + success, result = await self.analyze_video(temp_path, question) + finally: + # 清理临时文件 if os.path.exists(temp_path): - try: - os.remove(temp_path) - except Exception: # pragma: no cover - pass - except Exception as e: # pragma: no cover - return {"summary": f"❌ 处理失败: {e}"} + os.unlink(temp_path) + + # 保存分析结果到数据库(仅保存成功的结果) + if success and not result.startswith("❌"): + metadata = {"filename": filename, "file_size": len(video_bytes), "analysis_timestamp": time.time()} + self._store_video_result(video_hash=video_hash, description=result, metadata=metadata) + logger.info("✅ 分析结果已保存到数据库") + else: + logger.warning("⚠️ 分析失败,不保存到数据库以便后续重试") + + # 处理完成,通知等待者并清理资源 + video_event.set() + async with video_lock_manager: + # 清理资源 + video_locks.pop(video_hash, None) + video_events.pop(video_hash, None) + + return {"summary": result} + + except Exception as e: + error_msg = f"❌ 从字节数据分析视频失败: {str(e)}" + logger.error(error_msg) + + # 不保存错误信息到数据库,允许后续重试 + logger.info("💡 错误信息不保存到数据库,允许后续重试") + + # 处理失败,通知等待者并清理资源 + try: + if video_hash and video_event: + async with video_lock_manager: + if video_hash in video_events: + video_events[video_hash].set() + video_locks.pop(video_hash, None) + video_events.pop(video_hash, None) + except Exception as cleanup_e: + logger.error(f"❌ 清理锁资源失败: {cleanup_e}") + + return {"summary": error_msg} + + def is_supported_video(self, file_path: str) -> bool: + """检查是否为支持的视频格式""" + supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"} + return Path(file_path).suffix.lower() in supported_formats + + def get_processing_capabilities(self) -> Dict[str, any]: + """获取处理能力信息""" + if not RUST_VIDEO_AVAILABLE: + return {"error": "Rust视频处理模块不可用", "available": False, "reason": "rust_video模块未安装或加载失败"} - # ---- 缓存辅助 ---- - async def _get_cached(self, video_hash: str) -> Optional[str]: try: - async with get_db_session() as session: # type: ignore - result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) # type: ignore - obj: Optional[Videos] = result.scalar_one_or_none() # type: ignore - if obj and obj.vlm_processed and obj.description: - # 更新使用次数 - try: - await session.execute( - update(Videos) - .where(Videos.id == obj.id) # type: ignore - .values(count=obj.count + 1 if obj.count is not None else 1) - ) - await session.commit() - except Exception: # pragma: no cover - await session.rollback() - return obj.description - except Exception: # pragma: no cover - pass - return None + system_info = rust_video.get_system_info() - async def _save_cache(self, video_hash: str, summary: str, file_size: int) -> None: - try: - async with get_db_session() as session: # type: ignore - stmt = insert(Videos).values( # type: ignore - video_id="", - video_hash=video_hash, - description=summary, - count=1, - timestamp=time.time(), - vlm_processed=True, - duration=None, - frame_count=None, - fps=None, - resolution=None, - file_size=file_size, - ) - try: - await session.execute(stmt) - await session.commit() - logger.debug(f"视频缓存写入 success hash={video_hash}") - except sa_exc.IntegrityError: # 可能并发已写入 - await session.rollback() - logger.debug(f"视频缓存已存在 hash={video_hash}") - except Exception: # pragma: no cover - logger.debug("视频缓存写入失败") + # 创建一个临时的extractor来获取CPU特性 + extractor = rust_video.VideoKeyframeExtractor(threads=0, verbose=False) + cpu_features = extractor.get_cpu_features() + + capabilities = { + "system": { + "threads": system_info.get("threads", 0), + "rust_version": system_info.get("version", "unknown"), + }, + "cpu_features": cpu_features, + "recommended_settings": self._get_recommended_settings(cpu_features), + "analysis_modes": ["auto", "batch", "sequential"], + "supported_formats": [".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"], + "available": True, + } + + return capabilities + + except Exception as e: + logger.error(f"获取处理能力信息失败: {e}") + return {"error": str(e), "available": False} + + def _get_recommended_settings(self, cpu_features: Dict[str, bool]) -> Dict[str, any]: + """根据CPU特性推荐最佳设置""" + settings = { + "use_simd": any(cpu_features.values()), + "block_size": 8192, + "threads": 0, # 自动检测 + } + + # 根据CPU特性调整设置 + if cpu_features.get("avx2", False): + settings["block_size"] = 16384 # AVX2支持更大的块 + settings["optimization_level"] = "avx2" + elif cpu_features.get("sse2", False): + settings["block_size"] = 8192 + settings["optimization_level"] = "sse2" + else: + settings["use_simd"] = False + settings["block_size"] = 4096 + settings["optimization_level"] = "scalar" + + return settings -# ---- 外部接口 ---- -_INSTANCE: Optional[VideoAnalyzer] = None +# 全局实例 +_video_analyzer = None def get_video_analyzer() -> VideoAnalyzer: - global _INSTANCE - if _INSTANCE is None: - _INSTANCE = VideoAnalyzer() - return _INSTANCE + """获取视频分析器实例(单例模式)""" + global _video_analyzer + if _video_analyzer is None: + _video_analyzer = VideoAnalyzer() + return _video_analyzer def is_video_analysis_available() -> bool: - return True + """检查视频分析功能是否可用 - -def get_video_analysis_status() -> Dict[str, Any]: + Returns: + bool: 如果有任何可用的视频处理实现则返回True + """ + # 现在即使Rust模块不可用,也可以使用Python降级实现 try: - info = video.get_system_info() # type: ignore[attr-defined] - except Exception as e: # pragma: no cover - return {"available": False, "error": str(e)} - inst = get_video_analyzer() - return { - "available": True, - "system": info, - "modes": ["auto", "batch", "sequential"], - "max_frames_default": inst.max_frames, - "implementation": "inkfox", + import cv2 + + return True + except ImportError: + return False + + +def get_video_analysis_status() -> Dict[str, any]: + """获取视频分析功能的详细状态信息 + + Returns: + Dict[str, any]: 包含功能状态信息的字典 + """ + # 检查OpenCV是否可用 + opencv_available = False + try: + import cv2 + + opencv_available = True + except ImportError: + pass + + status = { + "available": opencv_available or RUST_VIDEO_AVAILABLE, + "implementations": { + "rust_keyframe": { + "available": RUST_VIDEO_AVAILABLE, + "description": "Rust智能关键帧提取", + "supported_modes": ["keyframe"], + }, + "python_legacy": { + "available": opencv_available, + "description": "Python传统抽帧方法", + "supported_modes": ["fixed_number", "time_interval"], + }, + }, + "supported_modes": [], } + + # 汇总支持的模式 + if RUST_VIDEO_AVAILABLE: + status["supported_modes"].extend(["keyframe"]) + if opencv_available: + status["supported_modes"].extend(["fixed_number", "time_interval"]) + + if not status["available"]: + status.update({"error": "没有可用的视频处理实现", "solution": "请安装opencv-python或rust_video模块"}) + + return status diff --git a/src/common/data_models/__init__.py b/src/common/data_models/__init__.py index 222ff59ca..d104eec9c 100644 --- a/src/common/data_models/__init__.py +++ b/src/common/data_models/__init__.py @@ -6,6 +6,7 @@ class BaseDataModel: def deepcopy(self): return copy.deepcopy(self) + def temporarily_transform_class_to_dict(obj: Any) -> Any: # sourcery skip: assign-if-exp, reintroduce-else """ diff --git a/src/common/data_models/bot_interest_data_model.py b/src/common/data_models/bot_interest_data_model.py new file mode 100644 index 000000000..819b50a8f --- /dev/null +++ b/src/common/data_models/bot_interest_data_model.py @@ -0,0 +1,137 @@ +""" +机器人兴趣标签数据模型 +定义机器人的兴趣标签和相关的embedding数据结构 +""" + +from dataclasses import dataclass, field +from typing import List, Dict, Optional, Any +from datetime import datetime + +from . import BaseDataModel + + +@dataclass +class BotInterestTag(BaseDataModel): + """机器人兴趣标签""" + + tag_name: str + weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0) + embedding: Optional[List[float]] = None # 标签的embedding向量 + created_at: datetime = field(default_factory=datetime.now) + updated_at: datetime = field(default_factory=datetime.now) + is_active: bool = True + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式""" + return { + "tag_name": self.tag_name, + "weight": self.weight, + "embedding": self.embedding, + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + "is_active": self.is_active, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "BotInterestTag": + """从字典创建对象""" + return cls( + tag_name=data["tag_name"], + weight=data.get("weight", 1.0), + embedding=data.get("embedding"), + created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(), + updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(), + is_active=data.get("is_active", True), + ) + + +@dataclass +class BotPersonalityInterests(BaseDataModel): + """机器人人格化兴趣配置""" + + personality_id: str + personality_description: str # 人设描述文本 + interest_tags: List[BotInterestTag] = field(default_factory=list) + embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型 + last_updated: datetime = field(default_factory=datetime.now) + version: int = 1 # 版本号,用于追踪更新 + + def get_active_tags(self) -> List[BotInterestTag]: + """获取活跃的兴趣标签""" + return [tag for tag in self.interest_tags if tag.is_active] + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式""" + return { + "personality_id": self.personality_id, + "personality_description": self.personality_description, + "interest_tags": [tag.to_dict() for tag in self.interest_tags], + "embedding_model": self.embedding_model, + "last_updated": self.last_updated.isoformat(), + "version": self.version, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "BotPersonalityInterests": + """从字典创建对象""" + return cls( + personality_id=data["personality_id"], + personality_description=data["personality_description"], + interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])], + embedding_model=data.get("embedding_model", "text-embedding-ada-002"), + last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(), + version=data.get("version", 1), + ) + + +@dataclass +class InterestMatchResult(BaseDataModel): + """兴趣匹配结果""" + + message_id: str + matched_tags: List[str] = field(default_factory=list) + match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score + overall_score: float = 0.0 + top_tag: Optional[str] = None + confidence: float = 0.0 # 匹配置信度 (0.0-1.0) + matched_keywords: List[str] = field(default_factory=list) + + def add_match(self, tag_name: str, score: float, keywords: List[str] = None): + """添加匹配结果""" + self.matched_tags.append(tag_name) + self.match_scores[tag_name] = score + if keywords: + self.matched_keywords.extend(keywords) + + def calculate_overall_score(self): + """计算总体匹配分数""" + if not self.match_scores: + self.overall_score = 0.0 + self.top_tag = None + return + + # 使用加权平均计算总体分数 + total_weight = len(self.match_scores) + if total_weight > 0: + self.overall_score = sum(self.match_scores.values()) / total_weight + # 设置最佳匹配标签 + self.top_tag = max(self.match_scores.items(), key=lambda x: x[1])[0] + else: + self.overall_score = 0.0 + self.top_tag = None + + # 计算置信度(基于匹配标签数量和分数分布) + if len(self.match_scores) > 0: + avg_score = self.overall_score + score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len( + self.match_scores + ) + # 分数越集中,置信度越高 + self.confidence = max(0.0, 1.0 - score_variance) + else: + self.confidence = 0.0 + + def get_top_matches(self, top_n: int = 3) -> List[tuple]: + """获取前N个最佳匹配""" + sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True) + return sorted_matches[:top_n] diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index bf4a5f527..4578d1481 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -79,6 +79,7 @@ class DatabaseMessages(BaseDataModel): is_command: bool = False, is_notify: bool = False, selected_expressions: Optional[str] = None, + is_read: bool = False, user_id: str = "", user_nickname: str = "", user_cardname: Optional[str] = None, @@ -94,6 +95,9 @@ class DatabaseMessages(BaseDataModel): chat_info_platform: str = "", chat_info_create_time: float = 0.0, chat_info_last_active_time: float = 0.0, + # 新增字段 + actions: Optional[list] = None, + should_reply: bool = False, **kwargs: Any, ): self.message_id = message_id @@ -102,6 +106,10 @@ class DatabaseMessages(BaseDataModel): self.reply_to = reply_to self.interest_value = interest_value + # 新增字段 + self.actions = actions + self.should_reply = should_reply + self.key_words = key_words self.key_words_lite = key_words_lite self.is_mentioned = is_mentioned @@ -122,6 +130,7 @@ class DatabaseMessages(BaseDataModel): self.is_notify = is_notify self.selected_expressions = selected_expressions + self.is_read = is_read self.group_info: Optional[DatabaseGroupInfo] = None self.user_info = DatabaseUserInfo( @@ -188,6 +197,10 @@ class DatabaseMessages(BaseDataModel): "is_command": self.is_command, "is_notify": self.is_notify, "selected_expressions": self.selected_expressions, + "is_read": self.is_read, + # 新增字段 + "actions": self.actions, + "should_reply": self.should_reply, "user_id": self.user_info.user_id, "user_nickname": self.user_info.user_nickname, "user_cardname": self.user_info.user_cardname, @@ -205,6 +218,61 @@ class DatabaseMessages(BaseDataModel): "chat_info_user_cardname": self.chat_info.user_info.user_cardname, } + def update_message_info(self, interest_value: float = None, actions: list = None, should_reply: bool = None): + """ + 更新消息信息 + + Args: + interest_value: 兴趣度值 + actions: 执行的动作列表 + should_reply: 是否应该回复 + """ + if interest_value is not None: + self.interest_value = interest_value + if actions is not None: + self.actions = actions + if should_reply is not None: + self.should_reply = should_reply + + def add_action(self, action: str): + """ + 添加执行的动作到消息中 + + Args: + action: 要添加的动作名称 + """ + if self.actions is None: + self.actions = [] + if action not in self.actions: # 避免重复添加 + self.actions.append(action) + + def get_actions(self) -> list: + """ + 获取执行的动作列表 + + Returns: + 动作列表,如果没有动作则返回空列表 + """ + return self.actions or [] + + def get_message_summary(self) -> Dict[str, Any]: + """ + 获取消息摘要信息 + + Returns: + 包含关键字段的消息摘要 + """ + return { + "message_id": self.message_id, + "time": self.time, + "interest_value": self.interest_value, + "actions": self.actions, + "should_reply": self.should_reply, + "user_nickname": self.user_info.user_nickname, + "display_message": self.display_message, + } + + @dataclass(init=False) class DatabaseActionRecords(BaseDataModel): def __init__( @@ -232,4 +300,4 @@ class DatabaseActionRecords(BaseDataModel): self.action_prompt_display = action_prompt_display self.chat_id = chat_id self.chat_info_stream_id = chat_info_stream_id - self.chat_info_platform = chat_info_platform \ No newline at end of file + self.chat_info_platform = chat_info_platform diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py index 7de787060..ba45ab3c4 100644 --- a/src/common/data_models/info_data_model.py +++ b/src/common/data_models/info_data_model.py @@ -1,10 +1,12 @@ from dataclasses import dataclass, field from typing import Optional, Dict, List, TYPE_CHECKING +from src.plugin_system.base.component_types import ChatType from . import BaseDataModel if TYPE_CHECKING: - pass + from .database_data_model import DatabaseMessages + from src.plugin_system.base.component_types import ActionInfo, ChatMode @dataclass @@ -21,23 +23,37 @@ class ActionPlannerInfo(BaseDataModel): action_type: str = field(default_factory=str) reasoning: Optional[str] = None action_data: Optional[Dict] = None - action_message: Optional[Dict] = None + action_message: Optional["DatabaseMessages"] = None available_actions: Optional[Dict[str, "ActionInfo"]] = None +@dataclass +class InterestScore(BaseDataModel): + """兴趣度评分结果""" + + message_id: str + total_score: float + interest_match_score: float + relationship_score: float + mentioned_score: float + details: Dict[str, str] + + @dataclass class Plan(BaseDataModel): """ 统一规划数据模型 """ + chat_id: str mode: "ChatMode" - + + chat_type: "ChatType" # Generator 填充 available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict) chat_history: List["DatabaseMessages"] = field(default_factory=list) target_info: Optional[TargetPersonInfo] = None - + # Filter 填充 llm_prompt: Optional[str] = None decided_actions: Optional[List[ActionPlannerInfo]] = None diff --git a/src/common/data_models/llm_data_model.py b/src/common/data_models/llm_data_model.py index cd706bc55..a59b65391 100644 --- a/src/common/data_models/llm_data_model.py +++ b/src/common/data_models/llm_data_model.py @@ -6,6 +6,7 @@ from . import BaseDataModel if TYPE_CHECKING: pass + @dataclass class LLMGenerationDataModel(BaseDataModel): content: Optional[str] = None @@ -14,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel): tool_calls: Optional[List["ToolCall"]] = None prompt: Optional[str] = None selected_expressions: Optional[List[int]] = None - reply_set: Optional[List[Tuple[str, Any]]] = None \ No newline at end of file + reply_set: Optional[List[Tuple[str, Any]]] = None diff --git a/src/common/data_models/message_data_model.py b/src/common/data_models/message_data_model.py deleted file mode 100644 index bf08a0d6a..000000000 --- a/src/common/data_models/message_data_model.py +++ /dev/null @@ -1,36 +0,0 @@ -from dataclasses import dataclass, field -from typing import Optional, TYPE_CHECKING - -from . import BaseDataModel - -if TYPE_CHECKING: - pass - - -@dataclass -class MessageAndActionModel(BaseDataModel): - chat_id: str = field(default_factory=str) - time: float = field(default_factory=float) - user_id: str = field(default_factory=str) - user_platform: str = field(default_factory=str) - user_nickname: str = field(default_factory=str) - user_cardname: Optional[str] = None - processed_plain_text: Optional[str] = None - display_message: Optional[str] = None - chat_info_platform: str = field(default_factory=str) - is_action_record: bool = field(default=False) - action_name: Optional[str] = None - - @classmethod - def from_DatabaseMessages(cls, message: "DatabaseMessages"): - return cls( - chat_id=message.chat_id, - time=message.time, - user_id=message.user_info.user_id, - user_platform=message.user_info.platform, - user_nickname=message.user_info.user_nickname, - user_cardname=message.user_info.user_cardname, - processed_plain_text=message.processed_plain_text, - display_message=message.display_message, - chat_info_platform=message.chat_info.platform, - ) diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py new file mode 100644 index 000000000..f35c53573 --- /dev/null +++ b/src/common/data_models/message_manager_data_model.py @@ -0,0 +1,373 @@ +""" +消息管理模块数据模型 +定义消息管理器使用的数据结构 +""" + +import asyncio +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional, TYPE_CHECKING + +from . import BaseDataModel +from src.plugin_system.base.component_types import ChatMode, ChatType +from src.common.logger import get_logger + +if TYPE_CHECKING: + from .database_data_model import DatabaseMessages + +logger = get_logger("stream_context") + + +class MessageStatus(Enum): + """消息状态枚举""" + + UNREAD = "unread" # 未读消息 + READ = "read" # 已读消息 + PROCESSING = "processing" # 处理中 + + +@dataclass +class StreamContext(BaseDataModel): + """聊天流上下文信息""" + + stream_id: str + chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊 + chat_mode: ChatMode = ChatMode.NORMAL # 聊天模式,默认为普通模式 + unread_messages: List["DatabaseMessages"] = field(default_factory=list) + history_messages: List["DatabaseMessages"] = field(default_factory=list) + last_check_time: float = field(default_factory=time.time) + is_active: bool = True + processing_task: Optional[asyncio.Task] = None + interruption_count: int = 0 # 打断计数器 + last_interruption_time: float = 0.0 # 上次打断时间 + afc_threshold_adjustment: float = 0.0 # afc阈值调整量 + + # 独立分发周期字段 + next_check_time: float = field(default_factory=time.time) # 下次检查时间 + distribution_interval: float = 5.0 # 当前分发周期(秒) + + # 新增字段以替代ChatMessageContext功能 + current_message: Optional["DatabaseMessages"] = None + priority_mode: Optional[str] = None + priority_info: Optional[dict] = None + + def add_message(self, message: "DatabaseMessages"): + """添加消息到上下文""" + message.is_read = False + self.unread_messages.append(message) + + # 自动检测和更新chat type + self._detect_chat_type(message) + + def update_message_info( + self, message_id: str, interest_value: float = None, actions: list = None, should_reply: bool = None + ): + """ + 更新消息信息 + + Args: + message_id: 消息ID + interest_value: 兴趣度值 + actions: 执行的动作列表 + should_reply: 是否应该回复 + """ + # 在未读消息中查找并更新 + for message in self.unread_messages: + if message.message_id == message_id: + message.update_message_info(interest_value, actions, should_reply) + break + + # 在历史消息中查找并更新 + for message in self.history_messages: + if message.message_id == message_id: + message.update_message_info(interest_value, actions, should_reply) + break + + def add_action_to_message(self, message_id: str, action: str): + """ + 向指定消息添加执行的动作 + + Args: + message_id: 消息ID + action: 要添加的动作名称 + """ + # 在未读消息中查找并更新 + for message in self.unread_messages: + if message.message_id == message_id: + message.add_action(action) + break + + # 在历史消息中查找并更新 + for message in self.history_messages: + if message.message_id == message_id: + message.add_action(action) + break + + def _detect_chat_type(self, message: "DatabaseMessages"): + """根据消息内容自动检测聊天类型""" + # 只有在第一次添加消息时才检测聊天类型,避免后续消息改变类型 + if len(self.unread_messages) == 1: # 只有这条消息 + # 如果消息包含群组信息,则为群聊 + if hasattr(message, "chat_info_group_id") and message.chat_info_group_id: + self.chat_type = ChatType.GROUP + elif hasattr(message, "chat_info_group_name") and message.chat_info_group_name: + self.chat_type = ChatType.GROUP + else: + self.chat_type = ChatType.PRIVATE + + def update_chat_type(self, chat_type: ChatType): + """手动更新聊天类型""" + self.chat_type = chat_type + + def set_chat_mode(self, chat_mode: ChatMode): + """设置聊天模式""" + self.chat_mode = chat_mode + + def is_group_chat(self) -> bool: + """检查是否为群聊""" + return self.chat_type == ChatType.GROUP + + def is_private_chat(self) -> bool: + """检查是否为私聊""" + return self.chat_type == ChatType.PRIVATE + + def get_chat_type_display(self) -> str: + """获取聊天类型的显示名称""" + if self.chat_type == ChatType.GROUP: + return "群聊" + elif self.chat_type == ChatType.PRIVATE: + return "私聊" + else: + return "未知类型" + + def mark_message_as_read(self, message_id: str): + """标记消息为已读""" + for msg in self.unread_messages: + if msg.message_id == message_id: + msg.is_read = True + self.history_messages.append(msg) + self.unread_messages.remove(msg) + break + + def get_unread_messages(self) -> List["DatabaseMessages"]: + """获取未读消息""" + return [msg for msg in self.unread_messages if not msg.is_read] + + def get_history_messages(self, limit: int = 20) -> List["DatabaseMessages"]: + """获取历史消息""" + # 优先返回最近的历史消息和所有未读消息 + recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages + return recent_history + + def calculate_interruption_probability(self, max_limit: int, probability_factor: float) -> float: + """计算打断概率""" + if max_limit <= 0: + return 0.0 + + # 计算打断比例 + interruption_ratio = self.interruption_count / max_limit + + # 如果已达到或超过最大次数,完全禁止打断 + if self.interruption_count >= max_limit: + return 0.0 + + # 如果超过概率因子,概率下降 + if interruption_ratio > probability_factor: + # 使用指数衰减,超过限制越多,概率越低 + excess_ratio = interruption_ratio - probability_factor + probability = 0.8 * (0.5**excess_ratio) # 基础概率0.8,指数衰减 + else: + # 在限制内,保持较高概率 + probability = 0.8 + + return max(0.0, min(1.0, probability)) + + def increment_interruption_count(self): + """增加打断计数""" + self.interruption_count += 1 + self.last_interruption_time = time.time() + + # 同步打断计数到ChatStream + self._sync_interruption_count_to_stream() + + def reset_interruption_count(self): + """重置打断计数和afc阈值调整""" + self.interruption_count = 0 + self.last_interruption_time = 0.0 + self.afc_threshold_adjustment = 0.0 + + # 同步打断计数到ChatStream + self._sync_interruption_count_to_stream() + + def apply_interruption_afc_reduction(self, reduction_value: float): + """应用打断导致的afc阈值降低""" + self.afc_threshold_adjustment += reduction_value + logger.debug(f"应用afc阈值降低: {reduction_value}, 总调整量: {self.afc_threshold_adjustment}") + + def get_afc_threshold_adjustment(self) -> float: + """获取当前的afc阈值调整量""" + return self.afc_threshold_adjustment + + def _sync_interruption_count_to_stream(self): + """同步打断计数到ChatStream""" + try: + from src.chat.message_receive.chat_stream import get_chat_manager + + chat_manager = get_chat_manager() + if chat_manager: + chat_stream = chat_manager.get_stream(self.stream_id) + if chat_stream and hasattr(chat_stream, "interruption_count"): + # 在这里我们只是标记需要保存,实际的保存会在下次save时进行 + chat_stream.saved = False + logger.debug( + f"已同步StreamContext {self.stream_id} 的打断计数 {self.interruption_count} 到ChatStream" + ) + except Exception as e: + logger.warning(f"同步打断计数到ChatStream失败: {e}") + + def set_current_message(self, message: "DatabaseMessages"): + """设置当前消息""" + self.current_message = message + + def get_template_name(self) -> Optional[str]: + """获取模板名称""" + if ( + self.current_message + and hasattr(self.current_message, "additional_config") + and self.current_message.additional_config + ): + try: + import json + + config = json.loads(self.current_message.additional_config) + if config.get("template_info") and not config.get("template_default", True): + return config.get("template_name") + except (json.JSONDecodeError, AttributeError): + pass + return None + + def get_last_message(self) -> Optional["DatabaseMessages"]: + """获取最后一条消息""" + if self.current_message: + return self.current_message + if self.unread_messages: + return self.unread_messages[-1] + if self.history_messages: + return self.history_messages[-1] + return None + + def check_types(self, types: list) -> bool: + """ + 检查当前消息是否支持指定的类型 + + Args: + types: 需要检查的消息类型列表,如 ["text", "image", "emoji"] + + Returns: + bool: 如果消息支持所有指定的类型则返回True,否则返回False + """ + if not self.current_message: + return False + + if not types: + # 如果没有指定类型要求,默认为支持 + return True + + # 优先从additional_config中获取format_info + if hasattr(self.current_message, "additional_config") and self.current_message.additional_config: + try: + import orjson + + config = orjson.loads(self.current_message.additional_config) + + # 检查format_info结构 + if "format_info" in config: + format_info = config["format_info"] + + # 方法1: 直接检查accept_format字段 + if "accept_format" in format_info: + accept_format = format_info["accept_format"] + # 确保accept_format是列表类型 + if isinstance(accept_format, str): + accept_format = [accept_format] + elif isinstance(accept_format, list): + pass + else: + # 如果accept_format不是字符串或列表,尝试转换为列表 + accept_format = list(accept_format) if hasattr(accept_format, "__iter__") else [] + + # 检查所有请求的类型是否都被支持 + for requested_type in types: + if requested_type not in accept_format: + logger.debug(f"消息不支持类型 '{requested_type}',支持的类型: {accept_format}") + return False + return True + + # 方法2: 检查content_format字段(向后兼容) + elif "content_format" in format_info: + content_format = format_info["content_format"] + # 确保content_format是列表类型 + if isinstance(content_format, str): + content_format = [content_format] + elif isinstance(content_format, list): + pass + else: + content_format = list(content_format) if hasattr(content_format, "__iter__") else [] + + # 检查所有请求的类型是否都被支持 + for requested_type in types: + if requested_type not in content_format: + logger.debug(f"消息不支持类型 '{requested_type}',支持的内容格式: {content_format}") + return False + return True + + except (orjson.JSONDecodeError, AttributeError, TypeError) as e: + logger.debug(f"解析消息格式信息失败: {e}") + + # 备用方案:如果无法从additional_config获取格式信息,使用默认支持的类型 + # 大多数消息至少支持text类型 + default_supported_types = ["text", "emoji"] + for requested_type in types: + if requested_type not in default_supported_types: + logger.debug(f"使用默认类型检查,消息可能不支持类型 '{requested_type}'") + # 对于非基础类型,返回False以避免错误 + if requested_type not in ["text", "emoji", "reply"]: + return False + return True + + def get_priority_mode(self) -> Optional[str]: + """获取优先级模式""" + return self.priority_mode + + def get_priority_info(self) -> Optional[dict]: + """获取优先级信息""" + return self.priority_info + + +@dataclass +class MessageManagerStats(BaseDataModel): + """消息管理器统计信息""" + + total_streams: int = 0 + active_streams: int = 0 + total_unread_messages: int = 0 + total_processed_messages: int = 0 + start_time: float = field(default_factory=time.time) + + @property + def uptime(self) -> float: + """运行时间""" + return time.time() - self.start_time + + +@dataclass +class StreamStats(BaseDataModel): + """聊天流统计信息""" + + stream_id: str + is_active: bool + unread_count: int + history_count: int + last_check_time: float + has_active_task: bool diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 13ef39c1a..c832789b5 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -30,6 +30,7 @@ from src.common.database.sqlalchemy_models import ( Schedule, MaiZoneScheduleStatus, CacheEntries, + UserRelationships, ) from src.common.logger import get_logger @@ -54,6 +55,7 @@ MODEL_MAPPING = { "Schedule": Schedule, "MaiZoneScheduleStatus": MaiZoneScheduleStatus, "CacheEntries": CacheEntries, + "UserRelationships": UserRelationships, } diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 996dd5a45..2469fa642 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -55,7 +55,17 @@ class ChatStreams(Base): user_cardname = Column(Text, nullable=True) energy_value = Column(Float, nullable=True, default=5.0) sleep_pressure = Column(Float, nullable=True, default=0.0) - focus_energy = Column(Float, nullable=True, default=1.0) + focus_energy = Column(Float, nullable=True, default=0.5) + # 动态兴趣度系统字段 + base_interest_energy = Column(Float, nullable=True, default=0.5) + message_interest_total = Column(Float, nullable=True, default=0.0) + message_count = Column(Integer, nullable=True, default=0) + action_count = Column(Integer, nullable=True, default=0) + reply_count = Column(Integer, nullable=True, default=0) + last_interaction_time = Column(Float, nullable=True, default=None) + consecutive_no_reply = Column(Integer, nullable=True, default=0) + # 消息打断系统字段 + interruption_count = Column(Integer, nullable=True, default=0) __table_args__ = ( Index("idx_chatstreams_stream_id", "stream_id"), @@ -165,11 +175,16 @@ class Messages(Base): is_command = Column(Boolean, nullable=False, default=False) is_notify = Column(Boolean, nullable=False, default=False) + # 兴趣度系统字段 + actions = Column(Text, nullable=True) # JSON格式存储动作列表 + should_reply = Column(Boolean, nullable=True, default=False) + __table_args__ = ( Index("idx_messages_message_id", "message_id"), Index("idx_messages_chat_id", "chat_id"), Index("idx_messages_time", "time"), Index("idx_messages_user_id", "user_id"), + Index("idx_messages_should_reply", "should_reply"), ) @@ -300,6 +315,26 @@ class PersonInfo(Base): ) +class BotPersonalityInterests(Base): + """机器人人格兴趣标签模型""" + + __tablename__ = "bot_personality_interests" + + id = Column(Integer, primary_key=True, autoincrement=True) + personality_id = Column(get_string_field(100), nullable=False, index=True) + personality_description = Column(Text, nullable=False) + interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表 + embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002") + version = Column(Integer, nullable=False, default=1) + last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True) + + __table_args__ = ( + Index("idx_botpersonality_personality_id", "personality_id"), + Index("idx_botpersonality_version", "version"), + Index("idx_botpersonality_last_updated", "last_updated"), + ) + + class Memory(Base): """记忆模型""" @@ -722,3 +757,23 @@ class UserPermissions(Base): Index("idx_user_permission", "platform", "user_id", "permission_node"), Index("idx_permission_granted", "permission_node", "granted"), ) + + +class UserRelationships(Base): + """用户关系模型 - 存储用户与bot的关系数据""" + + __tablename__ = "user_relationships" + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID + user_name = Column(get_string_field(100), nullable=True) # 用户名 + relationship_text = Column(Text, nullable=True) # 关系印象描述 + relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1) + last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间 + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间 + + __table_args__ = ( + Index("idx_user_relationship_id", "user_id"), + Index("idx_relationship_score", "relationship_score"), + Index("idx_relationship_updated", "last_updated"), + ) diff --git a/src/common/logger.py b/src/common/logger.py index 761caf9ef..fa5e27d04 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -350,6 +350,10 @@ MODULE_COLORS = { "memory": "\033[38;5;117m", # 天蓝色 "hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读 "action_manager": "\033[38;5;208m", # 橙色,不与replyer重复 + "message_manager": "\033[38;5;27m", # 深蓝色,消息管理器 + "chatter_manager": "\033[38;5;129m", # 紫色,聊天管理器 + "chatter_interest_scoring": "\033[38;5;214m", # 橙黄色,兴趣评分 + "plan_executor": "\033[38;5;172m", # 橙褐色,计划执行器 # 关系系统 "relation": "\033[38;5;139m", # 柔和的紫色,不刺眼 # 聊天相关模块 @@ -551,6 +555,10 @@ MODULE_ALIASES = { "llm_models": "模型", "person_info": "人物", "chat_stream": "聊天流", + "message_manager": "消息管理", + "chatter_manager": "聊天管理", + "chatter_interest_scoring": "兴趣评分", + "plan_executor": "计划执行", "planner": "规划器", "replyer": "言语", "config": "配置", diff --git a/src/common/message/api.py b/src/common/message/api.py index d24574d6e..37b7a7ddc 100644 --- a/src/common/message/api.py +++ b/src/common/message/api.py @@ -23,15 +23,15 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method maim_message_config = global_config.maim_message # 设置基本参数 - + host = os.getenv("HOST", "127.0.0.1") port_str = os.getenv("PORT", "8000") - + try: port = int(port_str) except ValueError: port = 8000 - + kwargs = { "host": host, "port": port, diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 96714db1f..992ad3320 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -22,10 +22,15 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]: """ 将 SQLAlchemy 模型实例转换为字典。 """ - return {col.name: getattr(instance, col.name) for col in instance.__table__.columns} + try: + return {col.name: getattr(instance, col.name) for col in instance.__table__.columns} + except Exception as e: + # 如果对象已经脱离会话,尝试从instance.__dict__中获取数据 + logger.warning(f"从数据库对象获取属性失败,尝试使用__dict__: {e}") + return {col.name: instance.__dict__.get(col.name) for col in instance.__table__.columns} -async def find_messages( +def find_messages( message_filter: dict[str, Any], sort: Optional[List[tuple[str, int]]] = None, limit: int = 0, @@ -46,7 +51,7 @@ async def find_messages( 消息字典列表,如果出错则返回空列表。 """ try: - async with get_db_session() as session: + with get_db_session() as session: query = select(Messages) # 应用过滤器 @@ -96,7 +101,7 @@ async def find_messages( # 获取时间最早的 limit 条记录,已经是正序 query = query.order_by(Messages.time.asc()).limit(limit) try: - results = (await session.execute(query)).scalars().all() + results = session.execute(query).scalars().all() except Exception as e: logger.error(f"执行earliest查询失败: {e}") results = [] @@ -104,7 +109,7 @@ async def find_messages( # 获取时间最晚的 limit 条记录 query = query.order_by(Messages.time.desc()).limit(limit) try: - latest_results = (await session.execute(query)).scalars().all() + latest_results = session.execute(query).scalars().all() # 将结果按时间正序排列 results = sorted(latest_results, key=lambda msg: msg.time) except Exception as e: @@ -128,11 +133,12 @@ async def find_messages( if sort_terms: query = query.order_by(*sort_terms) try: - results = (await session.execute(query)).scalars().all() + results = session.execute(query).scalars().all() except Exception as e: logger.error(f"执行无限制查询失败: {e}") results = [] + # 在会话内将结果转换为字典,避免会话分离错误 return [_model_to_dict(msg) for msg in results] except Exception as e: log_message = ( @@ -143,7 +149,7 @@ async def find_messages( return [] -async def count_messages(message_filter: dict[str, Any]) -> int: +def count_messages(message_filter: dict[str, Any]) -> int: """ 根据提供的过滤器计算消息数量。 @@ -154,7 +160,7 @@ async def count_messages(message_filter: dict[str, Any]) -> int: 符合条件的消息数量,如果出错则返回 0。 """ try: - async with get_db_session() as session: + with get_db_session() as session: query = select(func.count(Messages.id)) # 应用过滤器 @@ -192,7 +198,7 @@ async def count_messages(message_filter: dict[str, Any]) -> int: if conditions: query = query.where(*conditions) - count = (await session.execute(query)).scalar() + count = session.execute(query).scalar() return count or 0 except Exception as e: log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" @@ -201,5 +207,5 @@ async def count_messages(message_filter: dict[str, Any]) -> int: # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 -# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 await session.commit()。 +# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 session.commit()。 # 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。 diff --git a/src/common/remote.py b/src/common/remote.py index 2aa750449..95202f810 100644 --- a/src/common/remote.py +++ b/src/common/remote.py @@ -31,7 +31,9 @@ class TelemetryHeartBeatTask(AsyncTask): self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore """客户端UUID""" - self.private_key_pem: str | None = local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None # type: ignore + self.private_key_pem: str | None = ( + local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None + ) # type: ignore """客户端私钥""" self.info_dict = self._get_sys_info() @@ -61,78 +63,65 @@ class TelemetryHeartBeatTask(AsyncTask): def _generate_signature(self, request_body: dict) -> tuple[str, str]: """ 生成RSA签名 - + Returns: tuple[str, str]: (timestamp, signature_b64) """ if not self.private_key_pem: raise ValueError("私钥未初始化") - + # 生成时间戳 timestamp = datetime.now(timezone.utc).isoformat() - + # 创建签名数据字符串 sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}" - + # 加载私钥 - private_key = serialization.load_pem_private_key( - self.private_key_pem.encode('utf-8'), - password=None - ) - + private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None) + # 确保是RSA私钥 if not isinstance(private_key, rsa.RSAPrivateKey): raise ValueError("私钥必须是RSA格式") - + # 生成签名 signature = private_key.sign( - sign_data.encode('utf-8'), - padding.PSS( - mgf=padding.MGF1(hashes.SHA256()), - salt_length=padding.PSS.MAX_LENGTH - ), - hashes.SHA256() + sign_data.encode("utf-8"), + padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH), + hashes.SHA256(), ) - + # Base64编码 - signature_b64 = base64.b64encode(signature).decode('utf-8') - + signature_b64 = base64.b64encode(signature).decode("utf-8") + return timestamp, signature_b64 def _decrypt_challenge(self, challenge_b64: str) -> str: """ 解密挑战数据 - + Args: challenge_b64: Base64编码的挑战数据 - + Returns: str: 解密后的UUID字符串 """ if not self.private_key_pem: raise ValueError("私钥未初始化") - + # 加载私钥 - private_key = serialization.load_pem_private_key( - self.private_key_pem.encode('utf-8'), - password=None - ) - + private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None) + # 确保是RSA私钥 if not isinstance(private_key, rsa.RSAPrivateKey): raise ValueError("私钥必须是RSA格式") - + # 解密挑战数据 decrypted_bytes = private_key.decrypt( base64.b64decode(challenge_b64), - padding.OAEP( - mgf=padding.MGF1(hashes.SHA256()), - algorithm=hashes.SHA256(), - label=None - ) + padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None), ) - - return decrypted_bytes.decode('utf-8') + + return decrypted_bytes.decode("utf-8") async def _req_uuid(self) -> bool: """ @@ -155,28 +144,26 @@ class TelemetryHeartBeatTask(AsyncTask): if response.status != 200: response_text = await response.text() - logger.error( - f"注册步骤1失败,状态码: {response.status}, 响应内容: {response_text}" - ) + logger.error(f"注册步骤1失败,状态码: {response.status}, 响应内容: {response_text}") raise aiohttp.ClientResponseError( request_info=response.request_info, history=response.history, status=response.status, - message=f"Step1 failed: {response_text}" + message=f"Step1 failed: {response_text}", ) step1_data = await response.json() temp_uuid = step1_data.get("temp_uuid") private_key = step1_data.get("private_key") challenge = step1_data.get("challenge") - + if not all([temp_uuid, private_key, challenge]): logger.error("Step1响应缺少必要字段:temp_uuid, private_key 或 challenge") raise ValueError("Step1响应数据不完整") # 临时保存私钥用于解密 self.private_key_pem = private_key - + # 解密挑战数据 logger.debug("解密挑战数据...") try: @@ -184,21 +171,18 @@ class TelemetryHeartBeatTask(AsyncTask): except Exception as e: logger.error(f"解密挑战数据失败: {e}") raise - + # 验证解密结果 if decrypted_uuid != temp_uuid: logger.error(f"解密结果验证失败: 期望 {temp_uuid}, 实际 {decrypted_uuid}") raise ValueError("解密结果与临时UUID不匹配") - + logger.debug("挑战数据解密成功,开始注册步骤2") # Step 2: 发送解密结果完成注册 async with session.post( f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2", - json={ - "temp_uuid": temp_uuid, - "decrypted_uuid": decrypted_uuid - }, + json={"temp_uuid": temp_uuid, "decrypted_uuid": decrypted_uuid}, timeout=aiohttp.ClientTimeout(total=5), ) as response: logger.debug(f"Step2 Response status: {response.status}") @@ -206,7 +190,7 @@ class TelemetryHeartBeatTask(AsyncTask): if response.status == 200: step2_data = await response.json() mofox_uuid = step2_data.get("mofox_uuid") - + if mofox_uuid: # 将正式UUID和私钥存储到本地 local_storage["mofox_uuid"] = mofox_uuid @@ -225,23 +209,19 @@ class TelemetryHeartBeatTask(AsyncTask): raise ValueError(f"Step2失败: {response_text}") else: response_text = await response.text() - logger.error( - f"注册步骤2失败,状态码: {response.status}, 响应内容: {response_text}" - ) + logger.error(f"注册步骤2失败,状态码: {response.status}, 响应内容: {response_text}") raise aiohttp.ClientResponseError( request_info=response.request_info, history=response.history, status=response.status, - message=f"Step2 failed: {response_text}" + message=f"Step2 failed: {response_text}", ) except Exception as e: import traceback error_msg = str(e) or "未知错误" - logger.warning( - f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}" - ) + logger.warning(f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}") logger.debug(f"完整错误信息: {traceback.format_exc()}") # 请求失败,重试次数+1 @@ -264,13 +244,13 @@ class TelemetryHeartBeatTask(AsyncTask): try: # 生成签名 timestamp, signature = self._generate_signature(self.info_dict) - + headers = { "X-mofox-UUID": self.client_uuid, "X-mofox-Signature": signature, "X-mofox-Timestamp": timestamp, "User-Agent": f"MofoxClient/{self.client_uuid[:8]}", - "Content-Type": "application/json" + "Content-Type": "application/json", } logger.debug(f"正在发送心跳到服务器: {self.server_url}") @@ -347,4 +327,4 @@ class TelemetryHeartBeatTask(AsyncTask): logger.warning("客户端注册失败,跳过此次心跳") return - await self._send_heartbeat() \ No newline at end of file + await self._send_heartbeat() diff --git a/src/common/server.py b/src/common/server.py index a06cf1151..64299274b 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -99,14 +99,13 @@ def get_global_server() -> Server: """获取全局服务器实例""" global global_server if global_server is None: - host = os.getenv("HOST", "127.0.0.1") port_str = os.getenv("PORT", "8000") - + try: port = int(port_str) except ValueError: port = 8000 - + global_server = Server(host=host, port=port) return global_server diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index 0b1984a3c..f7e9fe514 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -137,7 +137,7 @@ class ModelTaskConfig(ValidatedConfigBase): monthly_plan_generator: TaskConfig = Field(..., description="月层计划生成模型配置") emoji_vlm: TaskConfig = Field(..., description="表情包识别模型配置") anti_injection: TaskConfig = Field(..., description="反注入检测专用模型配置") - + relationship_tracker: TaskConfig = Field(..., description="关系追踪模型配置") # 处理配置文件中命名不一致的问题 utils_video: TaskConfig = Field(..., description="视频分析模型配置(兼容配置文件中的命名)") diff --git a/src/config/config.py b/src/config/config.py index ac6204689..c338ed543 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -43,7 +43,8 @@ from src.config.official_configs import ( CrossContextConfig, PermissionConfig, CommandConfig, - PlanningSystemConfig + PlanningSystemConfig, + AffinityFlowConfig, ) from .api_ada_configs import ( @@ -66,7 +67,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template") # 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/ -MMC_VERSION = "0.10.0-alpha-2" +MMC_VERSION = "0.11.0-alpha-1" def get_key_comment(toml_table, key): @@ -417,6 +418,7 @@ class Config(ValidatedConfigBase): cross_context: CrossContextConfig = Field( default_factory=lambda: CrossContextConfig(), description="跨群聊上下文共享配置" ) + affinity_flow: AffinityFlowConfig = Field(default_factory=lambda: AffinityFlowConfig(), description="亲和流配置") class APIAdapterConfig(ValidatedConfigBase): diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 37e055bb3..7afedfae7 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -51,8 +51,12 @@ class PersonalityConfig(ValidatedConfigBase): personality_core: str = Field(..., description="核心人格") personality_side: str = Field(..., description="人格侧写") identity: str = Field(default="", description="身份特征") - background_story: str = Field(default="", description="世界观背景故事,这部分内容会作为背景知识,LLM被指导不应主动复述") - safety_guidelines: List[str] = Field(default_factory=list, description="安全与互动底线,Bot在任何情况下都必须遵守的原则") + background_story: str = Field( + default="", description="世界观背景故事,这部分内容会作为背景知识,LLM被指导不应主动复述" + ) + safety_guidelines: List[str] = Field( + default_factory=list, description="安全与互动底线,Bot在任何情况下都必须遵守的原则" + ) reply_style: str = Field(default="", description="表达风格") prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式") compress_personality: bool = Field(default=True, description="是否压缩人格") @@ -109,7 +113,8 @@ class ChatConfig(ValidatedConfigBase): talk_frequency_adjust: list[list[str]] = Field(default_factory=lambda: [], description="聊天频率调整") focus_value: float = Field(default=1.0, description="专注值") focus_mode_quiet_groups: List[str] = Field( - default_factory=list, description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]' + default_factory=list, + description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]', ) force_reply_private: bool = Field(default=False, description="强制回复私聊") group_chat_mode: Literal["auto", "normal", "focus"] = Field(default="auto", description="群聊模式") @@ -129,6 +134,31 @@ class ChatConfig(ValidatedConfigBase): ) delta_sigma: int = Field(default=120, description="采用正态分布随机时间间隔") + # 消息打断系统配置 + interruption_enabled: bool = Field(default=True, description="是否启用消息打断系统") + interruption_max_limit: int = Field(default=3, ge=0, description="每个聊天流的最大打断次数") + interruption_probability_factor: float = Field( + default=0.8, ge=0.0, le=1.0, description="打断概率因子,当前打断次数/最大打断次数超过此值时触发概率下降" + ) + interruption_afc_reduction: float = Field( + default=0.05, ge=0.0, le=1.0, description="每次连续打断降低的afc阈值数值" + ) + + # 动态消息分发系统配置 + dynamic_distribution_enabled: bool = Field(default=True, description="是否启用动态消息分发周期调整") + dynamic_distribution_base_interval: float = Field( + default=5.0, ge=1.0, le=60.0, description="基础分发间隔(秒)" + ) + dynamic_distribution_min_interval: float = Field( + default=1.0, ge=0.5, le=10.0, description="最小分发间隔(秒)" + ) + dynamic_distribution_max_interval: float = Field( + default=30.0, ge=5.0, le=300.0, description="最大分发间隔(秒)" + ) + dynamic_distribution_jitter_factor: float = Field( + default=0.2, ge=0.0, le=0.5, description="分发间隔随机扰动因子" + ) + def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float: """ 根据当前时间和聊天流获取对应的 talk_frequency @@ -376,6 +406,7 @@ class ExpressionConfig(ValidatedConfigBase): # 如果都没有匹配,返回默认值 return True, True, 1.0 + class ToolConfig(ValidatedConfigBase): """工具配置类""" @@ -510,7 +541,6 @@ class ExperimentalConfig(ValidatedConfigBase): pfc_chatting: bool = Field(default=False, description="启用PFC聊天") - class MaimMessageConfig(ValidatedConfigBase): """maim_message配置类""" @@ -635,8 +665,12 @@ class SleepSystemConfig(ValidatedConfigBase): sleep_by_schedule: bool = Field(default=True, description="是否根据日程表进行睡觉") fixed_sleep_time: str = Field(default="23:00", description="固定的睡觉时间") fixed_wake_up_time: str = Field(default="07:00", description="固定的起床时间") - sleep_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机") - wake_up_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机") + sleep_time_offset_minutes: int = Field( + default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机" + ) + wake_up_time_offset_minutes: int = Field( + default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机" + ) wakeup_threshold: float = Field(default=15.0, ge=1.0, description="唤醒阈值,达到此值时会被唤醒") private_message_increment: float = Field(default=3.0, ge=0.1, description="私聊消息增加的唤醒度") group_mention_increment: float = Field(default=2.0, ge=0.1, description="群聊艾特增加的唤醒度") @@ -651,10 +685,10 @@ class SleepSystemConfig(ValidatedConfigBase): # --- 失眠机制相关参数 --- enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统") insomnia_trigger_delay_minutes: List[int] = Field( - default_factory=lambda:[30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)" + default_factory=lambda: [30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)" ) insomnia_duration_minutes: List[int] = Field( - default_factory=lambda:[15, 45], description="单次失眠状态的持续时间范围(分钟)" + default_factory=lambda: [15, 45], description="单次失眠状态的持续时间范围(分钟)" ) sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值") deep_sleep_threshold: float = Field(default=80.0, description="进入“深度睡眠”的睡眠压力阈值") @@ -690,6 +724,8 @@ class CrossContextConfig(ValidatedConfigBase): enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能") groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表") + + class CommandConfig(ValidatedConfigBase): """命令系统配置类""" @@ -703,3 +739,34 @@ class PermissionConfig(ValidatedConfigBase): master_users: List[List[str]] = Field( default_factory=list, description="Master用户列表,格式: [[platform, user_id], ...]" ) + + +class AffinityFlowConfig(ValidatedConfigBase): + """亲和流配置类(兴趣度评分和人物关系系统)""" + + # 兴趣评分系统参数 + reply_action_interest_threshold: float = Field(default=0.4, description="回复动作兴趣阈值") + non_reply_action_interest_threshold: float = Field(default=0.2, description="非回复动作兴趣阈值") + high_match_interest_threshold: float = Field(default=0.8, description="高匹配兴趣阈值") + medium_match_interest_threshold: float = Field(default=0.5, description="中匹配兴趣阈值") + low_match_interest_threshold: float = Field(default=0.2, description="低匹配兴趣阈值") + high_match_keyword_multiplier: float = Field(default=1.5, description="高匹配关键词兴趣倍率") + medium_match_keyword_multiplier: float = Field(default=1.2, description="中匹配关键词兴趣倍率") + low_match_keyword_multiplier: float = Field(default=1.0, description="低匹配关键词兴趣倍率") + match_count_bonus: float = Field(default=0.1, description="匹配数关键词加成值") + max_match_bonus: float = Field(default=0.5, description="最大匹配数加成值") + + # 回复决策系统参数 + no_reply_threshold_adjustment: float = Field(default=0.1, description="不回复兴趣阈值调整值") + reply_cooldown_reduction: int = Field(default=2, description="回复后减少的不回复计数") + max_no_reply_count: int = Field(default=5, description="最大不回复计数次数") + + # 综合评分权重 + keyword_match_weight: float = Field(default=0.4, description="兴趣关键词匹配度权重") + mention_bot_weight: float = Field(default=0.3, description="提及bot分数权重") + relationship_weight: float = Field(default=0.3, description="人物关系分数权重") + + # 提及bot相关参数 + mention_bot_adjustment_threshold: float = Field(default=0.3, description="提及bot后的调整阈值") + mention_bot_interest_score: float = Field(default=0.6, description="提及bot的兴趣分") + base_relationship_score: float = Field(default=0.5, description="基础人物关系分") diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index a2e0f2621..a4a106387 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -64,6 +64,9 @@ class Individuality: else: logger.error("人设构建失败") + # 初始化智能兴趣系统 + await self._initialize_smart_interest_system(personality_result, identity_result) + # 如果任何一个发生变化,都需要清空数据库中的info_list(因为这影响整体人设) if personality_changed or identity_changed: logger.info("将清空数据库中原有的关键词缓存") @@ -75,6 +78,21 @@ class Individuality: } await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data) + async def _initialize_smart_interest_system(self, personality_result: str, identity_result: str): + """初始化智能兴趣系统""" + # 组合完整的人设描述 + full_personality = f"{personality_result},{identity_result}" + + # 获取全局兴趣评分系统实例 + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system as interest_scoring_system + + # 初始化智能兴趣系统 + await interest_scoring_system.initialize_smart_interests( + personality_description=full_personality, personality_id=self.bot_person_id + ) + + logger.info("智能兴趣系统初始化完成") + async def get_personality_block(self) -> str: bot_name = global_config.bot.nickname if global_config.bot.alias_names: diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index bf23f144a..34949e968 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -145,9 +145,9 @@ class LLMUsageRecorder: LLM使用情况记录器(SQLAlchemy版本) """ - @staticmethod async def record_usage_to_database( - model_info: ModelInfo, + self, + model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, @@ -161,7 +161,7 @@ class LLMUsageRecorder: session = None try: # 使用 SQLAlchemy 会话创建记录 - async with get_db_session() as session: + with get_db_session() as session: usage_record = LLMUsage( model_name=model_info.model_identifier, model_assign_name=model_info.name, @@ -179,7 +179,7 @@ class LLMUsageRecorder: ) session.add(usage_record) - await session.commit() + session.commit() logger.debug( f"Token使用情况 - 模型: {model_usage.model_name}, " diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 3efa9cd2d..cf2a7cb1c 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -# -*- coding: utf-8 -*- """ @desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。 它被设计为一个高度容错和可扩展的系统,包含以下主要组件: @@ -892,7 +891,7 @@ class LLMRequest: max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, ) - self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions") + await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions") if not response.content and not response.tool_calls: if raise_when_empty: @@ -917,14 +916,14 @@ class LLMRequest: embedding_input=embedding_input ) - self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings") + await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings") if not response.embedding: raise RuntimeError("获取embedding失败") return response.embedding, model_info.name - def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str): + async def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str): """ 记录模型使用情况。 diff --git a/src/main.py b/src/main.py index 734502271..1ff96935c 100644 --- a/src/main.py +++ b/src/main.py @@ -1,35 +1,40 @@ # 再用这个就写一行注释来混提交的我直接全部🌿飞😡 import asyncio +import time import signal import sys -import time +from functools import partial +import traceback +from typing import Dict, Any from maim_message import MessageServer -from rich.traceback import install -from src.chat.emoji_system.emoji_manager import get_emoji_manager -from src.chat.memory_system.Hippocampus import hippocampus_manager -from src.chat.message_receive.bot import chat_bot -from src.chat.message_receive.chat_stream import get_chat_manager -from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask -from src.common.logger import get_logger -# 导入消息API和traceback模块 -from src.common.message import get_global_api from src.common.remote import TelemetryHeartBeatTask -from src.common.server import get_global_server, Server -from src.config.config import global_config -from src.individuality.individuality import get_individuality, Individuality from src.manager.async_task_manager import async_task_manager +from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask +from src.chat.emoji_system.emoji_manager import get_emoji_manager +from src.chat.message_receive.chat_stream import get_chat_manager +from src.config.config import global_config +from src.chat.message_receive.bot import chat_bot +from src.common.logger import get_logger +from src.individuality.individuality import get_individuality, Individuality +from src.common.server import get_global_server, Server from src.mood.mood_manager import mood_manager -from src.plugin_system.base.component_types import EventType +from rich.traceback import install +from src.schedule.schedule_manager import schedule_manager +from src.schedule.monthly_plan_manager import monthly_plan_manager from src.plugin_system.core.event_manager import event_manager -from src.plugin_system.core.plugin_hot_reload import hot_reload_manager +from src.plugin_system.base.component_types import EventType +# from src.api.main import start_api_server + # 导入新的插件管理器和热重载管理器 from src.plugin_system.core.plugin_manager import plugin_manager -from src.schedule.monthly_plan_manager import monthly_plan_manager -from src.schedule.schedule_manager import schedule_manager +from src.plugin_system.core.plugin_hot_reload import hot_reload_manager -# from src.api.main import start_api_server +# 导入消息API和traceback模块 +from src.common.message import get_global_api + +from src.chat.memory_system.Hippocampus import hippocampus_manager if not global_config.memory.enable_memory: import src.chat.memory_system.Hippocampus as hippocampus_module @@ -38,11 +43,7 @@ if not global_config.memory.enable_memory: def initialize(self): pass - async def initialize_async(self): - pass - - @staticmethod - def get_hippocampus(): + def get_hippocampus(self): return None async def build_memory(self): @@ -54,9 +55,9 @@ if not global_config.memory.enable_memory: async def consolidate_memory(self): pass - @staticmethod async def get_memory_from_text( - text: str, + self, + text: str, max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3, @@ -64,24 +65,20 @@ if not global_config.memory.enable_memory: ) -> list: return [] - @staticmethod async def get_memory_from_topic( - valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 + self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 ) -> list: return [] - @staticmethod async def get_activate_from_text( - text: str, max_depth: int = 3, fast_retrieval: bool = False + self, text: str, max_depth: int = 3, fast_retrieval: bool = False ) -> tuple[float, list[str]]: return 0.0, [] - @staticmethod - def get_memory_from_keyword(keyword: str, max_depth: int = 2) -> list: + def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list: return [] - @staticmethod - def get_all_node_names() -> list: + def get_all_node_names(self) -> list: return [] hippocampus_module.hippocampus_manager = MockHippocampusManager() @@ -93,6 +90,20 @@ install(extra_lines=3) logger = get_logger("main") +def _task_done_callback(task: asyncio.Task, message_id: str, start_time: float): + """后台任务完成时的回调函数""" + end_time = time.time() + duration = end_time - start_time + try: + task.result() # 如果任务有异常,这里会重新抛出 + logger.debug(f"消息 {message_id} 的后台任务 (ID: {id(task)}) 已成功完成, 耗时: {duration:.2f}s") + except asyncio.CancelledError: + logger.warning(f"消息 {message_id} 的后台任务 (ID: {id(task)}) 被取消, 耗时: {duration:.2f}s") + except Exception: + logger.error(f"处理消息 {message_id} 的后台任务 (ID: {id(task)}) 出现未捕获的异常, 耗时: {duration:.2f}s:") + logger.error(traceback.format_exc()) + + class MainSystem: def __init__(self): self.hippocampus_manager = hippocampus_manager @@ -117,15 +128,28 @@ class MainSystem: signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - @staticmethod - def _cleanup(): + def _cleanup(self): """清理资源""" + try: + # 停止消息管理器 + from src.chat.message_manager import message_manager + import asyncio + + loop = asyncio.get_event_loop() + if loop.is_running(): + asyncio.create_task(message_manager.stop()) + else: + loop.run_until_complete(message_manager.stop()) + logger.info("🛑 消息管理器已停止") + except Exception as e: + logger.error(f"停止消息管理器时出错: {e}") + try: # 停止消息重组器 from src.plugin_system.core.event_manager import event_manager from src.plugin_system import EventType - import asyncio - asyncio.run(event_manager.trigger_event(EventType.ON_STOP,permission_group="SYSTEM")) + + asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM")) from src.utils.message_chunker import reassembler loop = asyncio.get_event_loop() @@ -159,6 +183,20 @@ class MainSystem: except Exception as e: logger.error(f"停止记忆管理器时出错: {e}") + async def _message_process_wrapper(self, message_data: Dict[str, Any]): + """并行处理消息的包装器""" + try: + start_time = time.time() + message_id = message_data.get("message_info", {}).get("message_id", "UNKNOWN") + # 创建后台任务 + task = asyncio.create_task(chat_bot.message_process(message_data)) + logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})") + # 添加一个回调函数,当任务完成时,它会被调用 + task.add_done_callback(partial(_task_done_callback, message_id=message_id, start_time=start_time)) + except Exception: + logger.error("在创建消息处理任务时发生严重错误:") + logger.error(traceback.format_exc()) + async def initialize(self): """初始化系统组件""" logger.info(f"正在唤醒{global_config.bot.nickname}......") @@ -211,7 +249,7 @@ MoFox_Bot(第三方修改版) # 添加统计信息输出任务 await async_task_manager.add_task(StatisticOutputTask()) - + # 添加遥测心跳任务 await async_task_manager.add_task(TelemetryHeartBeatTask()) @@ -223,7 +261,6 @@ MoFox_Bot(第三方修改版) from src.plugin_system.apis.permission_api import permission_api permission_manager = PermissionManager() - await permission_manager.initialize() permission_api.set_permission_manager(permission_manager) logger.info("权限管理器初始化成功") @@ -244,6 +281,18 @@ MoFox_Bot(第三方修改版) get_emoji_manager().initialize() logger.info("表情包管理器初始化成功") + # 初始化回复后关系追踪系统 + try: + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system + from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker + + relationship_tracker = ChatterRelationshipTracker(interest_scoring_system=chatter_interest_scoring_system) + chatter_interest_scoring_system.relationship_tracker = relationship_tracker + logger.info("回复后关系追踪系统初始化成功") + except Exception as e: + logger.error(f"回复后关系追踪系统初始化失败: {e}") + relationship_tracker = None + # 启动情绪管理器 await mood_manager.start() logger.info("情绪管理器初始化成功") @@ -256,11 +305,12 @@ MoFox_Bot(第三方修改版) logger.info("聊天管理器初始化成功") # 初始化记忆系统 - await self.hippocampus_manager.initialize_async() + self.hippocampus_manager.initialize() logger.info("记忆系统初始化成功") # 初始化LPMM知识库 from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge + initialize_lpmm_knowledge() logger.info("LPMM知识库初始化成功") @@ -276,7 +326,7 @@ MoFox_Bot(第三方修改版) # await asyncio.sleep(0.5) #防止logger输出飞了 # 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中 - self.app.register_message_handler(chat_bot.message_process) + self.app.register_message_handler(self._message_process_wrapper) # 启动消息重组器的清理任务 from src.utils.message_chunker import reassembler @@ -284,6 +334,12 @@ MoFox_Bot(第三方修改版) await reassembler.start_cleanup_task() logger.info("消息重组器已启动") + # 启动消息管理器 + from src.chat.message_manager import message_manager + + await message_manager.start() + logger.info("消息管理器已启动") + # 初始化个体特征 await self.individuality.initialize() @@ -291,7 +347,7 @@ MoFox_Bot(第三方修改版) if global_config.planning_system.monthly_plan_enable: logger.info("正在初始化月度计划管理器...") try: - await monthly_plan_manager.initialize() + await monthly_plan_manager.start_monthly_plan_generation() logger.info("月度计划管理器初始化成功") except Exception as e: logger.error(f"月度计划管理器初始化失败: {e}") @@ -299,7 +355,8 @@ MoFox_Bot(第三方修改版) # 初始化日程管理器 if global_config.planning_system.schedule_enable: logger.info("日程表功能已启用,正在初始化管理器...") - await schedule_manager.initialize() + await schedule_manager.load_or_generate_today_schedule() + await schedule_manager.start_daily_schedule_generation() logger.info("日程表管理器初始化成功。") try: diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index 5138a7d5d..caba99ad6 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -5,6 +5,7 @@ import time from src.common.logger import get_logger from src.config.config import global_config, model_config from src.chat.message_receive.message import MessageRecv +from src.common.data_models.database_data_model import DatabaseMessages from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive @@ -65,7 +66,7 @@ class ChatMood: self.last_change_time: float = 0 - async def update_mood_by_message(self, message: MessageRecv, interested_rate: float): + async def update_mood_by_message(self, message: MessageRecv | DatabaseMessages, interested_rate: float): # 如果当前聊天处于失眠状态,则锁定情绪,不允许更新 if self.chat_id in mood_manager.insomnia_chats: logger.debug(f"{self.log_prefix} 处于失眠状态,情绪已锁定,跳过更新。") @@ -73,7 +74,13 @@ class ChatMood: self.regression_count = 0 - during_last_time = message.message_info.time - self.last_change_time # type: ignore + # 处理不同类型的消息对象 + if isinstance(message, MessageRecv): + message_time = message.message_info.time + else: # DatabaseMessages + message_time = message.time + + during_last_time = message_time - self.last_change_time base_probability = 0.05 time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time)) @@ -96,16 +103,14 @@ class ChatMood: logger.debug( f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}" ) - - message_time: float = message.message_info.time # type: ignore - message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, limit=int(global_config.chat.max_context_size / 3), limit_mode="last", ) - chat_talking_prompt = await build_readable_messages( + chat_talking_prompt = build_readable_messages( message_list_before_now, replace_bot_name=True, merge_messages=False, @@ -135,26 +140,26 @@ class ChatMood: prompt=prompt, temperature=0.7 ) if global_config.debug.show_prompt: - logger.info(f"{self.log_prefix} prompt: {prompt}") - logger.info(f"{self.log_prefix} response: {response}") - logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}") + logger.debug(f"{self.log_prefix} prompt: {prompt}") + logger.debug(f"{self.log_prefix} response: {response}") + logger.debug(f"{self.log_prefix} reasoning_content: {reasoning_content}") logger.info(f"{self.log_prefix} 情绪状态更新为: {response}") self.mood_state = response self.last_change_time = message_time - + async def regress_mood(self): message_time = time.time() - message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, limit=15, limit_mode="last", ) - chat_talking_prompt = await build_readable_messages( + chat_talking_prompt = build_readable_messages( message_list_before_now, replace_bot_name=True, merge_messages=False, @@ -185,9 +190,9 @@ class ChatMood: ) if global_config.debug.show_prompt: - logger.info(f"{self.log_prefix} prompt: {prompt}") - logger.info(f"{self.log_prefix} response: {response}") - logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}") + logger.debug(f"{self.log_prefix} prompt: {prompt}") + logger.debug(f"{self.log_prefix} response: {response}") + logger.debug(f"{self.log_prefix} reasoning_content: {reasoning_content}") logger.info(f"{self.log_prefix} 情绪状态转变为: {response}") diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index f5bf8a515..cc84b8967 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -94,10 +94,51 @@ class PersonInfoManager: if "-" in platform: platform = platform.split("-")[1] - + # 在此处打一个补丁,如果platform为qq,尝试生成id后检查是否存在,如果不存在,则将平台换为napcat后再次检查,如果存在,则更新原id为platform为qq的id components = [platform, str(user_id)] key = "_".join(components) - return hashlib.md5(key.encode()).hexdigest() + + # 如果不是 qq 平台,直接返回计算的 id + if platform != "qq": + return hashlib.md5(key.encode()).hexdigest() + + qq_id = hashlib.md5(key.encode()).hexdigest() + + # 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回 + def _db_check_and_migrate_sync(p_id: str, raw_user_id: str): + try: + with get_db_session() as session: + # 检查 qq_id 是否存在 + existing_qq = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + if existing_qq: + return p_id + + # 如果 qq_id 不存在,尝试使用 napcat 作为平台生成对应 id 并检查 + nap_components = ["napcat", str(raw_user_id)] + nap_key = "_".join(nap_components) + nap_id = hashlib.md5(nap_key.encode()).hexdigest() + + existing_nap = session.execute(select(PersonInfo).where(PersonInfo.person_id == nap_id)).scalar() + if not existing_nap: + # napcat 也不存在,返回 qq_id(未命中) + return p_id + + # napcat 存在,迁移该记录:更新 person_id 与 platform -> qq + try: + # 更新现有 napcat 记录 + existing_nap.person_id = p_id + existing_nap.platform = "qq" + existing_nap.user_id = str(raw_user_id) + session.commit() + return p_id + except Exception: + session.rollback() + return p_id + except Exception as e: + logger.error(f"检查/迁移 napcat->qq 时出错: {e}") + return p_id + + return _db_check_and_migrate_sync(qq_id, user_id) async def is_person_known(self, platform: str, user_id: int): """判断是否认识某人""" @@ -127,7 +168,28 @@ class PersonInfoManager: except Exception as e: logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}") return "" - + + @staticmethod + async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str): + """判断是否认识某人""" + person_id = PersonInfoManager.get_person_id(platform, user_id) + # 生成唯一的 person_name + person_info_manager = get_person_info_manager() + unique_nickname = await person_info_manager._generate_unique_person_name(user_nickname) + data = { + "platform": platform, + "user_id": user_id, + "nickname": user_nickname, + "konw_time": int(time.time()), + "person_name": unique_nickname, # 使用唯一的 person_name + } + # 先创建用户基本信息,使用安全创建方法避免竞态条件 + await person_info_manager._safe_create_person_info(person_id=person_id, data=data) + # 更新昵称 + await person_info_manager.update_one_field( + person_id=person_id, field_name="nickname", value=user_nickname, data=data + ) + @staticmethod async def create_person_info(person_id: str, data: Optional[dict] = None): """创建一个项""" @@ -155,16 +217,16 @@ class PersonInfoManager: # Ensure person_id is correctly set from the argument final_data["person_id"] = person_id # 你们的英文注释是何意味? - + # 检查并修复关键字段为None的情况喵 if final_data.get("user_id") is None: logger.warning(f"user_id为None,使用'unknown'作为默认值 person_id={person_id}") final_data["user_id"] = "unknown" - + if final_data.get("platform") is None: logger.warning(f"platform为None,使用'unknown'作为默认值 person_id={person_id}") final_data["platform"] = "unknown" - + # 这里的目的是为了防止在识别出错的情况下有一个最小回退,不只是针对@消息识别成视频后的报错问题 # Serialize JSON fields @@ -215,12 +277,12 @@ class PersonInfoManager: # Ensure person_id is correctly set from the argument final_data["person_id"] = person_id - + # 检查并修复关键字段为None的情况 if final_data.get("user_id") is None: logger.warning(f"user_id为None,使用'unknown'作为默认值 person_id={person_id}") final_data["user_id"] = "unknown" - + if final_data.get("platform") is None: logger.warning(f"platform为None,使用'unknown'作为默认值 person_id={person_id}") final_data["platform"] = "unknown" @@ -315,12 +377,12 @@ class PersonInfoManager: creation_data["platform"] = data["platform"] if data and "user_id" in data: creation_data["user_id"] = data["user_id"] - + # 额外检查关键字段,如果为None则使用默认值 if creation_data.get("user_id") is None: logger.warning(f"创建用户时user_id为None,使用'unknown'作为默认值 person_id={person_id}") creation_data["user_id"] = "unknown" - + if creation_data.get("platform") is None: logger.warning(f"创建用户时platform为None,使用'unknown'作为默认值 person_id={person_id}") creation_data["platform"] = "unknown" diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 4b25f6b14..89632dd73 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -94,90 +94,144 @@ class RelationshipFetcher: if not self.info_fetched_cache[person_id]: del self.info_fetched_cache[person_id] - async def build_relation_info(self, person_id, points_num=3): + async def build_relation_info(self, person_id, points_num=5): + """构建详细的人物关系信息,包含从数据库中查询的丰富关系描述""" # 清理过期的信息缓存 self._cleanup_expired_cache() person_info_manager = get_person_info_manager() - person_info = await person_info_manager.get_values( - person_id, ["person_name", "short_impression", "nickname", "platform", "points"] - ) - person_name = person_info.get("person_name") - short_impression = person_info.get("short_impression") - nickname_str = person_info.get("nickname") - platform = person_info.get("platform") + person_name = await person_info_manager.get_value(person_id, "person_name") + short_impression = await person_info_manager.get_value(person_id, "short_impression") + full_impression = await person_info_manager.get_value(person_id, "impression") + attitude = await person_info_manager.get_value(person_id, "attitude") or 50 - if person_name == nickname_str and not short_impression: - return "" + nickname_str = await person_info_manager.get_value(person_id, "nickname") + platform = await person_info_manager.get_value(person_id, "platform") + know_times = await person_info_manager.get_value(person_id, "know_times") or 0 + know_since = await person_info_manager.get_value(person_id, "know_since") + last_know = await person_info_manager.get_value(person_id, "last_know") - current_points = person_info.get("points") - if isinstance(current_points, str): - current_points = orjson.loads(current_points) + # 如果用户没有基本信息,返回默认描述 + if person_name == nickname_str and not short_impression and not full_impression: + return f"你完全不认识{person_name},这是你们第一次交流。" + + # 获取用户特征点 + current_points = await person_info_manager.get_value(person_id, "points") or [] + forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or [] + + # 按时间排序并选择最有代表性的特征点 + all_points = current_points + forgotten_points + if all_points: + # 按权重和时效性综合排序 + all_points.sort( + key=lambda x: (float(x[1]) if len(x) > 1 else 0, float(x[2]) if len(x) > 2 else 0), reverse=True + ) + selected_points = all_points[:points_num] + points_text = "\n".join([f"- {point[0]}({point[2]})" for point in selected_points if len(point) > 2]) else: - current_points = current_points or [] + points_text = "" - # 按时间排序forgotten_points - current_points.sort(key=lambda x: x[2]) - # 按权重加权随机抽取最多3个不重复的points,point[1]的值在1-10之间,权重越高被抽到概率越大 - if len(current_points) > points_num: - # point[1] 取值范围1-10,直接作为权重 - weights = [max(1, min(10, int(point[1]))) for point in current_points] - # 使用加权采样不放回,保证不重复 - indices = list(range(len(current_points))) - points = [] - for _ in range(points_num): - if not indices: - break - sub_weights = [weights[i] for i in indices] - chosen_idx = random.choices(indices, weights=sub_weights, k=1)[0] - points.append(current_points[chosen_idx]) - indices.remove(chosen_idx) + # 构建详细的关系描述 + relation_parts = [] + + # 1. 基本信息 + if nickname_str and person_name != nickname_str: + relation_parts.append(f"用户{person_name}在{platform}平台的昵称是{nickname_str}") + + # 2. 认识时间和频率 + if know_since: + from datetime import datetime + + know_time = datetime.fromtimestamp(know_since).strftime("%Y年%m月%d日") + relation_parts.append(f"你从{know_time}开始认识{person_name}") + + if know_times > 0: + relation_parts.append(f"你们已经交流过{int(know_times)}次") + + if last_know: + from datetime import datetime + + last_time = datetime.fromtimestamp(last_know).strftime("%m月%d日") + relation_parts.append(f"最近一次交流是在{last_time}") + + # 3. 态度和印象 + attitude_desc = self._get_attitude_description(attitude) + relation_parts.append(f"你对{person_name}的态度是{attitude_desc}") + + if short_impression: + relation_parts.append(f"你对ta的总体印象:{short_impression}") + + if full_impression: + relation_parts.append(f"更详细的了解:{full_impression}") + + # 4. 特征点和记忆 + if points_text: + relation_parts.append(f"你记得关于{person_name}的一些事情:\n{points_text}") + + # 5. 从UserRelationships表获取额外关系信息 + try: + from src.common.database.sqlalchemy_database_api import db_query + from src.common.database.sqlalchemy_models import UserRelationships + + # 查询用户关系数据 + relationships = await db_query( + UserRelationships, + filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))], + limit=1, + ) + + if relationships: + rel_data = relationships[0] + if rel_data.relationship_text: + relation_parts.append(f"关系记录:{rel_data.relationship_text}") + if rel_data.relationship_score: + score_desc = self._get_relationship_score_description(rel_data.relationship_score) + relation_parts.append(f"关系亲密程度:{score_desc}") + + except Exception as e: + logger.debug(f"查询UserRelationships表失败: {e}") + + # 构建最终的关系信息字符串 + if relation_parts: + relation_info = f"关于{person_name},你知道以下信息:\n" + "\n".join( + [f"• {part}" for part in relation_parts] + ) else: - points = current_points - - # 构建points文本 - points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points]) - - nickname_str = "" - if person_name != nickname_str: - nickname_str = f"(ta在{platform}上的昵称是{nickname_str})" - - relation_info = "" - - if short_impression and relation_info: - if points_text: - relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}。具体来说:{relation_info}。你还记得ta最近做的事:{points_text}" - else: - relation_info = ( - f"你对{person_name}的印象是{nickname_str}:{short_impression}。具体来说:{relation_info}" - ) - elif short_impression: - if points_text: - relation_info = ( - f"你对{person_name}的印象是{nickname_str}:{short_impression}。你还记得ta最近做的事:{points_text}" - ) - else: - relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}" - elif relation_info: - if points_text: - relation_info = ( - f"你对{person_name}的了解{nickname_str}:{relation_info}。你还记得ta最近做的事:{points_text}" - ) - else: - relation_info = f"你对{person_name}的了解{nickname_str}:{relation_info}" - elif points_text: - relation_info = f"你记得{person_name}{nickname_str}最近做的事:{points_text}" - else: - relation_info = "" + relation_info = f"你对{person_name}了解不多,这是比较初步的交流。" return relation_info + def _get_attitude_description(self, attitude: int) -> str: + """根据态度分数返回描述性文字""" + if attitude >= 80: + return "非常喜欢和欣赏" + elif attitude >= 60: + return "比较有好感" + elif attitude >= 40: + return "中立态度" + elif attitude >= 20: + return "有些反感" + else: + return "非常厌恶" + + def _get_relationship_score_description(self, score: float) -> str: + """根据关系分数返回描述性文字""" + if score >= 0.8: + return "非常亲密的好友" + elif score >= 0.6: + return "关系不错的朋友" + elif score >= 0.4: + return "普通熟人" + elif score >= 0.2: + return "认识但不熟悉" + else: + return "陌生人" + async def _build_fetch_query(self, person_id, target_message, chat_history): nickname_str = ",".join(global_config.bot.alias_names) name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。" person_info_manager = get_person_info_manager() - person_info = await person_info_manager.get_values(person_id, ["person_name"]) - person_name: str = person_info.get("person_name") # type: ignore + person_name: str = await person_info_manager.get_value(person_id, "person_name") # type: ignore info_cache_block = self._build_info_cache_block() @@ -259,8 +313,7 @@ class RelationshipFetcher: person_info_manager = get_person_info_manager() # 首先检查 info_list 缓存 - person_info = await person_info_manager.get_values(person_id, ["info_list"]) - info_list = person_info.get("info_list") or [] + info_list = await person_info_manager.get_value(person_id, "info_list") or [] cached_info = None # 查找对应的 info_type @@ -287,9 +340,8 @@ class RelationshipFetcher: # 如果缓存中没有,尝试从用户档案中提取 try: - person_info = await person_info_manager.get_values(person_id, ["impression", "points"]) - person_impression = person_info.get("impression") - points = person_info.get("points") + person_impression = await person_info_manager.get_value(person_id, "impression") + points = await person_info_manager.get_value(person_id, "points") # 构建印象信息块 if person_impression: @@ -381,8 +433,7 @@ class RelationshipFetcher: person_info_manager = get_person_info_manager() # 获取现有的 info_list - person_info = await person_info_manager.get_values(person_id, ["info_list"]) - info_list = person_info.get("info_list") or [] + info_list = await person_info_manager.get_value(person_id, "info_list") or [] # 查找是否已存在相同 info_type 的记录 found_index = -1 diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index e74044866..dcdce65e2 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -121,6 +121,13 @@ async def generate_reply( if not extra_info and action_data: extra_info = action_data.get("extra_info", "") + # 如果action_data中有thinking,添加到extra_info中 + if action_data and (thinking := action_data.get("thinking")): + if extra_info: + extra_info += f"\n\n思考过程:{thinking}" + else: + extra_info = f"思考过程:{thinking}" + # 调用回复器生成回复 success, llm_response_dict, prompt = await replyer.generate_reply_with_context( reply_to=reply_to, diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 390397a22..b21bd6b3e 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -80,7 +80,7 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa message_info = { "platform": message_dict.get("chat_info_platform", ""), - "message_id": message_dict.get("message_id"), + "message_id": message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id"), "time": message_dict.get("time"), "group_info": group_info, "user_info": user_info, @@ -89,15 +89,16 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa "template_info": template_info, } - message_dict = { + new_message_dict = { "message_info": message_info, "raw_message": message_dict.get("processed_plain_text"), "processed_plain_text": message_dict.get("processed_plain_text"), } - message_recv = MessageRecv(message_dict) + message_recv = MessageRecv(new_message_dict) logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}") + logger.info(message_recv) return message_recv @@ -246,7 +247,7 @@ async def text_to_stream( typing: bool = False, reply_to: str = "", reply_to_message: Optional[Dict[str, Any]] = None, - set_reply: bool = False, + set_reply: bool = True, storage_message: bool = True, ) -> bool: """向指定流发送文本消息 @@ -275,7 +276,7 @@ async def text_to_stream( async def emoji_to_stream( - emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False + emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = True ) -> bool: """向指定流发送表情包 @@ -293,7 +294,7 @@ async def emoji_to_stream( async def image_to_stream( - image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False + image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = True ) -> bool: """向指定流发送图片 @@ -315,7 +316,7 @@ async def command_to_stream( stream_id: str, storage_message: bool = True, display_message: str = "", - set_reply: bool = False, + set_reply: bool = True, ) -> bool: """向指定流发送命令 @@ -340,7 +341,7 @@ async def custom_to_stream( typing: bool = False, reply_to: str = "", reply_to_message: Optional[Dict[str, Any]] = None, - set_reply: bool = False, + set_reply: bool = True, storage_message: bool = True, show_log: bool = True, ) -> bool: diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 9400032f8..725619adb 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -93,7 +93,6 @@ class BaseAction(ABC): self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy() self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL) - # ============================================================================= # 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解) # ============================================================================= @@ -398,6 +397,7 @@ class BaseAction(ABC): try: # 1. 从注册中心获取Action类 from src.plugin_system.core.component_registry import component_registry + action_class = component_registry.get_component_class(action_name, ComponentType.ACTION) if not action_class: logger.error(f"{log_prefix} 未找到Action: {action_name}") @@ -406,7 +406,7 @@ class BaseAction(ABC): # 2. 准备实例化参数 # 复用当前Action的大部分上下文信息 called_action_data = action_data if action_data is not None else self.action_data - + component_info = component_registry.get_component_info(action_name, ComponentType.ACTION) if not component_info: logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}") diff --git a/src/plugin_system/base/base_chatter.py b/src/plugin_system/base/base_chatter.py new file mode 100644 index 000000000..1bdb79c31 --- /dev/null +++ b/src/plugin_system/base/base_chatter.py @@ -0,0 +1,55 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, TYPE_CHECKING +from src.common.data_models.message_manager_data_model import StreamContext +from .component_types import ChatType +from src.plugin_system.base.component_types import ChatterInfo, ComponentType + +if TYPE_CHECKING: + from src.chat.planner_actions.action_manager import ChatterActionManager + from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner as ActionPlanner + +class BaseChatter(ABC): + chatter_name: str = "" + """Chatter组件的名称""" + chatter_description: str = "" + """Chatter组件的描述""" + chat_types: List[ChatType] = [ChatType.PRIVATE, ChatType.GROUP] + + def __init__(self, stream_id: str, action_manager: 'ChatterActionManager'): + """ + 初始化聊天处理器 + + Args: + stream_id: 聊天流ID + action_manager: 动作管理器 + """ + self.stream_id = stream_id + self.action_manager = action_manager + + @abstractmethod + async def execute(self, context: StreamContext) -> dict: + """ + 执行聊天处理流程 + + Args: + context: StreamContext对象,包含聊天流的所有消息信息 + + Returns: + 处理结果字典 + """ + pass + + @classmethod + def get_chatter_info(cls) -> "ChatterInfo": + """从类属性生成ChatterInfo + Returns: + ChatterInfo对象 + """ + + return ChatterInfo( + name=cls.chatter_name, + description=cls.chatter_description or "No description provided.", + chat_type_allow=cls.chat_types[0], + component_type=ComponentType.CHATTER, + ) + diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 2bcdca8c5..212634d5d 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -73,7 +73,7 @@ class BaseCommand(ABC): return True # 检查是否为群聊消息 - is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message + is_group = self.message.message_info.group_info if self.chat_type_allow == ChatType.GROUP and is_group: return True diff --git a/src/plugin_system/base/base_events_handler.py b/src/plugin_system/base/base_events_handler.py index 6b8ed1d73..517de92c2 100644 --- a/src/plugin_system/base/base_events_handler.py +++ b/src/plugin_system/base/base_events_handler.py @@ -98,7 +98,7 @@ class BaseEventHandler(ABC): weight=cls.weight, intercept_message=cls.intercept_message, ) - + def set_plugin_name(self, plugin_name: str) -> None: """设置插件名称 @@ -107,9 +107,9 @@ class BaseEventHandler(ABC): """ self.plugin_name = plugin_name - def set_plugin_config(self,plugin_config) -> None: + def set_plugin_config(self, plugin_config) -> None: self.plugin_config = plugin_config - + def get_config(self, key: str, default=None): """获取插件配置值,支持嵌套键访问 diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 0bcb0060e..3fc943bd5 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -17,6 +17,7 @@ class ComponentType(Enum): TOOL = "tool" # 工具组件 SCHEDULER = "scheduler" # 定时任务组件(预留) EVENT_HANDLER = "event_handler" # 事件处理组件 + CHATTER = "chatter" # 聊天处理器组件 def __str__(self) -> str: return self.value @@ -39,8 +40,8 @@ class ActionActivationType(Enum): # 聊天模式枚举 class ChatMode(Enum): """聊天模式枚举""" - - FOCUS = "focus" # Focus聊天模式 + + FOCUS = "focus" # 专注模式 NORMAL = "normal" # Normal聊天模式 PROACTIVE = "proactive" # 主动思考模式 PRIORITY = "priority" # 优先级聊天模式 @@ -54,8 +55,8 @@ class ChatMode(Enum): class ChatType(Enum): """聊天类型枚举,用于限制插件在不同聊天环境中的使用""" - GROUP = "group" # 仅群聊可用 PRIVATE = "private" # 仅私聊可用 + GROUP = "group" # 仅群聊可用 ALL = "all" # 群聊和私聊都可用 def __str__(self): @@ -69,7 +70,7 @@ class EventType(Enum): """ ON_START = "on_start" # 启动事件,用于调用按时任务 - ON_STOP ="on_stop" + ON_STOP = "on_stop" ON_MESSAGE = "on_message" ON_PLAN = "on_plan" POST_LLM = "post_llm" @@ -210,6 +211,17 @@ class EventHandlerInfo(ComponentInfo): self.component_type = ComponentType.EVENT_HANDLER +@dataclass +class ChatterInfo(ComponentInfo): + """聊天处理器组件信息""" + + chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型 + + def __post_init__(self): + super().__post_init__() + self.component_type = ComponentType.CHATTER + + @dataclass class EventInfo(ComponentInfo): """事件组件信息""" diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 0f64aa9ec..63594c53e 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -1,7 +1,7 @@ from pathlib import Path import re -from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Any, Pattern, Tuple, Union, Type from src.common.logger import get_logger from src.plugin_system.base.component_types import ( @@ -11,14 +11,17 @@ from src.plugin_system.base.component_types import ( CommandInfo, PlusCommandInfo, EventHandlerInfo, + ChatterInfo, PluginInfo, ComponentType, ) + from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_action import BaseAction from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.plus_command import PlusCommand +from src.plugin_system.base.base_chatter import BaseChatter logger = get_logger("component_registry") @@ -31,42 +34,45 @@ class ComponentRegistry: def __init__(self): # 命名空间式组件名构成法 f"{component_type}.{component_name}" - self._plus_command_registry: Dict[str, Type[PlusCommand]] = {} - self._components: Dict[str, ComponentInfo] = {} + self._components: Dict[str, 'ComponentInfo'] = {} """组件注册表 命名空间式组件名 -> 组件信息""" - self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType} + self._components_by_type: Dict['ComponentType', Dict[str, 'ComponentInfo']] = {types: {} for types in ComponentType} """类型 -> 组件原名称 -> 组件信息""" self._components_classes: Dict[ - str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler, PlusCommand]] + str, Type[Union['BaseCommand', 'BaseAction', 'BaseTool', 'BaseEventHandler', 'PlusCommand', 'BaseChatter']] ] = {} """命名空间式组件名 -> 组件类""" # 插件注册表 - self._plugins: Dict[str, PluginInfo] = {} + self._plugins: Dict[str, 'PluginInfo'] = {} """插件名 -> 插件信息""" # Action特定注册表 - self._action_registry: Dict[str, Type[BaseAction]] = {} + self._action_registry: Dict[str, Type['BaseAction']] = {} """Action注册表 action名 -> action类""" - self._default_actions: Dict[str, ActionInfo] = {} + self._default_actions: Dict[str, 'ActionInfo'] = {} """默认动作集,即启用的Action集,用于重置ActionManager状态""" # Command特定注册表 - self._command_registry: Dict[str, Type[BaseCommand]] = {} + self._command_registry: Dict[str, Type['BaseCommand']] = {} """Command类注册表 command名 -> command类""" self._command_patterns: Dict[Pattern, str] = {} """编译后的正则 -> command名""" # 工具特定注册表 - self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类 - self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类 + self._tool_registry: Dict[str, Type['BaseTool']] = {} # 工具名 -> 工具类 + self._llm_available_tools: Dict[str, Type['BaseTool']] = {} # llm可用的工具名 -> 工具类 # EventHandler特定注册表 - self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} + self._event_handler_registry: Dict[str, Type['BaseEventHandler']] = {} """event_handler名 -> event_handler类""" - self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {} + self._enabled_event_handlers: Dict[str, Type['BaseEventHandler']] = {} """启用的事件处理器 event_handler名 -> event_handler类""" + self._chatter_registry: Dict[str, Type['BaseChatter']] = {} + """chatter名 -> chatter类""" + self._enabled_chatter_registry: Dict[str, Type['BaseChatter']] = {} + """启用的chatter名 -> chatter类""" logger.info("组件注册中心初始化完成") # == 注册方法 == @@ -93,7 +99,7 @@ class ComponentRegistry: def register_component( self, component_info: ComponentInfo, - component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]], + component_class: Type[Union['BaseCommand', 'BaseAction', 'BaseEventHandler', 'BaseTool', 'BaseChatter']], ) -> bool: """注册组件 @@ -151,6 +157,10 @@ class ComponentRegistry: assert isinstance(component_info, EventHandlerInfo) assert issubclass(component_class, BaseEventHandler) ret = self._register_event_handler_component(component_info, component_class) + case ComponentType.CHATTER: + assert isinstance(component_info, ChatterInfo) + assert issubclass(component_class, BaseChatter) + ret = self._register_chatter_component(component_info, component_class) case _: logger.warning(f"未知组件类型: {component_type}") @@ -162,7 +172,7 @@ class ComponentRegistry: ) return True - def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool: + def _register_action_component(self, action_info: 'ActionInfo', action_class: Type['BaseAction']) -> bool: """注册Action组件到Action特定注册表""" if not (action_name := action_info.name): logger.error(f"Action组件 {action_class.__name__} 必须指定名称") @@ -182,7 +192,7 @@ class ComponentRegistry: return True - def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool: + def _register_command_component(self, command_info: 'CommandInfo', command_class: Type['BaseCommand']) -> bool: """注册Command组件到Command特定注册表""" if not (command_name := command_info.name): logger.error(f"Command组件 {command_class.__name__} 必须指定名称") @@ -209,7 +219,7 @@ class ComponentRegistry: return True def _register_plus_command_component( - self, plus_command_info: PlusCommandInfo, plus_command_class: Type[PlusCommand] + self, plus_command_info: 'PlusCommandInfo', plus_command_class: Type['PlusCommand'] ) -> bool: """注册PlusCommand组件到特定注册表""" plus_command_name = plus_command_info.name @@ -223,7 +233,7 @@ class ComponentRegistry: # 创建专门的PlusCommand注册表(如果还没有) if not hasattr(self, "_plus_command_registry"): - self._plus_command_registry: Dict[str, Type[PlusCommand]] = {} + self._plus_command_registry: Dict[str, Type['PlusCommand']] = {} plus_command_class.plugin_name = plus_command_info.plugin_name # 设置插件配置 @@ -233,7 +243,7 @@ class ComponentRegistry: logger.debug(f"已注册PlusCommand组件: {plus_command_name}") return True - def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool: + def _register_tool_component(self, tool_info: 'ToolInfo', tool_class: Type['BaseTool']) -> bool: """注册Tool组件到Tool特定注册表""" tool_name = tool_info.name @@ -249,7 +259,7 @@ class ComponentRegistry: return True def _register_event_handler_component( - self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler] + self, handler_info: 'EventHandlerInfo', handler_class: Type['BaseEventHandler'] ) -> bool: if not (handler_name := handler_info.name): logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称") @@ -271,11 +281,38 @@ class ComponentRegistry: # 使用EventManager进行事件处理器注册 from src.plugin_system.core.event_manager import event_manager - return event_manager.register_event_handler(handler_class,self.get_plugin_config(handler_info.plugin_name) or {}) + return event_manager.register_event_handler( + handler_class, self.get_plugin_config(handler_info.plugin_name) or {} + ) + + def _register_chatter_component(self, chatter_info: 'ChatterInfo', chatter_class: Type['BaseChatter']) -> bool: + """注册Chatter组件到Chatter特定注册表""" + chatter_name = chatter_info.name + + if not chatter_name: + logger.error(f"Chatter组件 {chatter_class.__name__} 必须指定名称") + return False + if not isinstance(chatter_info, ChatterInfo) or not issubclass(chatter_class, BaseChatter): + logger.error(f"注册失败: {chatter_name} 不是有效的Chatter") + return False + + chatter_class.plugin_name = chatter_info.plugin_name + # 设置插件配置 + chatter_class.plugin_config = self.get_plugin_config(chatter_info.plugin_name) or {} + + self._chatter_registry[chatter_name] = chatter_class + + if not chatter_info.enabled: + logger.warning(f"Chatter组件 {chatter_name} 未启用") + return True # 未启用,但是也是注册成功 + self._enabled_chatter_registry[chatter_name] = chatter_class + + logger.debug(f"已注册Chatter组件: {chatter_name}") + return True # === 组件移除相关 === - async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool: + async def remove_component(self, component_name: str, component_type: 'ComponentType', plugin_name: str) -> bool: target_component_class = self.get_component_class(component_name, component_type) if not target_component_class: logger.warning(f"组件 {component_name} 未注册,无法移除") @@ -323,6 +360,12 @@ class ComponentRegistry: except Exception as e: logger.warning(f"移除EventHandler事件订阅时出错: {e}") + case ComponentType.CHATTER: + # 移除Chatter注册 + if hasattr(self, '_chatter_registry'): + self._chatter_registry.pop(component_name, None) + logger.debug(f"已移除Chatter组件: {component_name}") + case _: logger.warning(f"未知的组件类型: {component_type}") return False @@ -441,8 +484,8 @@ class ComponentRegistry: # === 组件查询方法 === def get_component_info( - self, component_name: str, component_type: Optional[ComponentType] = None - ) -> Optional[ComponentInfo]: + self, component_name: str, component_type: Optional['ComponentType'] = None + ) -> Optional['ComponentInfo']: # sourcery skip: class-extract-method """获取组件信息,支持自动命名空间解析 @@ -486,8 +529,8 @@ class ComponentRegistry: def get_component_class( self, component_name: str, - component_type: Optional[ComponentType] = None, - ) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]: + component_type: Optional['ComponentType'] = None, + ) -> Optional[Union[Type['BaseCommand'], Type['BaseAction'], Type['BaseEventHandler'], Type['BaseTool']]]: """获取组件类,支持自动命名空间解析 Args: @@ -504,7 +547,7 @@ class ComponentRegistry: # 2. 如果指定了组件类型,构造命名空间化的名称查找 if component_type: namespaced_name = f"{component_type.value}.{component_name}" - return self._components_classes.get(namespaced_name) + return self._components_classes.get(namespaced_name) # type: ignore[valid-type] # 3. 如果没有指定类型,尝试在所有命名空间中查找 candidates = [] @@ -529,22 +572,22 @@ class ComponentRegistry: # 4. 都没找到 return None - def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]: + def get_components_by_type(self, component_type: 'ComponentType') -> Dict[str, 'ComponentInfo']: """获取指定类型的所有组件""" return self._components_by_type.get(component_type, {}).copy() - def get_enabled_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]: + def get_enabled_components_by_type(self, component_type: 'ComponentType') -> Dict[str, 'ComponentInfo']: """获取指定类型的所有启用组件""" components = self.get_components_by_type(component_type) return {name: info for name, info in components.items() if info.enabled} # === Action特定查询方法 === - def get_action_registry(self) -> Dict[str, Type[BaseAction]]: + def get_action_registry(self) -> Dict[str, Type['BaseAction']]: """获取Action注册表""" return self._action_registry.copy() - def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]: + def get_registered_action_info(self, action_name: str) -> Optional['ActionInfo']: """获取Action信息""" info = self.get_component_info(action_name, ComponentType.ACTION) return info if isinstance(info, ActionInfo) else None @@ -555,11 +598,11 @@ class ComponentRegistry: # === Command特定查询方法 === - def get_command_registry(self) -> Dict[str, Type[BaseCommand]]: + def get_command_registry(self) -> Dict[str, Type['BaseCommand']]: """获取Command注册表""" return self._command_registry.copy() - def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]: + def get_registered_command_info(self, command_name: str) -> Optional['CommandInfo']: """获取Command信息""" info = self.get_component_info(command_name, ComponentType.COMMAND) return info if isinstance(info, CommandInfo) else None @@ -568,7 +611,7 @@ class ComponentRegistry: """获取Command模式注册表""" return self._command_patterns.copy() - def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]: + def find_command_by_text(self, text: str) -> Optional[Tuple[Type['BaseCommand'], dict, 'CommandInfo']]: # sourcery skip: use-named-expression, use-next """根据文本查找匹配的命令 @@ -595,15 +638,15 @@ class ComponentRegistry: return None # === Tool 特定查询方法 === - def get_tool_registry(self) -> Dict[str, Type[BaseTool]]: + def get_tool_registry(self) -> Dict[str, Type['BaseTool']]: """获取Tool注册表""" return self._tool_registry.copy() - def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]: + def get_llm_available_tools(self) -> Dict[str, Type['BaseTool']]: """获取LLM可用的Tool列表""" return self._llm_available_tools.copy() - def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]: + def get_registered_tool_info(self, tool_name: str) -> Optional['ToolInfo']: """获取Tool信息 Args: @@ -616,13 +659,13 @@ class ComponentRegistry: return info if isinstance(info, ToolInfo) else None # === PlusCommand 特定查询方法 === - def get_plus_command_registry(self) -> Dict[str, Type[PlusCommand]]: + def get_plus_command_registry(self) -> Dict[str, Type['PlusCommand']]: """获取PlusCommand注册表""" if not hasattr(self, "_plus_command_registry"): - pass + self._plus_command_registry: Dict[str, Type[PlusCommand]] = {} return self._plus_command_registry.copy() - def get_registered_plus_command_info(self, command_name: str) -> Optional[PlusCommandInfo]: + def get_registered_plus_command_info(self, command_name: str) -> Optional['PlusCommandInfo']: """获取PlusCommand信息 Args: @@ -636,26 +679,44 @@ class ComponentRegistry: # === EventHandler 特定查询方法 === - def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]: + def get_event_handler_registry(self) -> Dict[str, Type['BaseEventHandler']]: """获取事件处理器注册表""" return self._event_handler_registry.copy() - def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]: + def get_registered_event_handler_info(self, handler_name: str) -> Optional['EventHandlerInfo']: """获取事件处理器信息""" info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER) return info if isinstance(info, EventHandlerInfo) else None - def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]: + def get_enabled_event_handlers(self) -> Dict[str, Type['BaseEventHandler']]: """获取启用的事件处理器""" return self._enabled_event_handlers.copy() + # === Chatter 特定查询方法 === + def get_chatter_registry(self) -> Dict[str, Type['BaseChatter']]: + """获取Chatter注册表""" + if not hasattr(self, '_chatter_registry'): + self._chatter_registry: Dict[str, Type[BaseChatter]] = {} + return self._chatter_registry.copy() + + def get_enabled_chatter_registry(self) -> Dict[str, Type['BaseChatter']]: + """获取启用的Chatter注册表""" + if not hasattr(self, '_enabled_chatter_registry'): + self._enabled_chatter_registry: Dict[str, Type[BaseChatter]] = {} + return self._enabled_chatter_registry.copy() + + def get_registered_chatter_info(self, chatter_name: str) -> Optional['ChatterInfo']: + """获取Chatter信息""" + info = self.get_component_info(chatter_name, ComponentType.CHATTER) + return info if isinstance(info, ChatterInfo) else None + # === 插件查询方法 === - def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]: + def get_plugin_info(self, plugin_name: str) -> Optional['PluginInfo']: """获取插件信息""" return self._plugins.get(plugin_name) - def get_all_plugins(self) -> Dict[str, PluginInfo]: + def get_all_plugins(self) -> Dict[str, 'PluginInfo']: """获取所有插件""" return self._plugins.copy() @@ -663,13 +724,12 @@ class ComponentRegistry: # """获取所有启用的插件""" # return {name: info for name, info in self._plugins.items() if info.enabled} - def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]: + def get_plugin_components(self, plugin_name: str) -> List['ComponentInfo']: """获取插件的所有组件""" plugin_info = self.get_plugin_info(plugin_name) return plugin_info.components if plugin_info else [] - @staticmethod - def get_plugin_config(plugin_name: str) -> dict: + def get_plugin_config(self, plugin_name: str) -> dict: """获取插件配置 Args: @@ -684,19 +744,20 @@ class ComponentRegistry: plugin_instance = plugin_manager.get_plugin_instance(plugin_name) if plugin_instance and plugin_instance.config: return plugin_instance.config - + # 如果插件实例不存在,尝试从配置文件读取 try: import toml + config_path = Path("config") / "plugins" / plugin_name / "config.toml" if config_path.exists(): - with open(config_path, 'r', encoding='utf-8') as f: + with open(config_path, "r", encoding="utf-8") as f: config_data = toml.load(f) logger.debug(f"从配置文件读取插件 {plugin_name} 的配置") return config_data except Exception as e: logger.debug(f"读取插件 {plugin_name} 配置文件失败: {e}") - + return {} def get_registry_stats(self) -> Dict[str, Any]: @@ -706,6 +767,7 @@ class ComponentRegistry: tool_components: int = 0 events_handlers: int = 0 plus_command_components: int = 0 + chatter_components: int = 0 for component in self._components.values(): if component.component_type == ComponentType.ACTION: action_components += 1 @@ -717,12 +779,15 @@ class ComponentRegistry: events_handlers += 1 elif component.component_type == ComponentType.PLUS_COMMAND: plus_command_components += 1 + elif component.component_type == ComponentType.CHATTER: + chatter_components += 1 return { "action_components": action_components, "command_components": command_components, "tool_components": tool_components, "event_handlers": events_handlers, "plus_command_components": plus_command_components, + "chatter_components": chatter_components, "total_components": len(self._components), "total_plugins": len(self._plugins), "components_by_type": { @@ -730,6 +795,8 @@ class ComponentRegistry: }, "enabled_components": len([c for c in self._components.values() if c.enabled]), "enabled_plugins": len([p for p in self._plugins.values() if p.enabled]), + "enabled_components": len([c for c in self._components.values() if c.enabled]), + "enabled_plugins": len([p for p in self._plugins.values() if p.enabled]), } # === 组件移除相关 === diff --git a/src/plugin_system/core/event_manager.py b/src/plugin_system/core/event_manager.py index 4108adad0..dac75b88f 100644 --- a/src/plugin_system/core/event_manager.py +++ b/src/plugin_system/core/event_manager.py @@ -146,7 +146,9 @@ class EventManager: logger.info(f"事件 {event_name} 已禁用") return True - def register_event_handler(self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None) -> bool: + def register_event_handler( + self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None + ) -> bool: """注册事件处理器 Args: @@ -168,7 +170,7 @@ class EventManager: # 创建事件处理器实例,传递插件配置 handler_instance = handler_class() handler_instance.plugin_config = plugin_config - if plugin_config is not None and hasattr(handler_instance, 'set_plugin_config'): + if plugin_config is not None and hasattr(handler_instance, "set_plugin_config"): handler_instance.set_plugin_config(plugin_config) self._event_handlers[handler_name] = handler_instance diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index e0a39ac25..cc7a54d4c 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -129,9 +129,7 @@ class PluginManager: self._show_plugin_components(plugin_name) # 检查并调用 on_plugin_loaded 钩子(如果存在) - if hasattr(plugin_instance, "on_plugin_loaded") and callable( - plugin_instance.on_plugin_loaded - ): + if hasattr(plugin_instance, "on_plugin_loaded") and callable(plugin_instance.on_plugin_loaded): logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子") try: # 使用 asyncio.create_task 确保它不会阻塞加载流程 @@ -380,13 +378,14 @@ class PluginManager: tool_count = stats.get("tool_components", 0) event_handler_count = stats.get("event_handlers", 0) plus_command_count = stats.get("plus_command_components", 0) + chatter_count = stats.get("chatter_components", 0) total_components = stats.get("total_components", 0) # 📋 显示插件加载总览 if total_registered > 0: logger.info("🎉 插件系统加载完成!") logger.info( - f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count})" + f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count})" ) # 显示详细的插件列表 @@ -442,6 +441,12 @@ class PluginManager: if plus_command_components: plus_command_names = [c.name for c in plus_command_components] logger.info(f" ⚡ PlusCommand组件: {', '.join(plus_command_names)}") + chatter_components = [ + c for c in plugin_info.components if c.component_type == ComponentType.CHATTER + ] + if chatter_components: + chatter_names = [c.name for c in chatter_components] + logger.info(f" 🗣️ Chatter组件: {', '.join(chatter_names)}") if event_handler_components: event_handler_names = [c.name for c in event_handler_components] logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}") diff --git a/src/plugins/built_in/affinity_flow_chatter/README.md b/src/plugins/built_in/affinity_flow_chatter/README.md new file mode 100644 index 000000000..26add6a34 --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/README.md @@ -0,0 +1,125 @@ +# 亲和力聊天处理器插件 + +## 概述 + +这是一个内置的chatter插件,实现了基于亲和力流的智能聊天处理器,具有兴趣度评分和人物关系构建功能。 + +## 功能特性 + +- **智能兴趣度评分**: 自动识别和评估用户兴趣话题 +- **人物关系系统**: 根据互动历史建立和维持用户关系 +- **多聊天类型支持**: 支持私聊和群聊场景 +- **插件化架构**: 完全集成到插件系统中 + +## 组件架构 + +### BaseChatter (抽象基类) +- 位置: `src/plugin_system/base/base_chatter.py` +- 功能: 定义所有chatter组件的基础接口 +- 必须实现的方法: `execute(context: StreamContext) -> dict` + +### ChatterManager (管理器) +- 位置: `src/chat/chatter_manager.py` +- 功能: 管理和调度所有chatter组件 +- 特性: 自动从插件系统注册和发现chatter组件 + +### AffinityChatter (具体实现) +- 位置: `src/plugins/built_in/chatter/affinity_chatter.py` +- 功能: 亲和力流聊天处理器的具体实现 +- 支持的聊天类型: PRIVATE, GROUP + +## 使用方法 + +### 1. 基本使用 + +```python +from src.chat.chatter_manager import ChatterManager +from src.chat.planner_actions.action_manager import ChatterActionManager + +# 初始化 +action_manager = ChatterActionManager() +chatter_manager = ChatterManager(action_manager) + +# 处理消息流 +result = await chatter_manager.process_stream_context(stream_id, context) +``` + +### 2. 创建自定义Chatter + +```python +from src.plugin_system.base.base_chatter import BaseChatter +from src.plugin_system.base.component_types import ChatType, ComponentType +from src.plugin_system.base.component_types import ChatterInfo + +class CustomChatter(BaseChatter): + chat_types = [ChatType.PRIVATE] # 只支持私聊 + + async def execute(self, context: StreamContext) -> dict: + # 实现你的聊天逻辑 + return {"success": True, "message": "处理完成"} + +# 在插件中注册 +async def on_load(self): + chatter_info = ChatterInfo( + name="custom_chatter", + component_type=ComponentType.CHATTER, + description="自定义聊天处理器", + enabled=True, + plugin_name=self.name, + chat_type_allow=ChatType.PRIVATE + ) + + ComponentRegistry.register_component( + component_info=chatter_info, + component_class=CustomChatter + ) +``` + +## 配置 + +### 插件配置文件 +- 位置: `src/plugins/built_in/chatter/_manifest.json` +- 包含插件信息和组件配置 + +### 聊天类型 +- `PRIVATE`: 私聊 +- `GROUP`: 群聊 +- `ALL`: 所有类型 + +## 核心概念 + +### 1. 兴趣值系统 +- 自动识别同类话题 +- 兴趣值会根据聊天频率增减 +- 支持新话题的自动学习 + +### 2. 人物关系系统 +- 根据互动质量建立关系分 +- 不同关系分对应不同的回复风格 +- 支持情感化的交流 + +### 3. 执行流程 +1. 接收StreamContext +2. 使用ActionPlanner进行规划 +3. 执行相应的Action +4. 返回处理结果 + +## 扩展开发 + +### 添加新的Chatter类型 +1. 继承BaseChatter类 +2. 实现execute方法 +3. 在插件中注册组件 +4. 配置支持的聊天类型 + +### 集成现有功能 +- 使用ActionPlanner进行动作规划 +- 通过ActionManager执行动作 +- 利用现有的记忆和知识系统 + +## 注意事项 + +1. 所有chatter组件必须实现`execute`方法 +2. 插件注册时需要指定支持的聊天类型 +3. 组件名称不能包含点号(.) +4. 确保在插件卸载时正确清理资源 \ No newline at end of file diff --git a/src/plugins/built_in/affinity_flow_chatter/__init__.py b/src/plugins/built_in/affinity_flow_chatter/__init__.py new file mode 100644 index 000000000..bc8ebb733 --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/__init__.py @@ -0,0 +1,7 @@ +""" +亲和力聊天处理器插件 +""" + +from .plugin import AffinityChatterPlugin + +__all__ = ["AffinityChatterPlugin"] diff --git a/src/plugins/built_in/affinity_flow_chatter/_manifest.json b/src/plugins/built_in/affinity_flow_chatter/_manifest.json new file mode 100644 index 000000000..253365b87 --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/_manifest.json @@ -0,0 +1,23 @@ +{ + "manifest_version": 1, + "name": "affinity_chatter", + "display_name": "Affinity Flow Chatter", + "description": "Built-in chatter plugin for affinity flow with interest scoring and relationship building", + "version": "1.0.0", + "author": "MoFox", + "plugin_class": "AffinityChatterPlugin", + "enabled": true, + "is_built_in": true, + "components": [ + { + "name": "affinity_chatter", + "type": "chatter", + "description": "Affinity flow chatter with intelligent interest scoring and relationship building", + "enabled": true, + "chat_type_allow": ["all"] + } + ], + "host_application": { "min_version": "0.8.0" }, + "keywords": ["chatter", "affinity", "conversation"], + "categories": ["Chat", "AI"] +} \ No newline at end of file diff --git a/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py b/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py new file mode 100644 index 000000000..08f5f7098 --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py @@ -0,0 +1,236 @@ +""" +亲和力聊天处理器 +基于现有的AffinityFlowChatter重构为插件化组件 +""" + +import asyncio +import time +import traceback +from datetime import datetime +from typing import Dict, Any + +from src.plugin_system.base.base_chatter import BaseChatter +from src.plugin_system.base.component_types import ChatType +from src.common.data_models.message_manager_data_model import StreamContext +from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner +from src.chat.planner_actions.action_manager import ChatterActionManager +from src.common.logger import get_logger +from src.chat.express.expression_learner import expression_learner_manager + +logger = get_logger("affinity_chatter") + +# 定义颜色 +SOFT_GREEN = "\033[38;5;118m" # 一个更柔和的绿色 +RESET_COLOR = "\033[0m" + + +class AffinityChatter(BaseChatter): + """亲和力聊天处理器""" + + chatter_name: str = "AffinityChatter" + chatter_description: str = "基于亲和力模型的智能聊天处理器,支持多种聊天类型" + chat_types: list[ChatType] = [ChatType.ALL] # 支持所有聊天类型 + + def __init__(self, stream_id: str, action_manager: ChatterActionManager): + """ + 初始化亲和力聊天处理器 + + Args: + stream_id: 聊天流ID + planner: 动作规划器 + action_manager: 动作管理器 + """ + super().__init__(stream_id, action_manager) + self.planner = ChatterActionPlanner(stream_id, action_manager) + + # 处理器统计 + self.stats = { + "messages_processed": 0, + "plans_created": 0, + "actions_executed": 0, + "successful_executions": 0, + "failed_executions": 0, + } + self.last_activity_time = time.time() + + async def execute(self, context: StreamContext) -> dict: + """ + 处理StreamContext对象 + + Args: + context: StreamContext对象,包含聊天流的所有消息信息 + + Returns: + 处理结果字典 + """ + try: + # 触发表达学习 + learner = expression_learner_manager.get_expression_learner(self.stream_id) + asyncio.create_task(learner.trigger_learning_for_chat()) + + unread_messages = context.get_unread_messages() + + # 使用增强版规划器处理消息 + actions, target_message = await self.planner.plan(context=context) + self.stats["plans_created"] += 1 + + # 执行动作(如果规划器返回了动作) + execution_result = {"executed_count": len(actions) if actions else 0} + if actions: + logger.debug(f"聊天流 {self.stream_id} 生成了 {len(actions)} 个动作") + + # 更新统计 + self.stats["messages_processed"] += 1 + self.stats["actions_executed"] += execution_result.get("executed_count", 0) + self.stats["successful_executions"] += 1 + self.last_activity_time = time.time() + + result = { + "success": True, + "stream_id": self.stream_id, + "plan_created": True, + "actions_count": len(actions) if actions else 0, + "has_target_message": target_message is not None, + "unread_messages_processed": len(unread_messages), + **execution_result, + } + + logger.debug( + f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}" + ) + + return result + + except Exception as e: + logger.error(f"亲和力聊天处理器 {self.stream_id} 处理StreamContext时出错: {e}\n{traceback.format_exc()}") + self.stats["failed_executions"] += 1 + self.last_activity_time = time.time() + + return { + "success": False, + "stream_id": self.stream_id, + "error_message": str(e), + "executed_count": 0, + } + + def get_stats(self) -> Dict[str, Any]: + """ + 获取处理器统计信息 + + Returns: + 统计信息字典 + """ + return self.stats.copy() + + def get_planner_stats(self) -> Dict[str, Any]: + """ + 获取规划器统计信息 + + Returns: + 规划器统计信息字典 + """ + return self.planner.get_planner_stats() + + def get_interest_scoring_stats(self) -> Dict[str, Any]: + """ + 获取兴趣度评分统计信息 + + Returns: + 兴趣度评分统计信息字典 + """ + return self.planner.get_interest_scoring_stats() + + def get_relationship_stats(self) -> Dict[str, Any]: + """ + 获取用户关系统计信息 + + Returns: + 用户关系统计信息字典 + """ + return self.planner.get_relationship_stats() + + def get_current_mood_state(self) -> str: + """ + 获取当前聊天的情绪状态 + + Returns: + 当前情绪状态描述 + """ + return self.planner.get_current_mood_state() + + def get_mood_stats(self) -> Dict[str, Any]: + """ + 获取情绪状态统计信息 + + Returns: + 情绪状态统计信息字典 + """ + return self.planner.get_mood_stats() + + def get_user_relationship(self, user_id: str) -> float: + """ + 获取用户关系分 + + Args: + user_id: 用户ID + + Returns: + 用户关系分 (0.0-1.0) + """ + return self.planner.get_user_relationship(user_id) + + def update_interest_keywords(self, new_keywords: dict): + """ + 更新兴趣关键词 + + Args: + new_keywords: 新的兴趣关键词字典 + """ + self.planner.update_interest_keywords(new_keywords) + logger.info(f"聊天流 {self.stream_id} 已更新兴趣关键词: {list(new_keywords.keys())}") + + def reset_stats(self): + """重置统计信息""" + self.stats = { + "messages_processed": 0, + "plans_created": 0, + "actions_executed": 0, + "successful_executions": 0, + "failed_executions": 0, + } + + def is_active(self, max_inactive_minutes: int = 60) -> bool: + """ + 检查处理器是否活跃 + + Args: + max_inactive_minutes: 最大不活跃分钟数 + + Returns: + 是否活跃 + """ + current_time = time.time() + max_inactive_seconds = max_inactive_minutes * 60 + return (current_time - self.last_activity_time) < max_inactive_seconds + + def get_activity_time(self) -> float: + """ + 获取最后活动时间 + + Returns: + 最后活动时间戳 + """ + return self.last_activity_time + + def __str__(self) -> str: + """字符串表示""" + return f"AffinityChatter(stream_id={self.stream_id}, messages={self.stats['messages_processed']})" + + def __repr__(self) -> str: + """详细字符串表示""" + return ( + f"AffinityChatter(stream_id={self.stream_id}, " + f"messages_processed={self.stats['messages_processed']}, " + f"plans_created={self.stats['plans_created']}, " + f"last_activity={datetime.fromtimestamp(self.last_activity_time)})" + ) diff --git a/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py new file mode 100644 index 000000000..0538090bc --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py @@ -0,0 +1,333 @@ +""" +兴趣度评分系统 +基于多维度评分机制,包括兴趣匹配度、用户关系分、提及度和时间因子 +现在使用embedding计算智能兴趣匹配 +""" + +import traceback +from typing import Dict, List, Any + +from src.common.data_models.database_data_model import DatabaseMessages +from src.common.data_models.info_data_model import InterestScore +from src.chat.interest_system import bot_interest_manager +from src.common.logger import get_logger +from src.config.config import global_config +from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker +logger = get_logger("chatter_interest_scoring") + +# 定义颜色 +SOFT_BLUE = "\033[38;5;67m" +RESET_COLOR = "\033[0m" + + +class ChatterInterestScoringSystem: + """兴趣度评分系统""" + + def __init__(self): + # 智能兴趣匹配配置 + self.use_smart_matching = True + + # 从配置加载评分权重 + affinity_config = global_config.affinity_flow + self.score_weights = { + "interest_match": affinity_config.keyword_match_weight, # 兴趣匹配度权重 + "relationship": affinity_config.relationship_weight, # 关系分权重 + "mentioned": affinity_config.mention_bot_weight, # 是否提及bot权重 + } + + # 评分阈值 + self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值 + self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值 + + # 连续不回复概率提升 + self.no_reply_count = 0 + self.max_no_reply_count = affinity_config.max_no_reply_count + self.probability_boost_per_no_reply = ( + affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count + ) # 每次不回复增加的概率 + + # 用户关系数据 + self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score + + async def calculate_interest_scores( + self, messages: List[DatabaseMessages], bot_nickname: str + ) -> List[InterestScore]: + """计算消息的兴趣度评分""" + user_messages = [msg for msg in messages if str(msg.user_info.user_id) != str(global_config.bot.qq_account)] + if not user_messages: + return [] + + scores = [] + for _, msg in enumerate(user_messages, 1): + score = await self._calculate_single_message_score(msg, bot_nickname) + scores.append(score) + + return scores + + async def _calculate_single_message_score(self, message: DatabaseMessages, bot_nickname: str) -> InterestScore: + """计算单条消息的兴趣度评分""" + + keywords = self._extract_keywords_from_database(message) + interest_match_score = await self._calculate_interest_match_score(message.processed_plain_text, keywords) + relationship_score = self._calculate_relationship_score(message.user_info.user_id) + mentioned_score = self._calculate_mentioned_score(message, bot_nickname) + + total_score = ( + interest_match_score * self.score_weights["interest_match"] + + relationship_score * self.score_weights["relationship"] + + mentioned_score * self.score_weights["mentioned"] + ) + + details = { + "interest_match": f"兴趣匹配: {interest_match_score:.3f}", + "relationship": f"关系: {relationship_score:.3f}", + "mentioned": f"提及: {mentioned_score:.3f}", + } + + logger.debug( + f"消息得分详情: {total_score:.3f} (匹配: {interest_match_score:.2f}, 关系: {relationship_score:.2f}, 提及: {mentioned_score:.2f})" + ) + + return InterestScore( + message_id=message.message_id, + total_score=total_score, + interest_match_score=interest_match_score, + relationship_score=relationship_score, + mentioned_score=mentioned_score, + details=details, + ) + + async def _calculate_interest_match_score(self, content: str, keywords: List[str] = None) -> float: + """计算兴趣匹配度 - 使用智能embedding匹配""" + if not content: + return 0.0 + + # 使用智能匹配(embedding) + if self.use_smart_matching and bot_interest_manager.is_initialized: + return await self._calculate_smart_interest_match(content, keywords) + else: + # 智能匹配未初始化,返回默认分数 + return 0.3 + + async def _calculate_smart_interest_match(self, content: str, keywords: List[str] = None) -> float: + """使用embedding计算智能兴趣匹配""" + try: + # 如果没有传入关键词,则提取 + if not keywords: + keywords = self._extract_keywords_from_content(content) + + # 使用机器人兴趣管理器计算匹配度 + match_result = await bot_interest_manager.calculate_interest_match(content, keywords) + + if match_result: + # 返回匹配分数,考虑置信度和匹配标签数量 + affinity_config = global_config.affinity_flow + match_count_bonus = min( + len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus + ) + final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus + return final_score + else: + return 0.0 + + except Exception as e: + logger.error(f"智能兴趣匹配计算失败: {e}") + return 0.0 + + def _extract_keywords_from_database(self, message: DatabaseMessages) -> List[str]: + """从数据库消息中提取关键词""" + keywords = [] + + # 尝试从 key_words 字段提取(存储的是JSON字符串) + if message.key_words: + try: + import orjson + + keywords = orjson.loads(message.key_words) + if not isinstance(keywords, list): + keywords = [] + except (orjson.JSONDecodeError, TypeError): + keywords = [] + + # 如果没有 keywords,尝试从 key_words_lite 提取 + if not keywords and message.key_words_lite: + try: + import orjson + + keywords = orjson.loads(message.key_words_lite) + if not isinstance(keywords, list): + keywords = [] + except (orjson.JSONDecodeError, TypeError): + keywords = [] + + # 如果还是没有,从消息内容中提取(降级方案) + if not keywords: + keywords = self._extract_keywords_from_content(message.processed_plain_text) + + return keywords[:15] # 返回前15个关键词 + + def _extract_keywords_from_content(self, content: str) -> List[str]: + """从内容中提取关键词(降级方案)""" + import re + + # 清理文本 + content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字 + words = content.split() + + # 过滤和关键词提取 + keywords = [] + for word in words: + word = word.strip() + if ( + len(word) >= 2 # 至少2个字符 + and word.isalnum() # 字母数字 + and not word.isdigit() + ): # 不是纯数字 + keywords.append(word.lower()) + + # 去重并限制数量 + unique_keywords = list(set(keywords)) + return unique_keywords[:10] # 返回前10个唯一关键词 + + def _calculate_relationship_score(self, user_id: str) -> float: + """计算关系分 - 从数据库获取关系分""" + # 优先使用内存中的关系分 + if user_id in self.user_relationships: + relationship_value = self.user_relationships[user_id] + return min(relationship_value, 1.0) + + # 如果内存中没有,尝试从关系追踪器获取 + if hasattr(self, "relationship_tracker") and self.relationship_tracker: + try: + relationship_score = self.relationship_tracker.get_user_relationship_score(user_id) + # 同时更新内存缓存 + self.user_relationships[user_id] = relationship_score + return relationship_score + except Exception: + pass + else: + # 尝试从全局关系追踪器获取 + try: + from .relationship_tracker import ChatterRelationshipTracker + + global_tracker = ChatterRelationshipTracker() + if global_tracker: + relationship_score = global_tracker.get_user_relationship_score(user_id) + # 同时更新内存缓存 + self.user_relationships[user_id] = relationship_score + return relationship_score + except Exception: + pass + + # 默认新用户的基础分 + return global_config.affinity_flow.base_relationship_score + + def _calculate_mentioned_score(self, msg: DatabaseMessages, bot_nickname: str) -> float: + """计算提及分数""" + if not msg.processed_plain_text: + return 0.0 + + # 检查是否被提及 + bot_aliases = [bot_nickname] + global_config.bot.alias_names + is_mentioned = msg.is_mentioned or any(alias in msg.processed_plain_text for alias in bot_aliases if alias) + + # 如果被提及或是私聊,都视为提及了bot + if is_mentioned or not hasattr(msg, "chat_info_group_id"): + return global_config.affinity_flow.mention_bot_interest_score + + return 0.0 + + def should_reply(self, score: InterestScore, message: "DatabaseMessages") -> bool: + """判断是否应该回复""" + base_threshold = self.reply_threshold + + # 如果被提及,降低阈值 + if score.mentioned_score >= global_config.affinity_flow.mention_bot_adjustment_threshold: + base_threshold = self.mention_threshold + + # 计算连续不回复的概率提升 + probability_boost = min(self.no_reply_count * self.probability_boost_per_no_reply, 0.8) + effective_threshold = base_threshold - probability_boost + + # 做出决策 + should_reply = score.total_score >= effective_threshold + decision = "回复" if should_reply else "不回复" + logger.info( + f"{SOFT_BLUE}决策: {decision} (兴趣度: {score.total_score:.3f} / 阈值: {effective_threshold:.3f}){RESET_COLOR}" + ) + + return should_reply, score.total_score + + def record_reply_action(self, did_reply: bool): + """记录回复动作""" + old_count = self.no_reply_count + if did_reply: + self.no_reply_count = max(0, self.no_reply_count - global_config.affinity_flow.reply_cooldown_reduction) + action = "回复" + else: + self.no_reply_count += 1 + action = "不回复" + + # 限制最大计数 + self.no_reply_count = min(self.no_reply_count, self.max_no_reply_count) + logger.info(f"动作: {action}, 连续不回复次数: {old_count} -> {self.no_reply_count}") + + def update_user_relationship(self, user_id: str, relationship_change: float): + """更新用户关系""" + old_score = self.user_relationships.get( + user_id, global_config.affinity_flow.base_relationship_score + ) # 默认新用户分数 + new_score = max(0.0, min(1.0, old_score + relationship_change)) + + self.user_relationships[user_id] = new_score + + logger.info(f"用户关系: {user_id} | {old_score:.3f} → {new_score:.3f}") + + def get_user_relationship(self, user_id: str) -> float: + """获取用户关系分""" + return self.user_relationships.get(user_id, 0.3) + + def get_scoring_stats(self) -> Dict: + """获取评分系统统计""" + return { + "no_reply_count": self.no_reply_count, + "max_no_reply_count": self.max_no_reply_count, + "reply_threshold": self.reply_threshold, + "mention_threshold": self.mention_threshold, + "user_relationships": len(self.user_relationships), + } + + def reset_stats(self): + """重置统计信息""" + self.no_reply_count = 0 + logger.info("重置兴趣度评分系统统计") + + async def initialize_smart_interests(self, personality_description: str, personality_id: str = "default"): + """初始化智能兴趣系统""" + try: + logger.info("开始初始化智能兴趣系统...") + logger.info(f"人设ID: {personality_id}, 描述长度: {len(personality_description)}") + + await bot_interest_manager.initialize(personality_description, personality_id) + logger.info("智能兴趣系统初始化完成。") + + # 显示初始化后的统计信息 + bot_interest_manager.get_interest_stats() + + except Exception as e: + logger.error(f"初始化智能兴趣系统失败: {e}") + traceback.print_exc() + + def get_matching_config(self) -> Dict[str, Any]: + """获取匹配配置信息""" + return { + "use_smart_matching": self.use_smart_matching, + "smart_system_initialized": bot_interest_manager.is_initialized, + "smart_system_stats": bot_interest_manager.get_interest_stats() + if bot_interest_manager.is_initialized + else None, + } + + +# 创建全局兴趣评分系统实例 +chatter_interest_scoring_system = ChatterInterestScoringSystem() diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py new file mode 100644 index 000000000..3aa1a28c0 --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py @@ -0,0 +1,368 @@ +""" +PlanExecutor: 接收 Plan 对象并执行其中的所有动作。 +集成用户关系追踪机制,自动记录交互并更新关系。 +""" + +import asyncio +import time +from typing import Dict, List + +from src.config.config import global_config +from src.chat.planner_actions.action_manager import ChatterActionManager +from src.common.data_models.info_data_model import Plan, ActionPlannerInfo +from src.common.logger import get_logger + +logger = get_logger("plan_executor") + + +class ChatterPlanExecutor: + """ + 增强版PlanExecutor,集成用户关系追踪机制。 + + 功能: + 1. 执行Plan中的所有动作 + 2. 自动记录用户交互并添加到关系追踪 + 3. 分类执行回复动作和其他动作 + 4. 提供完整的执行统计和监控 + """ + + def __init__(self, action_manager: ChatterActionManager): + """ + 初始化增强版PlanExecutor。 + + Args: + action_manager (ChatterActionManager): 用于实际执行各种动作的管理器实例。 + """ + self.action_manager = action_manager + + # 执行统计 + self.execution_stats = { + "total_executed": 0, + "successful_executions": 0, + "failed_executions": 0, + "reply_executions": 0, + "other_action_executions": 0, + "execution_times": [], + } + + # 用户关系追踪引用 + self.relationship_tracker = None + + def set_relationship_tracker(self, relationship_tracker): + """设置关系追踪器""" + self.relationship_tracker = relationship_tracker + + async def execute(self, plan: Plan) -> Dict[str, any]: + """ + 遍历并执行Plan对象中`decided_actions`列表里的所有动作。 + + Args: + plan (Plan): 包含待执行动作列表的Plan对象。 + + Returns: + Dict[str, any]: 执行结果统计信息 + """ + if not plan.decided_actions: + logger.info("没有需要执行的动作。") + return {"executed_count": 0, "results": []} + + # 像hfc一样,提前打印将要执行的动作 + action_types = [action.action_type for action in plan.decided_actions] + logger.info(f"选择动作: {', '.join(action_types) if action_types else '无'}") + + execution_results = [] + reply_actions = [] + other_actions = [] + + # 分类动作:回复动作和其他动作 + for action_info in plan.decided_actions: + if action_info.action_type in ["reply", "proactive_reply"]: + reply_actions.append(action_info) + else: + other_actions.append(action_info) + + # 执行回复动作(优先执行) + if reply_actions: + reply_result = await self._execute_reply_actions(reply_actions, plan) + execution_results.extend(reply_result["results"]) + self.execution_stats["reply_executions"] += len(reply_actions) + + # 将其他动作放入后台任务执行,避免阻塞主流程 + if other_actions: + asyncio.create_task(self._execute_other_actions(other_actions, plan)) + logger.info(f"已将 {len(other_actions)} 个其他动作放入后台任务执行。") + # 注意:后台任务的结果不会立即计入本次返回的统计数据 + + # 更新总体统计 + self.execution_stats["total_executed"] += len(plan.decided_actions) + successful_count = sum(1 for r in execution_results if r["success"]) + self.execution_stats["successful_executions"] += successful_count + self.execution_stats["failed_executions"] += len(execution_results) - successful_count + + logger.info( + f"规划执行完成: 总数={len(plan.decided_actions)}, 成功={successful_count}, 失败={len(execution_results) - successful_count}" + ) + + return { + "executed_count": len(plan.decided_actions), + "successful_count": successful_count, + "failed_count": len(execution_results) - successful_count, + "results": execution_results, + } + + async def _execute_reply_actions(self, reply_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]: + """执行回复动作""" + results = [] + + for action_info in reply_actions: + result = await self._execute_single_reply_action(action_info, plan) + results.append(result) + + return {"results": results} + + async def _execute_single_reply_action(self, action_info: ActionPlannerInfo, plan: Plan) -> Dict[str, any]: + """执行单个回复动作""" + start_time = time.time() + success = False + error_message = "" + reply_content = "" + + try: + logger.info(f"执行回复动作: {action_info.action_type} (原因: {action_info.reasoning})") + + # 获取用户ID - 兼容对象和字典 + if hasattr(action_info.action_message, "user_info"): + user_id = action_info.action_message.user_info.user_id + else: + user_id = action_info.action_message.get("user_info", {}).get("user_id") + + if user_id == str(global_config.bot.qq_account): + logger.warning("尝试回复自己,跳过此动作以防止死循环。") + return { + "action_type": action_info.action_type, + "success": False, + "error_message": "尝试回复自己,跳过此动作以防止死循环。", + "execution_time": 0, + "reasoning": action_info.reasoning, + "reply_content": "", + } + # 构建回复动作参数 + action_params = { + "chat_id": plan.chat_id, + "target_message": action_info.action_message, + "reasoning": action_info.reasoning, + "action_data": action_info.action_data or {}, + } + + logger.debug(f"📬 [PlanExecutor] 准备调用 ActionManager,target_message: {action_info.action_message}") + + # 通过动作管理器执行回复 + reply_content = await self.action_manager.execute_action( + action_name=action_info.action_type, **action_params + ) + + success = True + logger.info(f"回复动作 '{action_info.action_type}' 执行成功。") + + except Exception as e: + error_message = str(e) + logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}") + + # 记录用户关系追踪 + if success and action_info.action_message: + await self._track_user_interaction(action_info, plan, reply_content) + + execution_time = time.time() - start_time + self.execution_stats["execution_times"].append(execution_time) + + return { + "action_type": action_info.action_type, + "success": success, + "error_message": error_message, + "execution_time": execution_time, + "reasoning": action_info.reasoning, + "reply_content": reply_content[:200] + "..." if len(reply_content) > 200 else reply_content, + } + + async def _execute_other_actions(self, other_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]: + """执行其他动作""" + results = [] + + # 并行执行其他动作 + tasks = [] + for action_info in other_actions: + task = self._execute_single_other_action(action_info, plan) + tasks.append(task) + + if tasks: + executed_results = await asyncio.gather(*tasks, return_exceptions=True) + for i, result in enumerate(executed_results): + if isinstance(result, Exception): + logger.error(f"执行动作 {other_actions[i].action_type} 时发生异常: {result}") + results.append( + { + "action_type": other_actions[i].action_type, + "success": False, + "error_message": str(result), + "execution_time": 0, + "reasoning": other_actions[i].reasoning, + } + ) + else: + results.append(result) + + return {"results": results} + + async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> Dict[str, any]: + """执行单个其他动作""" + start_time = time.time() + success = False + error_message = "" + + try: + logger.info(f"执行其他动作: {action_info.action_type} (原因: {action_info.reasoning})") + + action_data = action_info.action_data or {} + + # 针对 poke_user 动作,特殊处理 + if action_info.action_type == "poke_user": + target_message = action_info.action_message + if target_message: + # 优先直接获取 user_id,这才是最可靠的信息 + user_id = target_message.get("user_id") + if user_id: + action_data["user_id"] = user_id + logger.info(f"检测到戳一戳动作,目标用户ID: {user_id}") + else: + # 如果没有 user_id,再尝试用 user_nickname 作为备用方案 + user_name = target_message.get("user_nickname") + if user_name: + action_data["user_name"] = user_name + logger.info(f"检测到戳一戳动作,目标用户: {user_name}") + else: + logger.warning("无法从戳一戳消息中获取用户ID或昵称。") + + # 传递原始消息ID以支持引用 + action_data["target_message_id"] = target_message.get("message_id") + + # 构建动作参数 + action_params = { + "chat_id": plan.chat_id, + "target_message": action_info.action_message, + "reasoning": action_info.reasoning, + "action_data": action_data, + } + + # 通过动作管理器执行动作 + await self.action_manager.execute_action(action_name=action_info.action_type, **action_params) + + success = True + logger.info(f"其他动作 '{action_info.action_type}' 执行成功。") + + except Exception as e: + error_message = str(e) + logger.error(f"执行其他动作失败: {action_info.action_type}, 错误: {error_message}") + + execution_time = time.time() - start_time + self.execution_stats["execution_times"].append(execution_time) + + return { + "action_type": action_info.action_type, + "success": success, + "error_message": error_message, + "execution_time": execution_time, + "reasoning": action_info.reasoning, + } + + async def _track_user_interaction(self, action_info: ActionPlannerInfo, plan: Plan, reply_content: str): + """追踪用户交互 - 集成回复后关系追踪""" + try: + if not action_info.action_message: + return + + # 获取用户信息 - 处理对象和字典两种情况 + if hasattr(action_info.action_message, "user_info"): + # 对象情况 + user_info = action_info.action_message.user_info + user_id = user_info.user_id + user_name = user_info.user_nickname or user_id + user_message = action_info.action_message.content + else: + # 字典情况 + user_info = action_info.action_message.get("user_info", {}) + user_id = user_info.get("user_id") + user_name = user_info.get("user_nickname") or user_id + user_message = action_info.action_message.get("content", "") + + if not user_id: + logger.debug("跳过追踪:缺少用户ID") + return + + # 如果有设置关系追踪器,执行回复后关系追踪 + if self.relationship_tracker: + # 记录基础交互信息(保持向后兼容) + self.relationship_tracker.add_interaction( + user_id=user_id, + user_name=user_name, + user_message=user_message, + bot_reply=reply_content, + reply_timestamp=time.time(), + ) + + # 执行新的回复后关系追踪 + await self.relationship_tracker.track_reply_relationship( + user_id=user_id, user_name=user_name, bot_reply_content=reply_content, reply_timestamp=time.time() + ) + + logger.debug(f"已执行用户交互追踪: {user_id}") + + except Exception as e: + logger.error(f"追踪用户交互时出错: {e}") + logger.debug(f"action_message类型: {type(action_info.action_message)}") + logger.debug(f"action_message内容: {action_info.action_message}") + + def get_execution_stats(self) -> Dict[str, any]: + """获取执行统计信息""" + stats = self.execution_stats.copy() + + # 计算平均执行时间 + if stats["execution_times"]: + avg_time = sum(stats["execution_times"]) / len(stats["execution_times"]) + stats["average_execution_time"] = avg_time + stats["max_execution_time"] = max(stats["execution_times"]) + stats["min_execution_time"] = min(stats["execution_times"]) + else: + stats["average_execution_time"] = 0 + stats["max_execution_time"] = 0 + stats["min_execution_time"] = 0 + + # 移除执行时间列表以避免返回过大数据 + stats.pop("execution_times", None) + + return stats + + def reset_stats(self): + """重置统计信息""" + self.execution_stats = { + "total_executed": 0, + "successful_executions": 0, + "failed_executions": 0, + "reply_executions": 0, + "other_action_executions": 0, + "execution_times": [], + } + + def get_recent_performance(self, limit: int = 10) -> List[Dict[str, any]]: + """获取最近的执行性能""" + recent_times = self.execution_stats["execution_times"][-limit:] + if not recent_times: + return [] + + return [ + { + "execution_index": i + 1, + "execution_time": time_val, + "timestamp": time.time() - (len(recent_times) - i) * 60, # 估算时间戳 + } + for i, time_val in enumerate(recent_times) + ] diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py new file mode 100644 index 000000000..09d7c5b67 --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py @@ -0,0 +1,678 @@ +""" +PlanFilter: 接收 Plan 对象,根据不同模式的逻辑进行筛选,决定最终要执行的动作。 +""" + +import orjson +import time +import traceback +import re +from datetime import datetime +from typing import Any, Dict, List, Optional + +from json_repair import repair_json + +from src.chat.memory_system.Hippocampus import hippocampus_manager +from src.chat.utils.chat_message_builder import ( + build_readable_actions, + build_readable_messages_with_id, + get_actions_by_timestamp_with_chat, +) +from src.chat.utils.prompt import global_prompt_manager +from src.common.data_models.info_data_model import ActionPlannerInfo, Plan +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest +from src.mood.mood_manager import mood_manager +from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType +from src.schedule.schedule_manager import schedule_manager + +logger = get_logger("plan_filter") + +SAKURA_PINK = "\033[38;5;175m" +SKY_BLUE = "\033[38;5;117m" +RESET_COLOR = "\033[0m" + + +class ChatterPlanFilter: + """ + 根据 Plan 中的模式和信息,筛选并决定最终的动作。 + """ + + def __init__(self, chat_id: str, available_actions: List[str]): + """ + 初始化动作计划筛选器。 + + Args: + chat_id (str): 当前聊天的唯一标识符。 + available_actions (List[str]): 当前可用的动作列表。 + """ + self.chat_id = chat_id + self.available_actions = available_actions + self.planner_llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="planner") + self.last_obs_time_mark = 0.0 + + async def filter(self, reply_not_available: bool, plan: Plan) -> Plan: + """ + 执行筛选逻辑,并填充 Plan 对象的 decided_actions 字段。 + """ + try: + prompt, used_message_id_list = await self._build_prompt(plan) + plan.llm_prompt = prompt + + llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt) + + if llm_content: + try: + parsed_json = orjson.loads(repair_json(llm_content)) + except orjson.JSONDecodeError: + parsed_json = { + "thinking": "", + "actions": {"action_type": "no_action", "reason": "返回内容无法解析为JSON"}, + } + + if "reply" in plan.available_actions and reply_not_available: + # 如果reply动作不可用,但llm返回的仍然有reply,则改为no_reply + if ( + isinstance(parsed_json, dict) + and parsed_json.get("actions", {}).get("action_type", "") == "reply" + ): + parsed_json["actions"]["action_type"] = "no_reply" + elif isinstance(parsed_json, list): + for item in parsed_json: + if isinstance(item, dict) and item.get("actions", {}).get("action_type", "") == "reply": + item["actions"]["action_type"] = "no_reply" + item["actions"]["reason"] += " (但由于兴趣度不足,reply动作不可用,已改为no_reply)" + + if isinstance(parsed_json, dict): + parsed_json = [parsed_json] + + if isinstance(parsed_json, list): + final_actions = [] + reply_action_added = False + # 定义回复类动作的集合,方便扩展 + reply_action_types = {"reply", "proactive_reply"} + + for item in parsed_json: + if not isinstance(item, dict): + continue + + # 预解析 action_type 来进行判断 + thinking = item.get("thinking", "未提供思考过程") + actions_obj = item.get("actions", {}) + + # 处理actions字段可能是字典或列表的情况 + if isinstance(actions_obj, dict): + action_type = actions_obj.get("action_type", "no_action") + elif isinstance(actions_obj, list) and actions_obj: + # 如果是列表,取第一个元素的action_type + first_action = actions_obj[0] + if isinstance(first_action, dict): + action_type = first_action.get("action_type", "no_action") + else: + action_type = "no_action" + else: + action_type = "no_action" + + if action_type in reply_action_types: + if not reply_action_added: + final_actions.extend( + await self._parse_single_action(item, used_message_id_list, plan) + ) + reply_action_added = True + else: + # 非回复类动作直接添加 + final_actions.extend(await self._parse_single_action(item, used_message_id_list, plan)) + + if thinking and thinking != "未提供思考过程": + logger.info(f"\n{SAKURA_PINK}思考: {thinking}{RESET_COLOR}\n") + plan.decided_actions = self._filter_no_actions(final_actions) + + except Exception as e: + logger.error(f"筛选 Plan 时出错: {e}\n{traceback.format_exc()}") + plan.decided_actions = [ActionPlannerInfo(action_type="no_action", reasoning=f"筛选时出错: {e}")] + + # 在返回最终计划前,打印将要执行的动作 + action_types = [action.action_type for action in plan.decided_actions] + logger.info(f"选择动作: [{SKY_BLUE}{', '.join(action_types) if action_types else '无'}{RESET_COLOR}]") + + return plan + + async def _build_prompt(self, plan: Plan) -> tuple[str, list]: + """ + 根据 Plan 对象构建提示词。 + """ + try: + time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + bot_name = global_config.bot.nickname + bot_nickname = ( + f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else "" + ) + bot_core_personality = global_config.personality.personality_core + identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:" + + schedule_block = "" + # 优先检查是否被吵醒 + from src.chat.message_manager.message_manager import message_manager + angry_prompt_addition = "" + wakeup_mgr = message_manager.wakeup_manager + + # 双重检查确保愤怒状态不会丢失 + # 检查1: 直接从 wakeup_manager 获取 + if wakeup_mgr.is_in_angry_state(): + angry_prompt_addition = wakeup_mgr.get_angry_prompt_addition() + + # 检查2: 如果上面没获取到,再从 mood_manager 确认 + if not angry_prompt_addition: + chat_mood_for_check = mood_manager.get_mood_by_chat_id(plan.chat_id) + if chat_mood_for_check.is_angry_from_wakeup: + angry_prompt_addition = global_config.sleep_system.angry_prompt + + if angry_prompt_addition: + schedule_block = angry_prompt_addition + elif global_config.planning_system.schedule_enable: + if current_activity := schedule_manager.get_current_activity(): + schedule_block = f"你当前正在:{current_activity},但注意它与群聊的聊天无关。" + + mood_block = "" + # 如果被吵醒,则心情也是愤怒的,不需要另外的情绪模块 + if not angry_prompt_addition and global_config.mood.enable_mood: + chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id) + mood_block = f"你现在的心情是:{chat_mood.mood_state}" + + if plan.mode == ChatMode.PROACTIVE: + long_term_memory_block = await self._get_long_term_memory_context() + + chat_content_block, message_id_list = build_readable_messages_with_id( + messages=[msg.flatten() for msg in plan.chat_history], + timestamp_mode="normal", + truncate=False, + show_actions=False, + ) + + prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt") + actions_before_now = get_actions_by_timestamp_with_chat( + chat_id=plan.chat_id, + timestamp_start=time.time() - 3600, + timestamp_end=time.time(), + limit=5, + ) + actions_before_now_block = build_readable_actions(actions=actions_before_now) + actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}" + + prompt = prompt_template.format( + time_block=time_block, + identity_block=identity_block, + schedule_block=schedule_block, + mood_block=mood_block, + long_term_memory_block=long_term_memory_block, + chat_content_block=chat_content_block or "最近没有聊天内容。", + actions_before_now_block=actions_before_now_block, + ) + return prompt, message_id_list + + # 构建已读/未读历史消息 + read_history_block, unread_history_block, message_id_list = await self._build_read_unread_history_blocks( + plan + ) + + # 为了兼容性,保留原有的chat_content_block + chat_content_block, _ = build_readable_messages_with_id( + messages=[msg.flatten() for msg in plan.chat_history], + timestamp_mode="normal", + read_mark=self.last_obs_time_mark, + truncate=True, + show_actions=True, + ) + + actions_before_now = get_actions_by_timestamp_with_chat( + chat_id=plan.chat_id, + timestamp_start=time.time() - 3600, + timestamp_end=time.time(), + limit=5, + ) + + actions_before_now_block = build_readable_actions(actions=actions_before_now) + actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}" + + self.last_obs_time_mark = time.time() + + mentioned_bonus = "" + if global_config.chat.mentioned_bot_inevitable_reply: + mentioned_bonus = "\n- 有人提到你" + if global_config.chat.at_bot_inevitable_reply: + mentioned_bonus = "\n- 有人提到你,或者at你" + + if plan.mode == ChatMode.FOCUS: + no_action_block = """ +动作:no_action +动作描述:不选择任何动作 +{{ + "action": "no_action", + "reason":"不动作的原因" +}} + +动作:no_reply +动作描述:不进行回复,等待合适的回复时机 +- 当你刚刚发送了消息,没有人回复时,选择no_reply +- 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_reply +{{ + "action": "no_reply", + "reason":"不回复的原因" +}} +""" + else: # normal Mode + no_action_block = """重要说明: +- 'reply' 表示只进行普通聊天回复,不执行任何额外动作 +- 其他action表示在普通回复的基础上,执行相应的额外动作 +{{ + "action": "reply", + "target_message_id":"触发action的消息id", + "reason":"回复的原因" +}}""" + + is_group_chat = plan.chat_type == ChatType.GROUP + chat_context_description = "你现在正在一个群聊中" + if not is_group_chat and plan.target_info: + chat_target_name = plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方" + chat_context_description = f"你正在和 {chat_target_name} 私聊" + + action_options_block = await self._build_action_options(plan.available_actions) + + moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。" + + custom_prompt_block = "" + if global_config.custom_prompt.planner_custom_prompt_content: + custom_prompt_block = global_config.custom_prompt.planner_custom_prompt_content + + users_in_chat_str = "" # TODO: Re-implement user list fetching if needed + + planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") + prompt = planner_prompt_template.format( + schedule_block=schedule_block, + mood_block=mood_block, + time_block=time_block, + chat_context_description=chat_context_description, + read_history_block=read_history_block, + unread_history_block=unread_history_block, + actions_before_now_block=actions_before_now_block, + mentioned_bonus=mentioned_bonus, + no_action_block=no_action_block, + action_options_text=action_options_block, + moderation_prompt=moderation_prompt_block, + identity_block=identity_block, + custom_prompt_block=custom_prompt_block, + bot_name=bot_name, + users_in_chat=users_in_chat_str, + ) + return prompt, message_id_list + except Exception as e: + logger.error(f"构建 Planner 提示词时出错: {e}") + logger.error(traceback.format_exc()) + return "构建 Planner Prompt 时出错", [] + + async def _build_read_unread_history_blocks(self, plan: Plan) -> tuple[str, str, list]: + """构建已读/未读历史消息块""" + try: + # 从message_manager获取真实的已读/未读消息 + from src.chat.message_manager.message_manager import message_manager + from src.chat.utils.utils import assign_message_ids + from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat + + # 获取聊天流的上下文 + stream_context = message_manager.stream_contexts.get(plan.chat_id) + + # 获取真正的已读和未读消息 + read_messages = stream_context.history_messages # 已读消息存储在history_messages中 + if not read_messages: + from src.common.data_models.database_data_model import DatabaseMessages + # 如果内存中没有已读消息(比如刚启动),则从数据库加载最近的上下文 + fallback_messages_dicts = get_raw_msg_before_timestamp_with_chat( + chat_id=plan.chat_id, + timestamp=time.time(), + limit=global_config.chat.max_context_size, + ) + # 将字典转换为DatabaseMessages对象 + read_messages = [DatabaseMessages(**msg_dict) for msg_dict in fallback_messages_dicts] + + unread_messages = stream_context.get_unread_messages() # 获取未读消息 + + # 构建已读历史消息块 + if read_messages: + read_content, read_ids = build_readable_messages_with_id( + messages=[msg.flatten() for msg in read_messages[-50:]], # 限制数量 + timestamp_mode="normal_no_YMD", + truncate=False, + show_actions=False, + ) + read_history_block = f"{read_content}" + else: + read_history_block = "暂无已读历史消息" + + # 构建未读历史消息块(包含兴趣度) + if unread_messages: + # 扁平化未读消息用于计算兴趣度和格式化 + flattened_unread = [msg.flatten() for msg in unread_messages] + + # 尝试获取兴趣度评分(返回以真实 message_id 为键的字典) + interest_scores = await self._get_interest_scores_for_messages(flattened_unread) + + # 为未读消息分配短 id(保持与 build_readable_messages_with_id 的一致结构) + message_id_list = assign_message_ids(flattened_unread) + + unread_lines = [] + for idx, msg in enumerate(flattened_unread): + mapped = message_id_list[idx] + synthetic_id = mapped.get("id") + original_msg_id = msg.get("message_id") or msg.get("id") + msg_time = time.strftime("%H:%M:%S", time.localtime(msg.get("time", time.time()))) + user_nickname = msg.get("user_nickname", "未知用户") + msg_content = msg.get("processed_plain_text", "") + + # 不再显示兴趣度,但保留合成ID供模型内部使用 + # 同时,为了让模型更好地理解上下文,我们显示用户名 + unread_lines.append(f"<{synthetic_id}> {msg_time} {user_nickname}: {msg_content}") + + unread_history_block = "\n".join(unread_lines) + else: + unread_history_block = "暂无未读历史消息" + + return read_history_block, unread_history_block, message_id_list + + except Exception as e: + logger.error(f"构建已读/未读历史消息块时出错: {e}") + return "构建已读历史消息时出错", "构建未读历史消息时出错", [] + + async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]: + """为消息获取兴趣度评分""" + interest_scores = {} + + try: + from .interest_scoring import chatter_interest_scoring_system + from src.common.data_models.database_data_model import DatabaseMessages + + # 使用插件内部的兴趣度评分系统计算评分 + for msg_dict in messages: + try: + # 将字典转换为DatabaseMessages对象 + db_message = DatabaseMessages( + message_id=msg_dict.get("message_id", ""), + user_info=msg_dict.get("user_info", {}), + processed_plain_text=msg_dict.get("processed_plain_text", ""), + key_words=msg_dict.get("key_words", "[]"), + is_mentioned=msg_dict.get("is_mentioned", False) + ) + + # 计算消息兴趣度 + interest_score_obj = await chatter_interest_scoring_system._calculate_single_message_score( + message=db_message, + bot_nickname=global_config.bot.nickname + ) + interest_score = interest_score_obj.total_score + + # 构建兴趣度字典 + interest_scores[msg_dict.get("message_id", "")] = interest_score + + except Exception as e: + logger.warning(f"计算消息兴趣度失败: {e}") + continue + + except Exception as e: + logger.warning(f"获取兴趣度评分失败: {e}") + + return interest_scores + + async def _parse_single_action( + self, action_json: dict, message_id_list: list, plan: Plan + ) -> List[ActionPlannerInfo]: + parsed_actions = [] + try: + # 从新的actions结构中获取动作信息 + actions_obj = action_json.get("actions", {}) + + # 处理actions字段可能是字典或列表的情况 + actions_to_process = [] + if isinstance(actions_obj, dict): + actions_to_process.append(actions_obj) + elif isinstance(actions_obj, list): + actions_to_process.extend(actions_obj) + + if not actions_to_process: + actions_to_process.append({"action_type": "no_action", "reason": "actions格式错误"}) + + for single_action_obj in actions_to_process: + if not isinstance(single_action_obj, dict): + continue + + action = single_action_obj.get("action_type", "no_action") + reasoning = single_action_obj.get("reasoning", "未提供原因") # 兼容旧的reason字段 + action_data = single_action_obj.get("action_data", {}) + + # 为了向后兼容,如果action_data不存在,则从顶层字段获取 + if not action_data: + action_data = {k: v for k, v in single_action_obj.items() if k not in ["action_type", "reason", "reasoning", "thinking"]} + + # 保留原始的thinking字段(如果有) + thinking = action_json.get("thinking", "") + if thinking and thinking != "未提供思考过程": + action_data["thinking"] = thinking + + target_message_obj = None + if action not in ["no_action", "no_reply", "do_nothing", "proactive_reply"]: + if target_message_id := action_data.get("target_message_id"): + target_message_dict = self._find_message_by_id(target_message_id, message_id_list) + else: + # 如果LLM没有指定target_message_id,进行特殊处理 + if action == "poke_user": + # 对于poke_user,尝试找到触发它的那条戳一戳消息 + target_message_dict = self._find_poke_notice(message_id_list) + if not target_message_dict: + # 如果找不到,再使用最新消息作为兜底 + target_message_dict = self._get_latest_message(message_id_list) + else: + # 其他动作,默认选择最新的一条消息 + target_message_dict = self._get_latest_message(message_id_list) + + if target_message_dict: + # 直接使用字典作为action_message,避免DatabaseMessages对象创建失败 + target_message_obj = target_message_dict + # 替换action_data中的临时ID为真实ID + if "target_message_id" in action_data: + real_message_id = target_message_dict.get("message_id") or target_message_dict.get("id") + if real_message_id: + action_data["target_message_id"] = real_message_id + + # 确保 action_message 中始终有 message_id 字段 + if "message_id" not in target_message_obj and "id" in target_message_obj: + target_message_obj["message_id"] = target_message_obj["id"] + else: + # 如果找不到目标消息,对于reply动作来说这是必需的,应该记录警告 + if action == "reply": + logger.warning( + f"reply动作找不到目标消息,target_message_id: {action_data.get('target_message_id')}" + ) + # 将reply动作改为no_action,避免后续执行时出错 + action = "no_action" + reasoning = f"找不到目标消息进行回复。原始理由: {reasoning}" + + if ( + action not in ["no_action", "no_reply", "reply", "do_nothing", "proactive_reply"] + and action not in plan.available_actions + ): + reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}" + action = "no_action" + + parsed_actions.append( + ActionPlannerInfo( + action_type=action, + reasoning=reasoning, + action_data=action_data, + action_message=target_message_obj, + available_actions=plan.available_actions, + ) + ) + except Exception as e: + logger.error(f"解析单个action时出错: {e}") + parsed_actions.append( + ActionPlannerInfo( + action_type="no_action", + reasoning=f"解析action时出错: {e}", + ) + ) + return parsed_actions + + def _filter_no_actions(self, action_list: List[ActionPlannerInfo]) -> List[ActionPlannerInfo]: + non_no_actions = [a for a in action_list if a.action_type not in ["no_action", "no_reply"]] + if non_no_actions: + return non_no_actions + return action_list[:1] if action_list else [] + + async def _get_long_term_memory_context(self) -> str: + try: + now = datetime.now() + keywords = ["今天", "日程", "计划"] + if 5 <= now.hour < 12: + keywords.append("早上") + elif 12 <= now.hour < 18: + keywords.append("中午") + else: + keywords.append("晚上") + + retrieved_memories = await hippocampus_manager.get_memory_from_topic( + valid_keywords=keywords, max_memory_num=5, max_memory_length=1 + ) + + if not retrieved_memories: + return "最近没有什么特别的记忆。" + + memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories] + return " ".join(memory_statements) + except Exception as e: + logger.error(f"获取长期记忆时出错: {e}") + return "回忆时出现了一些问题。" + + async def _build_action_options(self, current_available_actions: Dict[str, ActionInfo]) -> str: + action_options_block = "" + for action_name, action_info in current_available_actions.items(): + # 构建参数的JSON示例 + params_json_list = [] + if action_info.action_parameters: + for p_name, p_desc in action_info.action_parameters.items(): + # 为参数描述添加一个通用示例值 + if action_name == "set_emoji_like" and p_name == "emoji": + # 特殊处理set_emoji_like的emoji参数 + from plugins.set_emoji_like.qq_emoji_list import qq_face + emoji_options = [re.search(r"\[表情:(.+?)\]", name).group(1) for name in qq_face.values() if re.search(r"\[表情:(.+?)\]", name)] + example_value = f"<从'{', '.join(emoji_options[:10])}...'中选择一个>" + else: + example_value = f"<{p_desc}>" + params_json_list.append(f' "{p_name}": "{example_value}"') + + # 基础动作信息 + action_description = action_info.description + action_require = "\n".join(f"- {req}" for req in action_info.action_require) + + # 构建完整的JSON使用范例 + json_example_lines = [ + " {", + f' "action_type": "{action_name}"', + ] + # 将参数列表合并到JSON示例中 + if params_json_list: + # 移除最后一行的逗号 + json_example_lines.extend([line.rstrip(',') for line in params_json_list]) + + json_example_lines.append(' "reason": "<执行该动作的详细原因>"') + json_example_lines.append(" }") + + # 使用逗号连接内部元素,除了最后一个 + json_parts = [] + for i, line in enumerate(json_example_lines): + # "{" 和 "}" 不需要逗号 + if line.strip() in ["{", "}"]: + json_parts.append(line) + continue + + # 检查是否是最后一个需要逗号的元素 + is_last_item = True + for next_line in json_example_lines[i+1:]: + if next_line.strip() not in ["}"]: + is_last_item = False + break + + if not is_last_item: + json_parts.append(f"{line},") + else: + json_parts.append(line) + + json_example = "\n".join(json_parts) + + # 使用新的、更详细的action_prompt模板 + using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt_with_example") + action_options_block += using_action_prompt.format( + action_name=action_name, + action_description=action_description, + action_require=action_require, + json_example=json_example, + ) + return action_options_block + + def _find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]: + # 兼容多种 message_id 格式:数字、m123、buffered-xxxx + # 如果是纯数字,补上 m 前缀以兼容旧格式 + candidate_ids = {message_id} + if message_id.isdigit(): + candidate_ids.add(f"m{message_id}") + + # 如果是 m 开头且后面是数字,尝试去掉 m 前缀的数字形式 + if message_id.startswith("m") and message_id[1:].isdigit(): + candidate_ids.add(message_id[1:]) + + # 逐项匹配 message_id_list(每项可能为 {'id':..., 'message':...}) + for item in message_id_list: + # 支持 message_id_list 中直接是字符串/ID 的情形 + if isinstance(item, str): + if item in candidate_ids: + # 没有 message 对象,返回None + return None + continue + + if not isinstance(item, dict): + continue + + item_id = item.get("id") + # 直接匹配分配的短 id + if item_id and item_id in candidate_ids: + return item.get("message") + + # 有时 message 存储里会有原始的 message_id 字段(如 buffered-xxxx) + message_obj = item.get("message") + if isinstance(message_obj, dict): + orig_mid = message_obj.get("message_id") or message_obj.get("id") + if orig_mid and orig_mid in candidate_ids: + return message_obj + + # 作为兜底,尝试在 message_id_list 中找到 message.message_id 匹配 + for item in message_id_list: + if isinstance(item, dict) and isinstance(item.get("message"), dict): + mid = item["message"].get("message_id") or item["message"].get("id") + if mid == message_id: + return item["message"] + + return None + + def _get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]: + if not message_id_list: + return None + return message_id_list[-1].get("message") + + def _find_poke_notice(self, message_id_list: list) -> Optional[Dict[str, Any]]: + """在消息列表中寻找戳一戳的通知消息""" + for item in reversed(message_id_list): + message = item.get("message") + if ( + isinstance(message, dict) + and message.get("type") == "notice" + and "戳" in message.get("processed_plain_text", "") + ): + return message + return None diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_generator.py b/src/plugins/built_in/affinity_flow_chatter/plan_generator.py new file mode 100644 index 000000000..bd3f6185d --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/plan_generator.py @@ -0,0 +1,168 @@ +""" +PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的"原始计划" (Plan)。 +""" + +import time +from typing import Dict + +from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat +from src.chat.utils.utils import get_chat_type_and_target_info +from src.common.data_models.database_data_model import DatabaseMessages +from src.common.data_models.info_data_model import Plan, TargetPersonInfo +from src.config.config import global_config +from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType +from src.plugin_system.core.component_registry import component_registry + + +class ChatterPlanGenerator: + """ + ChatterPlanGenerator 负责在规划流程的初始阶段收集所有必要信息。 + + 它会汇总以下信息来构建一个"原始"的 Plan 对象,该对象后续会由 PlanFilter 进行筛选: + - 当前聊天信息 (ID, 目标用户) + - 当前可用的动作列表 + - 最近的聊天历史记录 + + Attributes: + chat_id (str): 当前聊天的唯一标识符。 + action_manager (ActionManager): 用于获取可用动作列表的管理器。 + """ + + def __init__(self, chat_id: str): + """ + 初始化 ChatterPlanGenerator。 + + Args: + chat_id (str): 当前聊天的 ID。 + """ + from src.chat.planner_actions.action_manager import ChatterActionManager + + self.chat_id = chat_id + # 注意:ChatterActionManager 可能需要根据实际情况初始化 + self.action_manager = ChatterActionManager() + + async def generate(self, mode: ChatMode) -> Plan: + """ + 收集所有信息,生成并返回一个初始的 Plan 对象。 + + 这个 Plan 对象包含了决策所需的所有上下文信息。 + + Args: + mode (ChatMode): 当前的聊天模式。 + + Returns: + Plan: 包含所有上下文信息的初始计划对象。 + """ + try: + # 获取聊天类型和目标信息 + chat_type, target_info = get_chat_type_and_target_info(self.chat_id) + + # 获取可用动作列表 + available_actions = await self._get_available_actions(chat_type, mode) + + # 获取聊天历史记录 + recent_messages = await self._get_recent_messages() + + # 构建计划对象 + plan = Plan( + chat_id=self.chat_id, + chat_type=chat_type, + mode=mode, + target_info=target_info, + available_actions=available_actions, + chat_history=recent_messages, + ) + + return plan + + except Exception: + # 如果生成失败,返回一个基本的空计划 + return Plan( + chat_id=self.chat_id, + mode=mode, + target_info=TargetPersonInfo(), + available_actions={}, + chat_history=[], + ) + + async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> Dict[str, ActionInfo]: + """ + 获取当前可用的动作列表。 + + Args: + chat_type (ChatType): 聊天类型。 + mode (ChatMode): 聊天模式。 + + Returns: + Dict[str, ActionInfo]: 可用动作的字典。 + """ + try: + # 从组件注册表获取可用动作 + available_actions = component_registry.get_enabled_actions() + + # 根据聊天类型和模式筛选动作 + filtered_actions = {} + for action_name, action_info in available_actions.items(): + # 检查动作是否支持当前聊天类型 + if chat_type in action_info.chat_types: + # 检查动作是否支持当前模式 + if mode in action_info.chat_modes: + filtered_actions[action_name] = action_info + + return filtered_actions + + except Exception: + # 如果获取失败,返回空字典 + return {} + + async def _get_recent_messages(self) -> list[DatabaseMessages]: + """ + 获取最近的聊天历史记录。 + + Returns: + list[DatabaseMessages]: 最近的聊天消息列表。 + """ + try: + # 获取最近的消息记录 + raw_messages = get_raw_msg_before_timestamp_with_chat( + chat_id=self.chat_id, timestamp=time.time(), limit=global_config.memory.short_memory_length + ) + + # 转换为 DatabaseMessages 对象 + recent_messages = [] + for msg in raw_messages: + try: + db_msg = DatabaseMessages( + message_id=msg.get("message_id", ""), + time=float(msg.get("time", 0)), + chat_id=msg.get("chat_id", ""), + processed_plain_text=msg.get("processed_plain_text", ""), + user_id=msg.get("user_id", ""), + user_nickname=msg.get("user_nickname", ""), + user_platform=msg.get("user_platform", ""), + ) + recent_messages.append(db_msg) + except Exception: + # 跳过格式错误的消息 + continue + + return recent_messages + + except Exception: + # 如果获取失败,返回空列表 + return [] + + def get_generator_stats(self) -> Dict: + """ + 获取生成器统计信息。 + + Returns: + Dict: 统计信息字典。 + """ + return { + "chat_id": self.chat_id, + "action_count": len(self.action_manager._using_actions) + if hasattr(self.action_manager, "_using_actions") + else 0, + "generation_time": time.time(), + } diff --git a/src/plugins/built_in/affinity_flow_chatter/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner.py new file mode 100644 index 000000000..57f98954e --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/planner.py @@ -0,0 +1,269 @@ +""" +主规划器入口,负责协调 PlanGenerator, PlanFilter, 和 PlanExecutor。 +集成兴趣度评分系统和用户关系追踪机制,实现智能化的聊天决策。 +""" + +from dataclasses import asdict +import time +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor +from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter +from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator +from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system +from src.mood.mood_manager import mood_manager + + +from src.common.logger import get_logger +from src.config.config import global_config + +if TYPE_CHECKING: + from src.common.data_models.message_manager_data_model import StreamContext + from src.common.data_models.info_data_model import Plan + from src.chat.planner_actions.action_manager import ChatterActionManager + +# 导入提示词模块以确保其被初始化 +from src.plugins.built_in.affinity_flow_chatter import planner_prompts # noqa + +logger = get_logger("planner") + + +class ChatterActionPlanner: + """ + 增强版ActionPlanner,集成兴趣度评分和用户关系追踪机制。 + + 核心功能: + 1. 兴趣度评分系统:根据兴趣匹配度、关系分、提及度、时间因子对消息评分 + 2. 用户关系追踪:自动追踪用户交互并更新关系分 + 3. 智能回复决策:基于兴趣度阈值和连续不回复概率的智能决策 + 4. 完整的规划流程:生成→筛选→执行的完整三阶段流程 + """ + + def __init__(self, chat_id: str, action_manager: "ChatterActionManager"): + """ + 初始化增强版ActionPlanner。 + + Args: + chat_id (str): 当前聊天的 ID。 + action_manager (ChatterActionManager): 一个 ChatterActionManager 实例。 + """ + self.chat_id = chat_id + self.action_manager = action_manager + self.generator = ChatterPlanGenerator(chat_id) + self.executor = ChatterPlanExecutor(action_manager) + + # 使用新的统一兴趣度管理系统 + + # 规划器统计 + self.planner_stats = { + "total_plans": 0, + "successful_plans": 0, + "failed_plans": 0, + "replies_generated": 0, + "other_actions_executed": 0, + } + + async def plan(self, context: "StreamContext" = None) -> Tuple[List[Dict], Optional[Dict]]: + """ + 执行完整的增强版规划流程。 + + Args: + context (StreamContext): 包含聊天流消息的上下文对象。 + + Returns: + Tuple[List[Dict], Optional[Dict]]: 一个元组,包含: + - final_actions_dict (List[Dict]): 最终确定的动作列表(字典格式)。 + - final_target_message_dict (Optional[Dict]): 最终的目标消息(字典格式)。 + """ + try: + self.planner_stats["total_plans"] += 1 + + return await self._enhanced_plan_flow(context) + + except Exception as e: + logger.error(f"规划流程出错: {e}") + self.planner_stats["failed_plans"] += 1 + return [], None + + async def _enhanced_plan_flow(self, context: "StreamContext") -> Tuple[List[Dict], Optional[Dict]]: + """执行增强版规划流程""" + try: + # 在规划前,先进行动作修改 + from src.chat.planner_actions.action_modifier import ActionModifier + action_modifier = ActionModifier(self.action_manager, self.chat_id) + await action_modifier.modify_actions() + + # 1. 生成初始 Plan + initial_plan = await self.generator.generate(context.chat_mode) + + # 确保Plan中包含所有当前可用的动作 + initial_plan.available_actions = self.action_manager.get_using_actions() + + unread_messages = context.get_unread_messages() if context else [] + # 2. 使用新的兴趣度管理系统进行评分 + score = 0.0 + should_reply = False + reply_not_available = False + + if unread_messages: + # 获取用户ID,优先从user_info.user_id获取,其次从user_id属性获取 + user_id = None + first_message = unread_messages[0] + user_id = first_message.user_info.user_id + + # 构建计算上下文 + calc_context = { + "stream_id": self.chat_id, + "user_id": user_id, + } + + # 为每条消息计算兴趣度 + for message in unread_messages: + try: + # 使用插件内部的兴趣度评分系统计算 + interest_score = await chatter_interest_scoring_system._calculate_single_message_score( + message=message, + bot_nickname=global_config.bot.nickname + ) + message_interest = interest_score.total_score + + # 更新消息的兴趣度 + message.interest_value = message_interest + + # 简单的回复决策逻辑:兴趣度超过阈值则回复 + message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold + + logger.debug(f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}") + + # 更新StreamContext中的消息信息并刷新focus_energy + if context: + from src.chat.message_manager.message_manager import message_manager + message_manager.update_message( + stream_id=self.chat_id, + message_id=message.message_id, + interest_value=message_interest, + should_reply=message.should_reply + ) + + # 更新数据库中的消息记录 + try: + from src.chat.message_receive.storage import MessageStorage + MessageStorage.update_message_interest_value(message.message_id, message_interest) + logger.debug(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}") + except Exception as e: + logger.warning(f"更新数据库消息兴趣度失败: {e}") + + # 记录最高分 + if message_interest > score: + score = message_interest + if message.should_reply: + should_reply = True + else: + reply_not_available = True + + except Exception as e: + logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}") + # 设置默认值 + message.interest_value = 0.0 + message.should_reply = False + + # 检查兴趣度是否达到非回复动作阈值 + non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold + if score < non_reply_action_interest_threshold: + logger.info(f"兴趣度 {score:.3f} 低于阈值 {non_reply_action_interest_threshold:.3f},不执行动作") + # 直接返回 no_action + from src.common.data_models.info_data_model import ActionPlannerInfo + + no_action = ActionPlannerInfo( + action_type="no_action", + reasoning=f"兴趣度评分 {score:.3f} 未达阈值 {non_reply_action_interest_threshold:.3f}", + action_data={}, + action_message=None, + ) + filtered_plan = initial_plan + filtered_plan.decided_actions = [no_action] + else: + # 4. 筛选 Plan + available_actions = list(initial_plan.available_actions.keys()) + plan_filter = ChatterPlanFilter(self.chat_id, available_actions) + filtered_plan = await plan_filter.filter(reply_not_available, initial_plan) + + # 检查filtered_plan是否有reply动作,用于统计 + has_reply_action = any(decision.action_type == "reply" for decision in filtered_plan.decided_actions) + + # 5. 使用 PlanExecutor 执行 Plan + execution_result = await self.executor.execute(filtered_plan) + + # 6. 根据执行结果更新统计信息 + self._update_stats_from_execution_result(execution_result) + + # 7. 返回结果 + return self._build_return_result(filtered_plan) + + except Exception as e: + logger.error(f"增强版规划流程出错: {e}") + self.planner_stats["failed_plans"] += 1 + return [], None + + def _update_stats_from_execution_result(self, execution_result: Dict[str, any]): + """根据执行结果更新规划器统计""" + if not execution_result: + return + + successful_count = execution_result.get("successful_count", 0) + + # 更新成功执行计数 + self.planner_stats["successful_plans"] += successful_count + + # 统计回复动作和其他动作 + reply_count = 0 + other_count = 0 + + for result in execution_result.get("results", []): + action_type = result.get("action_type", "") + if action_type in ["reply", "proactive_reply"]: + reply_count += 1 + else: + other_count += 1 + + self.planner_stats["replies_generated"] += reply_count + self.planner_stats["other_actions_executed"] += other_count + + def _build_return_result(self, plan: "Plan") -> Tuple[List[Dict], Optional[Dict]]: + """构建返回结果""" + final_actions = plan.decided_actions or [] + final_target_message = next((act.action_message for act in final_actions if act.action_message), None) + + final_actions_dict = [asdict(act) for act in final_actions] + + if final_target_message: + if hasattr(final_target_message, "__dataclass_fields__"): + final_target_message_dict = asdict(final_target_message) + else: + final_target_message_dict = final_target_message + else: + final_target_message_dict = None + + return final_actions_dict, final_target_message_dict + + def get_planner_stats(self) -> Dict[str, any]: + """获取规划器统计""" + return self.planner_stats.copy() + + def get_current_mood_state(self) -> str: + """获取当前聊天的情绪状态""" + chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id) + return chat_mood.mood_state + + def get_mood_stats(self) -> Dict[str, any]: + """获取情绪状态统计""" + chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id) + return { + "current_mood": chat_mood.mood_state, + "is_angry_from_wakeup": chat_mood.is_angry_from_wakeup, + "regression_count": chat_mood.regression_count, + "last_change_time": chat_mood.last_change_time, + } + + +# 全局兴趣度评分系统实例 - 在 individuality 模块中创建 diff --git a/src/plugins/built_in/affinity_flow_chatter/planner_prompts.py b/src/plugins/built_in/affinity_flow_chatter/planner_prompts.py new file mode 100644 index 000000000..c8f448067 --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/planner_prompts.py @@ -0,0 +1,290 @@ +""" +本文件集中管理所有与规划器(Planner)相关的提示词(Prompt)模板。 + +通过将提示词与代码逻辑分离,可以更方便地对模型的行为进行迭代和优化, +而无需修改核心代码。 +""" + +from src.chat.utils.prompt import Prompt + + +def init_prompts(): + """ + 初始化并向 Prompt 注册系统注册所有规划器相关的提示词。 + + 这个函数会在模块加载时自动调用,确保所有提示词在系统启动时都已准备就绪。 + """ + # 核心规划器提示词,用于在接收到新消息时决定如何回应。 + # 它构建了一个复杂的上下文,包括历史记录、可用动作、角色设定等, + # 并要求模型以 JSON 格式输出一个或多个动作组合。 + Prompt( + """ +{mood_block} +{time_block} +{identity_block} + +{users_in_chat} +{custom_prompt_block} +{chat_context_description},以下是具体的聊天内容。 + +## 📜 已读历史消息(仅供参考) +{read_history_block} + +## 📬 未读历史消息(动作执行对象) +{unread_history_block} + +{moderation_prompt} + +**任务: 构建一个完整的响应** +你的任务是根据当前的聊天内容,构建一个完整的、人性化的响应。一个完整的响应由两部分组成: +1. **主要动作**: 这是响应的核心,通常是 `reply`(如果有)。 +2. **辅助动作 (可选)**: 这是为了增强表达效果的附加动作,例如 `emoji`(发送表情包)或 `poke_user`(戳一戳)。 + +**决策流程:** +1. **重要:已读历史消息仅作为当前聊天情景的参考,帮助你理解对话上下文。** +2. **重要:所有动作的执行对象只能是未读历史消息中的消息,不能对已读消息执行动作。** +3. 在未读历史消息中,优先对兴趣值高的消息做出动作(兴趣值标注在消息末尾)。 +4. 首先,决定是否要对未读消息进行 `reply`(如果有)。 +5. 然后,评估当前的对话气氛和用户情绪,判断是否需要一个**辅助动作**来让你的回应更生动、更符合你的性格。 +6. 如果需要,选择一个最合适的辅助动作与 `reply`(如果有) 组合。 +7. 如果用户明确要求了某个动作,请务必优先满足。 + +**重要提醒:** +- **回复消息时必须遵循对话的流程,不要重复已经说过的话。** +- **确保回复与上下文紧密相关,回应要针对用户的消息内容。** +- **保持角色设定的一致性,使用符合你性格的语言风格。** +- **不要对表情包消息做出回应!** + +**输出格式:** +请严格按照以下 JSON 格式输出,包含 `thinking` 和 `actions` 字段: + +**重要概念:将“内心思考”作为思绪流的体现** +`thinking` 字段是本次决策的核心。它并非一个简单的“理由”,而是 **一个模拟人类在回应前,头脑中自然浮现的、未经修饰的思绪流**。你需要完全代入 {identity_block} 的角色,将那一刻的想法自然地记录下来。 + +**内心思考的要点:** +* **自然流露**: 不要使用“决定”、“所以”、“因此”等结论性或汇报式的词语。你的思考应该像日记一样,是给自己看的,充满了不确定性和情绪的自然流动。 +* **展现过程**: 重点在于展现 **思考的过程**,而不是 **决策的结果**。描述你看到了什么,想到了什么,感受到了什么。 +* **使用昵称**: 在你的思绪流中,请直接使用用户的昵称来指代他们,而不是``, ``这样的消息ID。 +* **严禁技术术语**: 严禁在思考中提及任何数字化的度量(如兴趣度、分数)或内部技术术语。请完全使用角色自身的感受和语言来描述思考过程。 + +## 可用动作列表 +{action_options_text} + +```json +{{ + "thinking": "在这里写下你的思绪流...", + "actions": [ + {{ + "action_type": "动作类型(如:reply, emoji等)", + "reasoning": "选择该动作的理由", + "action_data": {{ + "target_message_id": "目标消息ID", + "content": "回复内容或其他动作所需数据" + }} + }} + ] +}} +``` + +**强制规则**: +- 对于每一个需要目标消息的动作(如`reply`, `poke_user`, `set_emoji_like`),你 **必须** 在`action_data`中提供准确的`target_message_id`,这个ID来源于`## 未读历史消息`中消息前的``标签。 +- 当你选择的动作需要参数时(例如 `set_emoji_like` 需要 `emoji` 参数),你 **必须** 在 `action_data` 中提供所有必需的参数及其对应的值。 + +如果没有合适的回复对象或不需要回复,输出空的 actions 数组: +```json +{{ + "thinking": "说明为什么不需要回复", + "actions": [] +}} +``` +""", + "planner_prompt", + ) + + # 主动规划器提示词,用于主动场景和前瞻性规划 + Prompt( + """ +{mood_block} +{time_block} +{identity_block} + +{users_in_chat} +{custom_prompt_block} +{chat_context_description},以下是具体的聊天内容。 + +## 📜 已读历史消息(仅供参考) +{read_history_block} + +## 📬 未读历史消息(动作执行对象) +{unread_history_block} + +{moderation_prompt} + +**任务: 构建一个完整的响应** +你的任务是根据当前的聊天内容,构建一个完整的、人性化的响应。一个完整的响应由两部分组成: +1. **主要动作**: 这是响应的核心,通常是 `reply`(如果有)。 +2. **辅助动作 (可选)**: 这是为了增强表达效果的附加动作,例如 `emoji`(发送表情包)或 `poke_user`(戳一戳)。 + +**决策流程:** +1. **重要:已读历史消息仅作为当前聊天情景的参考,帮助你理解对话上下文。** +2. **重要:所有动作的执行对象只能是未读历史消息中的消息,不能对已读消息执行动作。** +3. 在未读历史消息中,优先对兴趣值高的消息做出动作(兴趣值标注在消息末尾)。 +4. 首先,决定是否要对未读消息进行 `reply`(如果有)。 +5. 然后,评估当前的对话气氛和用户情绪,判断是否需要一个**辅助动作**来让你的回应更生动、更符合你的性格。 +6. 如果需要,选择一个最合适的辅助动作与 `reply`(如果有) 组合。 +7. 如果用户明确要求了某个动作,请务必优先满足。 + +**动作限制:** +- 在私聊中,你只能使用 `reply` 动作。私聊中不允许使用任何其他动作。 +- 在群聊中,你可以自由选择是否使用辅助动作。 + +**重要提醒:** +- **回复消息时必须遵循对话的流程,不要重复已经说过的话。** +- **确保回复与上下文紧密相关,回应要针对用户的消息内容。** +- **保持角色设定的一致性,使用符合你性格的语言风格。** + +**输出格式:** +请严格按照以下 JSON 格式输出,包含 `thinking` 和 `actions` 字段: +```json +{{ + "thinking": "你的思考过程,分析当前情况并说明为什么选择这些动作", + "actions": [ + {{ + "action_type": "动作类型(如:reply, emoji等)", + "reasoning": "选择该动作的理由", + "action_data": {{ + "target_message_id": "目标消息ID", + "content": "回复内容或其他动作所需数据" + }} + }} + ] +}} +``` + +如果没有合适的回复对象或不需要回复,输出空的 actions 数组: +```json +{{ + "thinking": "说明为什么不需要回复", + "actions": [] +}} +``` +""", + "proactive_planner_prompt", + ) + + # 轻量级规划器提示词,用于快速决策和简单场景 + Prompt( + """ +{identity_block} + +## 当前聊天情景 +{chat_context_description} + +## 未读消息 +{unread_history_block} + +**任务:快速决策** +请根据当前聊天内容,快速决定是否需要回复。 + +**决策规则:** +1. 如果有人直接提到你或问你问题,优先回复 +2. 如果消息内容符合你的兴趣,考虑回复 +3. 如果只是群聊中的普通聊天且与你无关,可以不回复 + +**输出格式:** +```json +{{ + "thinking": "简要分析", + "actions": [ + {{ + "action_type": "reply", + "reasoning": "回复理由", + "action_data": {{ + "target_message_id": "目标消息ID", + "content": "回复内容" + }} + }} + ] +}} +``` +""", + "chatter_planner_lite", + ) + + # 动作筛选器提示词,用于筛选和优化规划器生成的动作 + Prompt( + """ +{identity_block} + +## 原始动作计划 +{original_plan} + +## 聊天上下文 +{chat_context} + +**任务:动作筛选优化** +请对原始动作计划进行筛选和优化,确保动作的合理性和有效性。 + +**筛选原则:** +1. 移除重复或不必要的动作 +2. 确保动作之间的逻辑顺序 +3. 优化动作的具体参数 +4. 考虑当前聊天环境和个人设定 + +**输出格式:** +```json +{{ + "thinking": "筛选优化思考", + "actions": [ + {{ + "action_type": "优化后的动作类型", + "reasoning": "优化理由", + "action_data": {{ + "target_message_id": "目标消息ID", + "content": "优化后的内容" + }} + }} + ] +}} +``` +""", + "chatter_plan_filter", + ) + + # 动作提示词,用于格式化动作选项 + Prompt( + """ +## 动作: {action_name} +**描述**: {action_description} + +**参数**: +{action_parameters} + +**要求**: +{action_require} + +**使用说明**: +请根据上述信息判断是否需要使用此动作。 +""", + "action_prompt", + ) + + # 带有完整JSON示例的动作提示词模板 + Prompt( + """ +动作: {action_name} +动作描述: {action_description} +动作使用场景: +{action_require} + +你应该像这样使用它: +{{ +{json_example} +}} +""", + "action_prompt_with_example", + ) + + +# 确保提示词在模块加载时初始化 +init_prompts() diff --git a/src/plugins/built_in/affinity_flow_chatter/plugin.py b/src/plugins/built_in/affinity_flow_chatter/plugin.py new file mode 100644 index 000000000..7c86d13fe --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/plugin.py @@ -0,0 +1,46 @@ +""" +亲和力聊天处理器插件 +""" + +from typing import List, Tuple, Type + +from src.plugin_system.apis.plugin_register_api import register_plugin +from src.plugin_system.base.base_plugin import BasePlugin +from src.plugin_system.base.component_types import ComponentInfo +from src.common.logger import get_logger + +logger = get_logger("affinity_chatter_plugin") + + +@register_plugin +class AffinityChatterPlugin(BasePlugin): + """亲和力聊天处理器插件 + + - 延迟导入 `AffinityChatter` 并通过组件注册器注册为聊天处理器 + - 提供 `get_plugin_components` 以兼容插件注册机制 + """ + + plugin_name: str = "affinity_chatter" + enable_plugin: bool = True + dependencies: list[str] = [] + python_dependencies: list[str] = [] + config_file_name: str = "" + + # 简单的 config_schema 占位(如果将来需要配置可扩展) + config_schema = {} + + def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: + """返回插件包含的组件列表(ChatterInfo, AffinityChatter) + + 这里采用延迟导入 AffinityChatter 来避免循环依赖和启动顺序问题。 + 如果导入失败则返回空列表以让注册过程继续而不崩溃。 + """ + try: + # 延迟导入以避免循环导入 + from .affinity_chatter import AffinityChatter + + return [(AffinityChatter.get_chatter_info(), AffinityChatter)] + + except Exception as e: + logger.error(f"加载 AffinityChatter 时出错: {e}") + return [] diff --git a/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py b/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py new file mode 100644 index 000000000..c0050025e --- /dev/null +++ b/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py @@ -0,0 +1,755 @@ +""" +用户关系追踪器 +负责追踪用户交互历史,并通过LLM分析更新用户关系分 +支持数据库持久化存储和回复后自动关系更新 +""" + +import time +from typing import Dict, List, Optional + +from src.common.logger import get_logger +from src.config.config import model_config, global_config +from src.llm_models.utils_model import LLMRequest +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import UserRelationships, Messages +from sqlalchemy import select, desc +from src.common.data_models.database_data_model import DatabaseMessages + +logger = get_logger("chatter_relationship_tracker") + + +class ChatterRelationshipTracker: + """用户关系追踪器""" + + def __init__(self, interest_scoring_system=None): + self.tracking_users: Dict[str, Dict] = {} # user_id -> interaction_data + self.max_tracking_users = 3 + self.update_interval_minutes = 30 + self.last_update_time = time.time() + self.relationship_history: List[Dict] = [] + self.interest_scoring_system = interest_scoring_system + + # 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float}) + self.user_relationship_cache: Dict[str, Dict] = {} + self.cache_expiry_hours = 1 # 缓存过期时间(小时) + + # 关系更新LLM + try: + self.relationship_llm = LLMRequest( + model_set=model_config.model_task_config.relationship_tracker, request_type="relationship_tracker" + ) + except AttributeError: + # 如果relationship_tracker配置不存在,尝试其他可用的模型配置 + available_models = [ + attr + for attr in dir(model_config.model_task_config) + if not attr.startswith("_") and attr != "model_dump" + ] + + if available_models: + # 使用第一个可用的模型配置 + fallback_model = available_models[0] + logger.warning(f"relationship_tracker model configuration not found, using fallback: {fallback_model}") + self.relationship_llm = LLMRequest( + model_set=getattr(model_config.model_task_config, fallback_model), + request_type="relationship_tracker", + ) + else: + # 如果没有任何模型配置,创建一个简单的LLMRequest + logger.warning("No model configurations found, creating basic LLMRequest") + self.relationship_llm = LLMRequest( + model_set="gpt-3.5-turbo", # 默认模型 + request_type="relationship_tracker", + ) + + def set_interest_scoring_system(self, interest_scoring_system): + """设置兴趣度评分系统引用""" + self.interest_scoring_system = interest_scoring_system + + def add_interaction(self, user_id: str, user_name: str, user_message: str, bot_reply: str, reply_timestamp: float): + """添加用户交互记录""" + if len(self.tracking_users) >= self.max_tracking_users: + # 移除最旧的记录 + oldest_user = min( + self.tracking_users.keys(), key=lambda k: self.tracking_users[k].get("reply_timestamp", 0) + ) + del self.tracking_users[oldest_user] + + # 获取当前关系分 + current_relationship_score = global_config.affinity_flow.base_relationship_score # 默认值 + if self.interest_scoring_system: + current_relationship_score = self.interest_scoring_system.get_user_relationship(user_id) + + self.tracking_users[user_id] = { + "user_id": user_id, + "user_name": user_name, + "user_message": user_message, + "bot_reply": bot_reply, + "reply_timestamp": reply_timestamp, + "current_relationship_score": current_relationship_score, + } + + logger.debug(f"添加用户交互追踪: {user_id}") + + async def check_and_update_relationships(self) -> List[Dict]: + """检查并更新用户关系""" + current_time = time.time() + if current_time - self.last_update_time < self.update_interval_minutes * 60: + return [] + + updates = [] + for user_id, interaction in list(self.tracking_users.items()): + if current_time - interaction["reply_timestamp"] > 60 * 5: # 5分钟 + update = await self._update_user_relationship(interaction) + if update: + updates.append(update) + del self.tracking_users[user_id] + + self.last_update_time = current_time + return updates + + async def _update_user_relationship(self, interaction: Dict) -> Optional[Dict]: + """更新单个用户的关系""" + try: + # 获取bot人设信息 + from src.individuality.individuality import Individuality + + individuality = Individuality() + bot_personality = await individuality.get_personality_block() + + prompt = f""" +你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality} + +请以你独特的性格视角,严格按现实逻辑分析以下用户交互,更新用户关系: + +用户ID: {interaction["user_id"]} +用户名: {interaction["user_name"]} +用户消息: {interaction["user_message"]} +你的回复: {interaction["bot_reply"]} +当前关系分: {interaction["current_relationship_score"]} + +【重要】关系分数档次定义: +- 0.0-0.2:陌生人/初次认识 - 仅礼貌性交流 +- 0.2-0.4:普通网友 - 有基本互动但不熟悉 +- 0.4-0.6:熟悉网友 - 经常交流,有一定了解 +- 0.6-0.8:朋友 - 可以分享心情,互相关心 +- 0.8-1.0:好朋友/知己 - 深度信任,亲密无间 + +【严格要求】: +1. 加分必须符合现实关系发展逻辑 - 不能因为对方态度好就盲目加分到不符合当前关系档次的分数 +2. 关系提升需要足够的互动积累和时间验证 +3. 即使是朋友关系,单次互动加分通常不超过0.05-0.1 +4. 关系描述要详细具体,包括: + - 用户性格特点观察 + - 印象深刻的互动记忆 + - 你们关系的具体状态描述 + +根据你的人设性格,思考: +1. 以你的性格,你会如何看待这次互动? +2. 用户的行为是否符合你性格的喜好? +3. 这次互动是否真的让你们的关系提升了一个档次?为什么? +4. 有什么特别值得记住的互动细节? + +请以JSON格式返回更新结果: +{{ + "new_relationship_score": 0.0~1.0的数值(必须符合现实逻辑), + "reasoning": "从你的性格角度说明更新理由,重点说明是否符合现实关系发展逻辑", + "interaction_summary": "基于你性格的交互总结,包含印象深刻的互动记忆" +}} +""" + + llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt) + if llm_response: + import json + + try: + # 清理LLM响应,移除可能的格式标记 + cleaned_response = self._clean_llm_json_response(llm_response) + response_data = json.loads(cleaned_response) + new_score = max( + 0.0, + min( + 1.0, + float( + response_data.get( + "new_relationship_score", global_config.affinity_flow.base_relationship_score + ) + ), + ), + ) + + if self.interest_scoring_system: + self.interest_scoring_system.update_user_relationship( + interaction["user_id"], new_score - interaction["current_relationship_score"] + ) + + return { + "user_id": interaction["user_id"], + "new_relationship_score": new_score, + "reasoning": response_data.get("reasoning", ""), + "interaction_summary": response_data.get("interaction_summary", ""), + } + + except json.JSONDecodeError as e: + logger.error(f"LLM响应JSON解析失败: {e}") + logger.debug(f"LLM原始响应: {llm_response}") + except Exception as e: + logger.error(f"处理关系更新数据失败: {e}") + + except Exception as e: + logger.error(f"更新用户关系时出错: {e}") + + return None + + def get_tracking_users(self) -> Dict[str, Dict]: + """获取正在追踪的用户""" + return self.tracking_users.copy() + + def get_user_interaction(self, user_id: str) -> Optional[Dict]: + """获取特定用户的交互记录""" + return self.tracking_users.get(user_id) + + def remove_user_tracking(self, user_id: str): + """移除用户追踪""" + if user_id in self.tracking_users: + del self.tracking_users[user_id] + logger.debug(f"移除用户追踪: {user_id}") + + def clear_all_tracking(self): + """清空所有追踪""" + self.tracking_users.clear() + logger.info("清空所有用户追踪") + + def get_relationship_history(self) -> List[Dict]: + """获取关系历史记录""" + return self.relationship_history.copy() + + def add_to_history(self, relationship_update: Dict): + """添加到关系历史""" + self.relationship_history.append({**relationship_update, "update_time": time.time()}) + + # 限制历史记录数量 + if len(self.relationship_history) > 100: + self.relationship_history = self.relationship_history[-100:] + + def get_tracker_stats(self) -> Dict: + """获取追踪器统计""" + return { + "tracking_users": len(self.tracking_users), + "max_tracking_users": self.max_tracking_users, + "update_interval_minutes": self.update_interval_minutes, + "relationship_history": len(self.relationship_history), + "last_update_time": self.last_update_time, + } + + def update_config(self, max_tracking_users: int = None, update_interval_minutes: int = None): + """更新配置""" + if max_tracking_users is not None: + self.max_tracking_users = max_tracking_users + logger.info(f"更新最大追踪用户数: {max_tracking_users}") + + if update_interval_minutes is not None: + self.update_interval_minutes = update_interval_minutes + logger.info(f"更新关系更新间隔: {update_interval_minutes} 分钟") + + def force_update_relationship(self, user_id: str, new_score: float, reasoning: str = ""): + """强制更新用户关系分""" + if user_id in self.tracking_users: + current_score = self.tracking_users[user_id]["current_relationship_score"] + if self.interest_scoring_system: + self.interest_scoring_system.update_user_relationship(user_id, new_score - current_score) + + update_info = { + "user_id": user_id, + "new_relationship_score": new_score, + "reasoning": reasoning or "手动更新", + "interaction_summary": "手动更新关系分", + } + self.add_to_history(update_info) + logger.info(f"强制更新用户关系: {user_id} -> {new_score:.2f}") + + def get_user_summary(self, user_id: str) -> Dict: + """获取用户交互总结""" + if user_id not in self.tracking_users: + return {} + + interaction = self.tracking_users[user_id] + return { + "user_id": user_id, + "user_name": interaction["user_name"], + "current_relationship_score": interaction["current_relationship_score"], + "interaction_count": 1, # 简化版本,每次追踪只记录一次交互 + "last_interaction": interaction["reply_timestamp"], + "recent_message": interaction["user_message"][:100] + "..." + if len(interaction["user_message"]) > 100 + else interaction["user_message"], + } + + # ===== 数据库支持方法 ===== + + def get_user_relationship_score(self, user_id: str) -> float: + """获取用户关系分""" + # 先检查缓存 + if user_id in self.user_relationship_cache: + cache_data = self.user_relationship_cache[user_id] + # 检查缓存是否过期 + cache_time = cache_data.get("last_tracked", 0) + if time.time() - cache_time < self.cache_expiry_hours * 3600: + return cache_data.get("relationship_score", global_config.affinity_flow.base_relationship_score) + + # 缓存过期或不存在,从数据库获取 + relationship_data = self._get_user_relationship_from_db(user_id) + if relationship_data: + # 更新缓存 + self.user_relationship_cache[user_id] = { + "relationship_text": relationship_data.get("relationship_text", ""), + "relationship_score": relationship_data.get( + "relationship_score", global_config.affinity_flow.base_relationship_score + ), + "last_tracked": time.time(), + } + return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score) + + # 数据库中也没有,返回默认值 + return global_config.affinity_flow.base_relationship_score + + def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]: + """从数据库获取用户关系数据""" + try: + with get_db_session() as session: + # 查询用户关系表 + stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) + result = session.execute(stmt).scalar_one_or_none() + + if result: + return { + "relationship_text": result.relationship_text or "", + "relationship_score": float(result.relationship_score) + if result.relationship_score is not None + else 0.3, + "last_updated": result.last_updated, + } + except Exception as e: + logger.error(f"从数据库获取用户关系失败: {e}") + + return None + + def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float): + """更新数据库中的用户关系""" + try: + current_time = time.time() + + with get_db_session() as session: + # 检查是否已存在关系记录 + existing = session.execute( + select(UserRelationships).where(UserRelationships.user_id == user_id) + ).scalar_one_or_none() + + if existing: + # 更新现有记录 + existing.relationship_text = relationship_text + existing.relationship_score = relationship_score + existing.last_updated = current_time + existing.user_name = existing.user_name or user_id # 更新用户名如果为空 + else: + # 插入新记录 + new_relationship = UserRelationships( + user_id=user_id, + user_name=user_id, + relationship_text=relationship_text, + relationship_score=relationship_score, + last_updated=current_time, + ) + session.add(new_relationship) + + session.commit() + logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}") + + except Exception as e: + logger.error(f"更新数据库用户关系失败: {e}") + + # ===== 回复后关系追踪方法 ===== + + async def track_reply_relationship( + self, user_id: str, user_name: str, bot_reply_content: str, reply_timestamp: float + ): + """回复后关系追踪 - 主要入口点""" + try: + logger.info(f"🔄 [RelationshipTracker] 开始回复后关系追踪: {user_id}") + + # 检查上次追踪时间 + last_tracked_time = self._get_last_tracked_time(user_id) + time_diff = reply_timestamp - last_tracked_time + + if time_diff < 5 * 60: # 5分钟内不重复追踪 + logger.debug( + f"⏱️ [RelationshipTracker] 用户 {user_id} 距离上次追踪时间不足5分钟 ({time_diff:.2f}s),跳过" + ) + return + + # 获取上次bot回复该用户的消息 + last_bot_reply = await self._get_last_bot_reply_to_user(user_id) + if not last_bot_reply: + logger.info(f"👋 [RelationshipTracker] 未找到用户 {user_id} 的历史回复记录,启动'初次见面'逻辑") + await self._handle_first_interaction(user_id, user_name, bot_reply_content) + return + + # 获取用户后续的反应消息 + user_reactions = await self._get_user_reactions_after_reply(user_id, last_bot_reply.time) + logger.debug(f"💬 [RelationshipTracker] 找到用户 {user_id} 在上次回复后的 {len(user_reactions)} 条反应消息") + + # 获取当前关系数据 + current_relationship = self._get_user_relationship_from_db(user_id) + current_score = ( + current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score) + if current_relationship + else global_config.affinity_flow.base_relationship_score + ) + current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户" + + # 使用LLM分析并更新关系 + logger.debug(f"🧠 [RelationshipTracker] 开始为用户 {user_id} 分析并更新关系") + await self._analyze_and_update_relationship( + user_id, user_name, last_bot_reply, user_reactions, current_text, current_score, bot_reply_content + ) + + except Exception as e: + logger.error(f"回复后关系追踪失败: {e}") + logger.debug("错误详情:", exc_info=True) + + def _get_last_tracked_time(self, user_id: str) -> float: + """获取上次追踪时间""" + # 先检查缓存 + if user_id in self.user_relationship_cache: + return self.user_relationship_cache[user_id].get("last_tracked", 0) + + # 从数据库获取 + relationship_data = self._get_user_relationship_from_db(user_id) + if relationship_data: + return relationship_data.get("last_updated", 0) + + return 0 + + async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]: + """获取上次bot回复该用户的消息""" + try: + with get_db_session() as session: + # 查询bot回复给该用户的最新消息 + stmt = ( + select(Messages) + .where(Messages.user_id == user_id) + .where(Messages.reply_to.isnot(None)) + .order_by(desc(Messages.time)) + .limit(1) + ) + + result = session.execute(stmt).scalar_one_or_none() + if result: + # 将SQLAlchemy模型转换为DatabaseMessages对象 + return self._sqlalchemy_to_database_messages(result) + + except Exception as e: + logger.error(f"获取上次回复消息失败: {e}") + + return None + + async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]: + """获取用户在bot回复后的反应消息""" + try: + with get_db_session() as session: + # 查询用户在回复时间之后的5分钟内的消息 + end_time = reply_time + 5 * 60 # 5分钟 + + stmt = ( + select(Messages) + .where(Messages.user_id == user_id) + .where(Messages.time > reply_time) + .where(Messages.time <= end_time) + .order_by(Messages.time) + ) + + results = session.execute(stmt).scalars().all() + if results: + return [self._sqlalchemy_to_database_messages(result) for result in results] + + except Exception as e: + logger.error(f"获取用户反应消息失败: {e}") + + return [] + + def _sqlalchemy_to_database_messages(self, sqlalchemy_message) -> DatabaseMessages: + """将SQLAlchemy消息模型转换为DatabaseMessages对象""" + try: + return DatabaseMessages( + message_id=sqlalchemy_message.message_id or "", + time=float(sqlalchemy_message.time) if sqlalchemy_message.time is not None else 0.0, + chat_id=sqlalchemy_message.chat_id or "", + reply_to=sqlalchemy_message.reply_to, + processed_plain_text=sqlalchemy_message.processed_plain_text or "", + user_id=sqlalchemy_message.user_id or "", + user_nickname=sqlalchemy_message.user_nickname or "", + user_platform=sqlalchemy_message.user_platform or "", + ) + except Exception as e: + logger.error(f"SQLAlchemy消息转换失败: {e}") + # 返回一个基本的消息对象 + return DatabaseMessages( + message_id="", + time=0.0, + chat_id="", + processed_plain_text="", + user_id="", + user_nickname="", + user_platform="", + ) + + async def _analyze_and_update_relationship( + self, + user_id: str, + user_name: str, + last_bot_reply: DatabaseMessages, + user_reactions: List[DatabaseMessages], + current_text: str, + current_score: float, + current_reply: str, + ): + """使用LLM分析并更新用户关系""" + try: + # 构建分析提示 + user_reactions_text = "\n".join([f"- {msg.processed_plain_text}" for msg in user_reactions]) + + # 获取bot人设信息 + from src.individuality.individuality import Individuality + + individuality = Individuality() + bot_personality = await individuality.get_personality_block() + + prompt = f""" +你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality} + +请以你独特的性格视角,严格按现实逻辑分析以下用户交互,更新用户关系印象和分数: + +用户信息: +- 用户ID: {user_id} +- 用户名: {user_name} + +你上次的回复: {last_bot_reply.processed_plain_text} + +用户反应消息: +{user_reactions_text} + +你当前的回复: {current_reply} + +当前关系印象: {current_text} +当前关系分数: {current_score:.3f} + +【重要】关系分数档次定义: +- 0.0-0.2:陌生人/初次认识 - 仅礼貌性交流 +- 0.2-0.4:普通网友 - 有基本互动但不熟悉 +- 0.4-0.6:熟悉网友 - 经常交流,有一定了解 +- 0.6-0.8:朋友 - 可以分享心情,互相关心 +- 0.8-1.0:好朋友/知己 - 深度信任,亲密无间 + +【严格要求】: +1. 加分必须符合现实关系发展逻辑 - 不能因为用户反应好就盲目加分 +2. 关系提升需要足够的互动积累和时间验证,单次互动加分通常不超过0.05-0.1 +3. 必须考虑当前关系档次,不能跳跃式提升(比如从0.3直接到0.7) +4. 关系印象描述要详细具体(100-200字),包括: + - 用户性格特点和交流风格观察 + - 印象深刻的互动记忆和对话片段 + - 你们关系的具体状态描述和发展阶段 + - 根据你的性格,你对用户的真实感受 + +性格视角深度分析: +1. 以你的性格特点,用户这次的反应给你什么感受? +2. 用户的情绪和行为符合你性格的喜好吗?具体哪些方面? +3. 从现实角度看,这次互动是否足以让关系提升到下一个档次?为什么? +4. 有什么特别值得记住的互动细节或对话内容? +5. 基于你们的互动历史,用户给你留下了哪些深刻印象? + +请以JSON格式返回更新结果: +{{ + "relationship_text": "详细的关系印象描述(100-200字),包含用户性格观察、印象深刻记忆、关系状态描述", + "relationship_score": 0.0~1.0的新分数(必须严格符合现实逻辑), + "analysis_reasoning": "从你性格角度的深度分析,重点说明分数调整的现实合理性", + "interaction_quality": "high/medium/low" +}} +""" + + # 调用LLM进行分析 + llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt) + + if llm_response: + import json + + try: + # 清理LLM响应,移除可能的格式标记 + cleaned_response = self._clean_llm_json_response(llm_response) + response_data = json.loads(cleaned_response) + + new_text = response_data.get("relationship_text", current_text) + new_score = max(0.0, min(1.0, float(response_data.get("relationship_score", current_score)))) + reasoning = response_data.get("analysis_reasoning", "") + quality = response_data.get("interaction_quality", "medium") + + # 更新数据库 + self._update_user_relationship_in_db(user_id, new_text, new_score) + + # 更新缓存 + self.user_relationship_cache[user_id] = { + "relationship_text": new_text, + "relationship_score": new_score, + "last_tracked": time.time(), + } + + # 如果有兴趣度评分系统,也更新内存中的关系分 + if self.interest_scoring_system: + self.interest_scoring_system.update_user_relationship(user_id, new_score - current_score) + + # 记录分析历史 + analysis_record = { + "user_id": user_id, + "timestamp": time.time(), + "old_score": current_score, + "new_score": new_score, + "old_text": current_text, + "new_text": new_text, + "reasoning": reasoning, + "quality": quality, + "user_reactions_count": len(user_reactions), + } + self.relationship_history.append(analysis_record) + + # 限制历史记录数量 + if len(self.relationship_history) > 100: + self.relationship_history = self.relationship_history[-100:] + + logger.info(f"✅ 关系分析完成: {user_id}") + logger.info(f" 📝 印象: '{current_text}' -> '{new_text}'") + logger.info(f" 💝 分数: {current_score:.3f} -> {new_score:.3f}") + logger.info(f" 🎯 质量: {quality}") + + except json.JSONDecodeError as e: + logger.error(f"LLM响应JSON解析失败: {e}") + logger.debug(f"LLM原始响应: {llm_response}") + else: + logger.warning("LLM未返回有效响应") + + except Exception as e: + logger.error(f"关系分析失败: {e}") + logger.debug("错误详情:", exc_info=True) + + async def _handle_first_interaction(self, user_id: str, user_name: str, bot_reply_content: str): + """处理与用户的初次交互""" + try: + logger.info(f"✨ [RelationshipTracker] 正在处理与用户 {user_id} 的初次交互") + + # 获取bot人设信息 + from src.individuality.individuality import Individuality + + individuality = Individuality() + bot_personality = await individuality.get_personality_block() + + prompt = f""" +你现在是:{bot_personality} + +你正在与一个新用户进行初次有效互动。请根据你对TA的第一印象,建立初始关系档案。 + +用户信息: +- 用户ID: {user_id} +- 用户名: {user_name} + +你的首次回复: {bot_reply_content} + +【严格要求】: +1. 建立一个初始关系分数,通常在0.2-0.4之间(普通网友)。 +2. 关系印象描述要简洁地记录你对用户的初步看法(50-100字)。 + - 用户名给你的感觉? + - 你的回复是基于什么考虑? + - 你对接下来与TA的互动有什么期待? + +请以JSON格式返回结果: +{{ + "relationship_text": "简洁的初始关系印象描述(50-100字)", + "relationship_score": 0.2~0.4的新分数, + "analysis_reasoning": "从你性格角度说明建立此初始印象的理由" +}} +""" + # 调用LLM进行分析 + llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt) + if not llm_response: + logger.warning(f"初次交互分析时LLM未返回有效响应: {user_id}") + return + + import json + + cleaned_response = self._clean_llm_json_response(llm_response) + response_data = json.loads(cleaned_response) + + new_text = response_data.get("relationship_text", "初次见面") + new_score = max( + 0.0, + min( + 1.0, + float(response_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)), + ), + ) + + # 更新数据库和缓存 + self._update_user_relationship_in_db(user_id, new_text, new_score) + self.user_relationship_cache[user_id] = { + "relationship_text": new_text, + "relationship_score": new_score, + "last_tracked": time.time(), + } + + logger.info(f"✅ [RelationshipTracker] 已成功为新用户 {user_id} 建立初始关系档案,分数为 {new_score:.3f}") + + except Exception as e: + logger.error(f"处理初次交互失败: {user_id}, 错误: {e}") + logger.debug("错误详情:", exc_info=True) + + def _clean_llm_json_response(self, response: str) -> str: + """ + 清理LLM响应,移除可能的JSON格式标记 + + Args: + response: LLM原始响应 + + Returns: + 清理后的JSON字符串 + """ + try: + import re + + # 移除常见的JSON格式标记 + cleaned = response.strip() + + # 移除 ```json 或 ``` 等标记 + cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE) + cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE) + + # 移除可能的Markdown代码块标记 + cleaned = re.sub(r"^`|`$", "", cleaned, flags=re.MULTILINE) + + # 尝试找到JSON对象的开始和结束 + json_start = cleaned.find("{") + json_end = cleaned.rfind("}") + + if json_start != -1 and json_end != -1 and json_end > json_start: + # 提取JSON部分 + cleaned = cleaned[json_start : json_end + 1] + + # 移除多余的空白字符 + cleaned = cleaned.strip() + + logger.debug(f"LLM响应清理: 原始长度={len(response)}, 清理后长度={len(cleaned)}") + if cleaned != response: + logger.debug(f"清理前: {response[:200]}...") + logger.debug(f"清理后: {cleaned[:200]}...") + + return cleaned + + except Exception as e: + logger.warning(f"清理LLM响应失败: {e}") + return response # 清理失败时返回原始响应 diff --git a/src/plugins/built_in/at_user_plugin/plugin.py b/src/plugins/built_in/at_user_plugin/plugin.py index ba40903cd..820b37a27 100644 --- a/src/plugins/built_in/at_user_plugin/plugin.py +++ b/src/plugins/built_in/at_user_plugin/plugin.py @@ -64,50 +64,50 @@ class AtAction(BaseAction): # 使用回复器生成艾特回复,而不是直接发送命令 from src.chat.replyer.default_generator import DefaultReplyer from src.chat.message_receive.chat_stream import get_chat_manager - + # 获取当前聊天流 chat_manager = get_chat_manager() chat_stream = self.chat_stream or chat_manager.get_stream(self.chat_id) - + if not chat_stream: logger.error(f"找不到聊天流: {self.chat_stream}") return False, "聊天流不存在" - + # 创建回复器实例 replyer = DefaultReplyer(chat_stream) - + # 构建回复对象,将艾特消息作为回复目标 reply_to = f"{user_name}:{at_message}" extra_info = f"你需要艾特用户 {user_name} 并回复他们说: {at_message}" - + # 使用回复器生成回复 success, llm_response, prompt = await replyer.generate_reply_with_context( reply_to=reply_to, extra_info=extra_info, enable_tool=False, # 艾特回复通常不需要工具调用 - from_plugin=False + from_plugin=False, ) - + if success and llm_response: # 获取生成的回复内容 reply_content = llm_response.get("content", "") if reply_content: # 获取用户QQ号,发送真正的艾特消息 user_id = user_info.get("user_id") - + # 发送真正的艾特命令,使用回复器生成的智能内容 await self.send_command( "SEND_AT_MESSAGE", args={"qq_id": user_id, "text": reply_content}, display_message=f"艾特用户 {user_name} 并发送智能回复: {reply_content}", ) - + await self.store_action_info( action_build_into_prompt=True, action_prompt_display=f"执行了艾特用户动作:艾特用户 {user_name} 并发送智能回复: {reply_content}", action_done=True, ) - + logger.info(f"成功通过回复器生成智能内容并发送真正的艾特消息给 {user_name}: {reply_content}") return True, "智能艾特消息发送成功" else: @@ -116,7 +116,7 @@ class AtAction(BaseAction): else: logger.error("回复器生成回复失败") return False, "回复生成失败" - + except Exception as e: logger.error(f"执行艾特用户动作时发生异常: {e}", exc_info=True) await self.store_action_info( diff --git a/src/plugins/built_in/core_actions/_manifest.json b/src/plugins/built_in/core_actions/_manifest.json index 48ba76378..ae70035df 100644 --- a/src/plugins/built_in/core_actions/_manifest.json +++ b/src/plugins/built_in/core_actions/_manifest.json @@ -26,8 +26,8 @@ "components": [ { "type": "action", - "name": "emoji", - "description": "发送表情包辅助表达情绪" + "name": "emoji", + "description": "作为一条全新的消息,发送一个符合当前情景的表情包来生动地表达情绪。" } ] } diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index e8ffba68e..f7b9c231a 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -33,7 +33,7 @@ class EmojiAction(BaseAction): # 动作基本信息 action_name = "emoji" - action_description = "发送表情包辅助表达情绪" + action_description = "作为一条全新的消息,发送一个符合当前情景的表情包来生动地表达情绪。" # LLM判断提示词 llm_judge_prompt = """ @@ -70,7 +70,9 @@ class EmojiAction(BaseAction): # 2. 获取所有有效的表情包对象 emoji_manager = get_emoji_manager() - all_emojis_obj: list[MaiEmoji] = [e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description] + all_emojis_obj: list[MaiEmoji] = [ + e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description + ] if not all_emojis_obj: logger.warning(f"{self.log_prefix} 无法获取任何带有描述的有效表情包") return False, "无法获取任何带有描述的有效表情包" @@ -91,12 +93,12 @@ class EmojiAction(BaseAction): # 4. 准备情感数据和后备列表 emotion_map = {} all_emojis_data = [] - + for emoji in all_emojis_obj: b64 = image_path_to_base64(emoji.full_path) if not b64: continue - + desc = emoji.description emotions = emoji.emotion all_emojis_data.append((b64, desc)) @@ -122,10 +124,10 @@ class EmojiAction(BaseAction): emoji_base64, emoji_description = random.choice(all_emojis_data) else: # 获取最近的5条消息内容用于判断 - recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5) + recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5) messages_text = "" if recent_messages: - messages_text = await message_api.build_readable_messages( + messages_text = message_api.build_readable_messages( messages=recent_messages, timestamp_mode="normal_no_YMD", truncate=False, @@ -150,10 +152,10 @@ class EmojiAction(BaseAction): # 调用LLM models = llm_api.get_available_models() - chat_model_config = models.get("planner") + chat_model_config = models.get("utils") if not chat_model_config: - logger.error(f"{self.log_prefix} 未找到'planner'模型配置,无法调用LLM") - return False, "未找到'planner'模型配置" + logger.error(f"{self.log_prefix} 未找到'utils'模型配置,无法调用LLM") + return False, "未找到'utils'模型配置" success, chosen_emotion, _, _ = await llm_api.generate_with_model( prompt, model_config=chat_model_config, request_type="emoji" @@ -168,23 +170,25 @@ class EmojiAction(BaseAction): # 使用模糊匹配来查找最相关的情感标签 matched_key = next((key for key in emotion_map if chosen_emotion in key), None) - + if matched_key: emoji_base64, emoji_description = random.choice(emotion_map[matched_key]) - logger.info(f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}") + logger.info( + f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}" + ) else: logger.warning( f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包" ) emoji_base64, emoji_description = random.choice(all_emojis_data) - + elif global_config.emoji.emoji_selection_mode == "description": # --- 详细描述选择模式 --- # 获取最近的5条消息内容用于判断 - recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5) + recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5) messages_text = "" if recent_messages: - messages_text = await message_api.build_readable_messages( + messages_text = message_api.build_readable_messages( messages=recent_messages, timestamp_mode="normal_no_YMD", truncate=False, @@ -208,10 +212,10 @@ class EmojiAction(BaseAction): # 调用LLM models = llm_api.get_available_models() - chat_model_config = models.get("planner") + chat_model_config = models.get("utils") if not chat_model_config: - logger.error(f"{self.log_prefix} 未找到'planner'模型配置,无法调用LLM") - return False, "未找到'planner'模型配置" + logger.error(f"{self.log_prefix} 未找到'utils'模型配置,无法调用LLM") + return False, "未找到'utils'模型配置" success, chosen_description, _, _ = await llm_api.generate_with_model( prompt, model_config=chat_model_config, request_type="emoji" @@ -226,15 +230,23 @@ class EmojiAction(BaseAction): logger.info(f"{self.log_prefix} LLM选择的描述: {chosen_description}") # 简单关键词匹配 - matched_emoji = next((item for item in all_emojis_data if chosen_description.lower() in item[1].lower() or item[1].lower() in chosen_description.lower()), None) - + matched_emoji = next( + ( + item + for item in all_emojis_data + if chosen_description.lower() in item[1].lower() + or item[1].lower() in chosen_description.lower() + ), + None, + ) + # 如果包含匹配失败,尝试关键词匹配 if not matched_emoji: - keywords = ['惊讶', '困惑', '呆滞', '震惊', '懵', '无语', '萌', '可爱'] + keywords = ["惊讶", "困惑", "呆滞", "震惊", "懵", "无语", "萌", "可爱"] for keyword in keywords: if keyword in chosen_description: for item in all_emojis_data: - if any(k in item[1] for k in ['呆', '萌', '惊', '困惑', '无语']): + if any(k in item[1] for k in ["呆", "萌", "惊", "困惑", "无语"]): matched_emoji = item break if matched_emoji: @@ -255,7 +267,9 @@ class EmojiAction(BaseAction): if not success: logger.error(f"{self.log_prefix} 表情包发送失败") - await self.store_action_info(action_build_into_prompt = True,action_prompt_display ="发送了一个表情包,但失败了",action_done= False) + await self.store_action_info( + action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包,但失败了", action_done=False + ) return False, "表情包发送失败" # 发送成功后,记录到历史 @@ -263,8 +277,10 @@ class EmojiAction(BaseAction): add_emoji_to_history(self.chat_id, emoji_description) except Exception as e: logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}") - - await self.store_action_info(action_build_into_prompt = True,action_prompt_display ="发送了一个表情包",action_done= True) + + await self.store_action_info( + action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包", action_done=True + ) return True, f"发送表情包: {emoji_description}" diff --git a/src/plugins/built_in/napcat_adapter_plugin/_manifest.json b/src/plugins/built_in/napcat_adapter_plugin/_manifest.json index 676aa3121..1c8c0686f 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/_manifest.json +++ b/src/plugins/built_in/napcat_adapter_plugin/_manifest.json @@ -11,7 +11,7 @@ "host_application": { "min_version": "0.10.0", - "max_version": "0.10.0" + "max_version": "0.11.0" }, "homepage_url": "https://github.com/Windpicker-owo/InternetSearchPlugin", "repository_url": "https://github.com/Windpicker-owo/InternetSearchPlugin", diff --git a/src/plugins/built_in/napcat_adapter_plugin/event_handlers.py b/src/plugins/built_in/napcat_adapter_plugin/event_handlers.py index c4f889712..9fe6f8096 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/event_handlers.py +++ b/src/plugins/built_in/napcat_adapter_plugin/event_handlers.py @@ -1,4 +1,3 @@ - from src.plugin_system import BaseEventHandler from src.plugin_system.base.base_event import HandlerResult @@ -1748,6 +1747,7 @@ class SetGroupSignHandler(BaseEventHandler): logger.error("事件 napcat_set_group_sign 请求失败!") return HandlerResult(False, False, {"status": "error"}) + # ===PERSONAL=== class SetInputStatusHandler(BaseEventHandler): handler_name: str = "napcat_set_input_status_handler" diff --git a/src/plugins/built_in/napcat_adapter_plugin/plugin.py b/src/plugins/built_in/napcat_adapter_plugin/plugin.py index 3cd7973e7..fa0eeed23 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/plugin.py +++ b/src/plugins/built_in/napcat_adapter_plugin/plugin.py @@ -233,7 +233,7 @@ class LauchNapcatAdapterHandler(BaseEventHandler): await reassembler.start_cleanup_task() logger.info("开始启动Napcat Adapter") - + # 创建单独的异步任务,防止阻塞主线程 asyncio.create_task(self._start_maibot_connection()) asyncio.create_task(napcat_server(self.plugin_config)) @@ -244,10 +244,10 @@ class LauchNapcatAdapterHandler(BaseEventHandler): """非阻塞方式启动MaiBot连接,等待主服务启动后再连接""" # 等待一段时间让MaiBot主服务完全启动 await asyncio.sleep(5) - + max_attempts = 10 attempt = 0 - + while attempt < max_attempts: try: logger.info(f"尝试连接MaiBot (第{attempt + 1}次)") @@ -291,7 +291,7 @@ class NapcatAdapterPlugin(BasePlugin): def enable_plugin(self) -> bool: """通过配置文件动态控制插件启用状态""" # 如果已经通过配置加载了状态,使用配置中的值 - if hasattr(self, '_is_enabled'): + if hasattr(self, "_is_enabled"): return self._is_enabled # 否则使用默认值(禁用状态) return False @@ -305,7 +305,7 @@ class NapcatAdapterPlugin(BasePlugin): "name": ConfigField(type=str, default="napcat_adapter_plugin", description="插件名称"), "version": ConfigField(type=str, default="1.1.0", description="插件版本"), "config_version": ConfigField(type=str, default="1.3.1", description="配置文件版本"), - "enabled": ConfigField(type=bool, default=False, description="是否启用插件"), + "enabled": ConfigField(type=bool, default=True, description="是否启用插件"), }, "inner": { "version": ConfigField(type=str, default="0.2.1", description="配置版本号,请勿修改"), @@ -314,60 +314,88 @@ class NapcatAdapterPlugin(BasePlugin): "nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"), }, "napcat_server": { - "mode": ConfigField(type=str, default="reverse", description="连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]), + "mode": ConfigField( + type=str, + default="reverse", + description="连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", + choices=["reverse", "forward"], + ), "host": ConfigField(type=str, default="localhost", description="主机地址"), "port": ConfigField(type=int, default=8095, description="端口号"), - "url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)"), - "access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"), + "url": ConfigField( + type=str, + default="", + description="正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)", + ), + "access_token": ConfigField( + type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)" + ), "heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"), }, "maibot_server": { - "host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址,即HOST字段"), + "host": ConfigField( + type=str, default="localhost", description="麦麦在.env文件中设置的主机地址,即HOST字段" + ), "port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口,即PORT字段"), "platform_name": ConfigField(type=str, default="qq", description="平台名称,用于消息路由"), }, "voice": { - "use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音(请确保你配置了tts并有对应的adapter)"), + "use_tts": ConfigField( + type=bool, default=False, description="是否使用tts语音(请确保你配置了tts并有对应的adapter)" + ), }, "slicing": { - "max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小,单位为字节,默认64KB"), + "max_frame_size": ConfigField( + type=int, default=64, description="WebSocket帧的最大大小,单位为字节,默认64KB" + ), "delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"), }, "debug": { - "level": ConfigField(type=str, default="INFO", description="日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), + "level": ConfigField( + type=str, + default="INFO", + description="日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + ), }, "features": { # 权限设置 - "group_list_type": ConfigField(type=str, default="blacklist", description="群聊列表类型:whitelist(白名单)或 blacklist(黑名单)", choices=["whitelist", "blacklist"]), + "group_list_type": ConfigField( + type=str, + default="blacklist", + description="群聊列表类型:whitelist(白名单)或 blacklist(黑名单)", + choices=["whitelist", "blacklist"], + ), "group_list": ConfigField(type=list, default=[], description="群聊ID列表"), - "private_list_type": ConfigField(type=str, default="blacklist", description="私聊列表类型:whitelist(白名单)或 blacklist(黑名单)", choices=["whitelist", "blacklist"]), + "private_list_type": ConfigField( + type=str, + default="blacklist", + description="私聊列表类型:whitelist(白名单)或 blacklist(黑名单)", + choices=["whitelist", "blacklist"], + ), "private_list": ConfigField(type=list, default=[], description="用户ID列表"), - "ban_user_id": ConfigField(type=list, default=[], description="全局禁止用户ID列表,这些用户无法在任何地方使用机器人"), + "ban_user_id": ConfigField( + type=list, default=[], description="全局禁止用户ID列表,这些用户无法在任何地方使用机器人" + ), "ban_qq_bot": ConfigField(type=bool, default=False, description="是否屏蔽QQ官方机器人消息"), - # 聊天功能设置 "enable_poke": ConfigField(type=bool, default=True, description="是否启用戳一戳功能"), "ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否无视不是针对自己的戳一戳"), - "poke_debounce_seconds": ConfigField(type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"), + "poke_debounce_seconds": ConfigField( + type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略" + ), "enable_reply_at": ConfigField(type=bool, default=True, description="是否启用引用回复时艾特用户的功能"), "reply_at_rate": ConfigField(type=float, default=0.5, description="引用回复时艾特用户的几率 (0.0 ~ 1.0)"), "enable_emoji_like": ConfigField(type=bool, default=True, description="是否启用群聊表情回复功能"), - # 视频处理设置 "enable_video_analysis": ConfigField(type=bool, default=True, description="是否启用视频识别功能"), "max_video_size_mb": ConfigField(type=int, default=100, description="视频文件最大大小限制(MB)"), "download_timeout": ConfigField(type=int, default=60, description="视频下载超时时间(秒)"), - "supported_formats": ConfigField(type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"), - - # 消息缓冲设置 - "enable_message_buffer": ConfigField(type=bool, default=True, description="是否启用消息缓冲合并功能"), - "message_buffer_enable_group": ConfigField(type=bool, default=True, description="是否启用群聊消息缓冲合并"), - "message_buffer_enable_private": ConfigField(type=bool, default=True, description="是否启用私聊消息缓冲合并"), - "message_buffer_interval": ConfigField(type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"), - "message_buffer_initial_delay": ConfigField(type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"), - "message_buffer_max_components": ConfigField(type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"), - "message_buffer_block_prefixes": ConfigField(type=list, default=["/", "!", "!", ".", "。", "#", "%"], description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"), - } + "supported_formats": ConfigField( + type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式" + ), + # 消息缓冲功能已移除 + }, } # 配置节描述 @@ -380,7 +408,7 @@ class NapcatAdapterPlugin(BasePlugin): "voice": "发送语音设置", "slicing": "WebSocket消息切片设置", "debug": "调试设置", - "features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)" + "features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)", } def register_events(self): @@ -414,6 +442,7 @@ class NapcatAdapterPlugin(BasePlugin): chunker.set_plugin_config(self.config) # 设置response_pool的插件配置 from .src.response_pool import set_plugin_config as set_response_pool_config + set_response_pool_config(self.config) # 设置send_handler的插件配置 send_handler.set_plugin_config(self.config) @@ -423,4 +452,4 @@ class NapcatAdapterPlugin(BasePlugin): notice_handler.set_plugin_config(self.config) # 设置meta_event_handler的插件配置 meta_event_handler.set_plugin_config(self.config) - # 设置其他handler的插件配置(现在由component_registry在注册时自动设置) \ No newline at end of file + # 设置其他handler的插件配置(现在由component_registry在注册时自动设置) diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py b/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py deleted file mode 100644 index 73216942e..000000000 --- a/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py +++ /dev/null @@ -1,317 +0,0 @@ -import asyncio -import time -from typing import Dict, List, Any, Optional -from dataclasses import dataclass, field - -from src.common.logger import get_logger - -logger = get_logger("napcat_adapter") - -from src.plugin_system.apis import config_api -from .recv_handler import RealMessageType - - -@dataclass -class TextMessage: - """文本消息""" - - text: str - timestamp: float = field(default_factory=time.time) - - -@dataclass -class BufferedSession: - """缓冲会话数据""" - - session_id: str - messages: List[TextMessage] = field(default_factory=list) - timer_task: Optional[asyncio.Task] = None - delay_task: Optional[asyncio.Task] = None - original_event: Any = None - created_at: float = field(default_factory=time.time) - - -class SimpleMessageBuffer: - def __init__(self, merge_callback=None): - """ - 初始化消息缓冲器 - - Args: - merge_callback: 消息合并后的回调函数,接收(session_id, merged_text, original_event)参数 - """ - self.buffer_pool: Dict[str, BufferedSession] = {} - self.lock = asyncio.Lock() - self.merge_callback = merge_callback - self._shutdown = False - self.plugin_config = None - - def set_plugin_config(self, plugin_config: dict): - """设置插件配置""" - self.plugin_config = plugin_config - - @staticmethod - def get_session_id(event_data: Dict[str, Any]) -> str: - """根据事件数据生成会话ID""" - message_type = event_data.get("message_type", "unknown") - user_id = event_data.get("user_id", "unknown") - - if message_type == "private": - return f"private_{user_id}" - elif message_type == "group": - group_id = event_data.get("group_id", "unknown") - return f"group_{group_id}_{user_id}" - else: - return f"{message_type}_{user_id}" - - @staticmethod - def extract_text_from_message(message: List[Dict[str, Any]]) -> Optional[str]: - """从OneBot消息中提取纯文本,如果包含非文本内容则返回None""" - text_parts = [] - has_non_text = False - - logger.debug(f"正在提取消息文本,消息段数量: {len(message)}") - - for msg_seg in message: - msg_type = msg_seg.get("type", "") - logger.debug(f"处理消息段类型: {msg_type}") - - if msg_type == RealMessageType.text: - text = msg_seg.get("data", {}).get("text", "").strip() - if text: - text_parts.append(text) - logger.debug(f"提取到文本: {text[:50]}...") - else: - # 发现非文本消息段,标记为包含非文本内容 - has_non_text = True - logger.debug(f"发现非文本消息段: {msg_type},跳过缓冲") - - # 如果包含非文本内容,则不进行缓冲 - if has_non_text: - logger.debug("消息包含非文本内容,不进行缓冲") - return None - - if text_parts: - combined_text = " ".join(text_parts).strip() - logger.debug(f"成功提取纯文本: {combined_text[:50]}...") - return combined_text - - logger.debug("没有找到有效的文本内容") - return None - - def should_skip_message(self, text: str) -> bool: - """判断消息是否应该跳过缓冲""" - if not text or not text.strip(): - return True - - # 检查屏蔽前缀 - block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", [])) - - text = text.strip() - if text.startswith(block_prefixes): - logger.debug(f"消息以屏蔽前缀开头,跳过缓冲: {text[:20]}...") - return True - - return False - - async def add_text_message( - self, event_data: Dict[str, Any], message: List[Dict[str, Any]], original_event: Any = None - ) -> bool: - """ - 添加文本消息到缓冲区 - - Args: - event_data: 事件数据 - message: OneBot消息数组 - original_event: 原始事件对象 - - Returns: - 是否成功添加到缓冲区 - """ - if self._shutdown: - return False - - # 检查是否启用消息缓冲 - if not config_api.get_plugin_config(self.plugin_config, "features.enable_message_buffer", False): - return False - - # 检查是否启用对应类型的缓冲 - message_type = event_data.get("message_type", "") - if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False): - return False - elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False): - return False - - # 提取文本 - text = self.extract_text_from_message(message) - if not text: - return False - - # 检查是否应该跳过 - if self.should_skip_message(text): - return False - - session_id = self.get_session_id(event_data) - - async with self.lock: - # 获取或创建会话 - if session_id not in self.buffer_pool: - self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event) - - session = self.buffer_pool[session_id] - - # 检查是否超过最大组件数量 - if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5): - logger.debug(f"会话 {session_id} 消息数量达到上限,强制合并") - asyncio.create_task(self._force_merge_session(session_id)) - self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event) - session = self.buffer_pool[session_id] - - # 添加文本消息 - session.messages.append(TextMessage(text=text)) - session.original_event = original_event # 更新事件 - - # 取消之前的定时器 - await self._cancel_session_timers(session) - - # 设置新的延迟任务 - session.delay_task = asyncio.create_task(self._wait_and_start_merge(session_id)) - - logger.debug(f"文本消息已添加到缓冲器 {session_id}: {text[:50]}...") - return True - - @staticmethod - async def _cancel_session_timers(session: BufferedSession): - """取消会话的所有定时器""" - for task_name in ["timer_task", "delay_task"]: - task = getattr(session, task_name) - if task and not task.done(): - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - setattr(session, task_name, None) - - async def _wait_and_start_merge(self, session_id: str): - """等待初始延迟后开始合并定时器""" - initial_delay = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_initial_delay", 0.5) - await asyncio.sleep(initial_delay) - - async with self.lock: - session = self.buffer_pool.get(session_id) - if session and session.messages: - # 取消旧的定时器 - if session.timer_task and not session.timer_task.done(): - session.timer_task.cancel() - try: - await session.timer_task - except asyncio.CancelledError: - pass - - # 设置合并定时器 - session.timer_task = asyncio.create_task(self._wait_and_merge(session_id)) - - async def _wait_and_merge(self, session_id: str): - """等待合并间隔后执行合并""" - interval = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_interval", 2.0) - await asyncio.sleep(interval) - await self._merge_session(session_id) - - async def _force_merge_session(self, session_id: str): - """强制合并会话(不等待定时器)""" - await self._merge_session(session_id, force=True) - - async def _merge_session(self, session_id: str, force: bool = False): - """合并会话中的消息""" - async with self.lock: - session = self.buffer_pool.get(session_id) - if not session or not session.messages: - self.buffer_pool.pop(session_id, None) - return - - try: - # 合并文本消息 - text_parts = [] - for msg in session.messages: - if msg.text.strip(): - text_parts.append(msg.text.strip()) - - if not text_parts: - self.buffer_pool.pop(session_id, None) - return - - merged_text = ",".join(text_parts) # 使用中文逗号连接 - message_count = len(session.messages) - - logger.debug(f"合并会话 {session_id} 的 {message_count} 条文本消息: {merged_text[:100]}...") - - # 调用回调函数 - if self.merge_callback: - try: - if asyncio.iscoroutinefunction(self.merge_callback): - await self.merge_callback(session_id, merged_text, session.original_event) - else: - self.merge_callback(session_id, merged_text, session.original_event) - except Exception as e: - logger.error(f"消息合并回调执行失败: {e}") - - except Exception as e: - logger.error(f"合并会话 {session_id} 时出错: {e}") - finally: - # 清理会话 - await self._cancel_session_timers(session) - self.buffer_pool.pop(session_id, None) - - async def flush_session(self, session_id: str): - """强制刷新指定会话的缓冲区""" - await self._force_merge_session(session_id) - - async def flush_all(self): - """强制刷新所有会话的缓冲区""" - session_ids = list(self.buffer_pool.keys()) - for session_id in session_ids: - await self._force_merge_session(session_id) - - async def get_buffer_stats(self) -> Dict[str, Any]: - """获取缓冲区统计信息""" - async with self.lock: - stats = {"total_sessions": len(self.buffer_pool), "sessions": {}} - - for session_id, session in self.buffer_pool.items(): - stats["sessions"][session_id] = { - "message_count": len(session.messages), - "created_at": session.created_at, - "age": time.time() - session.created_at, - } - - return stats - - async def clear_expired_sessions(self, max_age: float = 300.0): - """清理过期的会话""" - current_time = time.time() - expired_sessions = [] - - async with self.lock: - for session_id, session in self.buffer_pool.items(): - if current_time - session.created_at > max_age: - expired_sessions.append(session_id) - - for session_id in expired_sessions: - logger.debug(f"清理过期会话: {session_id}") - await self._force_merge_session(session_id) - - async def shutdown(self): - """关闭消息缓冲器""" - self._shutdown = True - logger.debug("正在关闭简化消息缓冲器...") - - # 刷新所有缓冲区 - await self.flush_all() - - # 确保所有任务都被取消 - async with self.lock: - for session in list(self.buffer_pool.values()): - await self._cancel_session_timers(session) - self.buffer_pool.clear() - - logger.debug("简化消息缓冲器已关闭") diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py b/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py index c735d63cf..acd12fe01 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py @@ -11,10 +11,10 @@ router = None def create_router(plugin_config: dict): """创建路由器实例""" global router - platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "napcat") + platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "qq") host = config_api.get_plugin_config(plugin_config, "maibot_server.host", "localhost") port = config_api.get_plugin_config(plugin_config, "maibot_server.port", 8000) - + route_config = RouteConfig( route_config={ platform_name: TargetConfig( @@ -32,7 +32,7 @@ async def mmc_start_com(plugin_config: dict = None): logger.info("正在连接MaiBot") if plugin_config: create_router(plugin_config) - + if router: router.register_class_handler(send_handler.handle_message) await router.run() diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/__init__.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/__init__.py index 48561ffbe..231c0ce39 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/__init__.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/__init__.py @@ -32,7 +32,7 @@ class NoticeType: # 通知事件 group_recall = "group_recall" # 群聊消息撤回 notify = "notify" group_ban = "group_ban" # 群禁言 - group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复 + group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复 class Notify: poke = "poke" # 戳一戳 diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py index a19ca85e5..ab0dac46b 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py @@ -6,7 +6,6 @@ from ...CONSTS import PLUGIN_NAME logger = get_logger("napcat_adapter") from src.plugin_system.apis import config_api -from ..message_buffer import SimpleMessageBuffer from ..utils import ( get_group_info, get_member_info, @@ -48,20 +47,18 @@ class MessageHandler: self.server_connection: Server.ServerConnection = None self.bot_id_list: Dict[int, bool] = {} self.plugin_config = None - # 初始化简化消息缓冲器,传入回调函数 - self.message_buffer = SimpleMessageBuffer(merge_callback=self._send_buffered_message) + # 消息缓冲功能已移除 def set_plugin_config(self, plugin_config: dict): """设置插件配置""" self.plugin_config = plugin_config - # 将配置传递给消息缓冲器 - if self.message_buffer: - self.message_buffer.set_plugin_config(plugin_config) + # 消息缓冲功能已移除 async def shutdown(self): """关闭消息处理器,清理资源""" - if self.message_buffer: - await self.message_buffer.shutdown() + # 消息缓冲功能已移除 + + # 消息缓冲功能已移除 async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: """设置Napcat连接""" @@ -100,7 +97,7 @@ class MessageHandler: # 检查群聊黑白名单 group_list_type = config_api.get_plugin_config(self.plugin_config, "features.group_list_type", "blacklist") group_list = config_api.get_plugin_config(self.plugin_config, "features.group_list", []) - + if group_list_type == "whitelist": if group_id not in group_list: logger.warning("群聊不在白名单中,消息被丢弃") @@ -111,9 +108,11 @@ class MessageHandler: return False else: # 检查私聊黑白名单 - private_list_type = config_api.get_plugin_config(self.plugin_config, "features.private_list_type", "blacklist") + private_list_type = config_api.get_plugin_config( + self.plugin_config, "features.private_list_type", "blacklist" + ) private_list = config_api.get_plugin_config(self.plugin_config, "features.private_list", []) - + if private_list_type == "whitelist": if user_id not in private_list: logger.warning("私聊不在白名单中,消息被丢弃") @@ -156,21 +155,23 @@ class MessageHandler: Parameters: raw_message: dict: 原始消息 """ - + # 添加原始消息调试日志,特别关注message字段 - logger.debug(f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}") + logger.debug( + f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}" + ) logger.debug(f"原始消息内容: {raw_message.get('message', [])}") - + # 检查是否包含@或video消息段 - message_segments = raw_message.get('message', []) + message_segments = raw_message.get("message", []) if message_segments: for i, seg in enumerate(message_segments): - seg_type = seg.get('type') - if seg_type in ['at', 'video']: + seg_type = seg.get("type") + if seg_type in ["at", "video"]: logger.info(f"检测到 {seg_type.upper()} 消息段 [{i}]: {seg}") - elif seg_type not in ['text', 'face', 'image']: + elif seg_type not in ["text", "face", "image"]: logger.warning(f"检测到特殊消息段 [{i}]: type={seg_type}, data={seg.get('data', {})}") - + message_type: str = raw_message.get("message_type") message_id: int = raw_message.get("message_id") # message_time: int = raw_message.get("time") @@ -301,38 +302,7 @@ class MessageHandler: logger.warning("处理后消息内容为空") return None - # 检查是否需要使用消息缓冲 - enable_message_buffer = config_api.get_plugin_config(self.plugin_config, "features.enable_message_buffer", True) - if enable_message_buffer: - # 检查消息类型是否启用缓冲 - message_type = raw_message.get("message_type") - should_use_buffer = False - - if message_type == "group" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", True): - should_use_buffer = True - elif message_type == "private" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", True): - should_use_buffer = True - - if should_use_buffer: - logger.debug(f"尝试缓冲消息,消息类型: {message_type}, 用户: {user_info.user_id}") - - # 尝试添加到缓冲器 - buffered = await self.message_buffer.add_text_message( - event_data={ - "message_type": message_type, - "user_id": user_info.user_id, - "group_id": group_info.group_id if group_info else None, - }, - message=raw_message.get("message", []), - original_event={"message_info": message_info, "raw_message": raw_message}, - ) - - if buffered: - logger.debug(f"✅ 文本消息已成功缓冲: {user_info.user_id}") - return None # 缓冲成功,不立即发送 - # 如果缓冲失败(消息包含非文本元素),走正常处理流程 - logger.debug(f"❌ 消息缓冲失败,包含非文本元素,走正常处理流程: {user_info.user_id}") - # 缓冲失败时继续执行后面的正常处理流程,不要直接返回 + # 消息缓冲功能已移除,直接处理消息 logger.debug(f"准备发送消息到MaiBot,消息段数量: {len(seg_message)}") for i, seg in enumerate(seg_message): @@ -351,7 +321,6 @@ class MessageHandler: logger.debug("发送到Maibot处理信息") await message_send_instance.message_send(message_base) - return None async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None: # sourcery skip: low-code-quality @@ -369,10 +338,10 @@ class MessageHandler: for sub_message in real_message: sub_message: dict sub_message_type = sub_message.get("type") - + # 添加详细的消息类型调试信息 logger.debug(f"处理消息段: type={sub_message_type}, data={sub_message.get('data', {})}") - + # 特别关注 at 和 video 消息的识别 if sub_message_type == "at": logger.debug(f"检测到@消息: {sub_message}") @@ -380,7 +349,7 @@ class MessageHandler: logger.debug(f"检测到VIDEO消息: {sub_message}") elif sub_message_type not in ["text", "face", "image", "record"]: logger.warning(f"检测到特殊消息类型: {sub_message_type}, 完整消息: {sub_message}") - + match sub_message_type: case RealMessageType.text: ret_seg = await self.handle_text_message(sub_message) @@ -519,8 +488,7 @@ class MessageHandler: logger.debug(f"handle_real_message完成,处理了{len(real_message)}个消息段,生成了{len(seg_message)}个seg") return seg_message - @staticmethod - async def handle_text_message(raw_message: dict) -> Seg: + async def handle_text_message(self, raw_message: dict) -> Seg: """ 处理纯文本信息 Parameters: @@ -532,8 +500,7 @@ class MessageHandler: plain_text: str = message_data.get("text") return Seg(type="text", data=plain_text) - @staticmethod - async def handle_face_message(raw_message: dict) -> Seg | None: + async def handle_face_message(self, raw_message: dict) -> Seg | None: """ 处理表情消息 Parameters: @@ -550,8 +517,7 @@ class MessageHandler: logger.warning(f"不支持的表情:{face_raw_id}") return None - @staticmethod - async def handle_image_message(raw_message: dict) -> Seg | None: + async def handle_image_message(self, raw_message: dict) -> Seg | None: """ 处理图片消息与表情包消息 Parameters: @@ -607,7 +573,6 @@ class MessageHandler: return Seg(type="at", data=f"{member_info.get('nickname')}:{member_info.get('user_id')}") else: return None - return None async def handle_record_message(self, raw_message: dict) -> Seg | None: """ @@ -636,8 +601,7 @@ class MessageHandler: return None return Seg(type="voice", data=audio_base64) - @staticmethod - async def handle_video_message(raw_message: dict) -> Seg | None: + async def handle_video_message(self, raw_message: dict) -> Seg | None: """ 处理视频消息 Parameters: @@ -744,7 +708,6 @@ class MessageHandler: reply_message = [Seg(type="text", data="(获取发言内容失败)")] sender_info: dict = message_detail.get("sender") sender_nickname: str = sender_info.get("nickname") - sender_id: str = sender_info.get("user_id") seg_message: List[Seg] = [] if not sender_nickname: logger.warning("无法获取被引用的人的昵称,返回默认值") @@ -768,7 +731,7 @@ class MessageHandler: return None processed_message: Seg - if 5 > image_count > 0: + if image_count < 5 and image_count > 0: # 处理图片数量小于5的情况,此时解析图片为base64 logger.debug("图片数量小于5,开始解析图片为base64") processed_message = await self._recursive_parse_image_seg(handled_message, True) @@ -785,18 +748,15 @@ class MessageHandler: forward_hint = Seg(type="text", data="这是一条转发消息:\n") return Seg(type="seglist", data=[forward_hint, processed_message]) - @staticmethod - async def handle_dice_message(raw_message: dict) -> Seg: + async def handle_dice_message(self, raw_message: dict) -> Seg: message_data: dict = raw_message.get("data", {}) res = message_data.get("result", "") return Seg(type="text", data=f"[扔了一个骰子,点数是{res}]") - @staticmethod - async def handle_shake_message(raw_message: dict) -> Seg: + async def handle_shake_message(self, raw_message: dict) -> Seg: return Seg(type="text", data="[向你发送了窗口抖动,现在你的屏幕猛烈地震了一下!]") - @staticmethod - async def handle_json_message(raw_message: dict) -> Seg | None: + async def handle_json_message(self, raw_message: dict) -> Seg: """ 处理JSON消息 Parameters: @@ -868,43 +828,6 @@ class MessageHandler: data=f"这是一条小程序分享消息,可以根据来源,考虑使用对应解析工具\n{formatted_content}", ) - # 检查是否是音乐分享 - elif nested_data.get("view") == "music" and "music" in nested_data.get("meta", {}): - logger.debug("检测到音乐分享消息,开始提取信息") - music_info = nested_data["meta"]["music"] - title = music_info.get("title", "未知歌曲") - desc = music_info.get("desc", "未知艺术家") - jump_url = music_info.get("jumpUrl", "") - preview_url = music_info.get("preview", "") - source = music_info.get("tag", "未知来源") - - # 优化文本结构,使其更像卡片 - text_parts = [ - "--- 音乐分享 ---", - f"歌曲:{title}", - f"歌手:{desc}", - f"来源:{source}" - ] - if jump_url: - text_parts.append(f"链接:{jump_url}") - text_parts.append("----------------") - - text_content = "\n".join(text_parts) - - # 如果有预览图,创建一个seglist包含文本和图片 - if preview_url: - try: - image_base64 = await get_image_base64(preview_url) - if image_base64: - return Seg(type="seglist", data=[ - Seg(type="text", data=text_content + "\n"), - Seg(type="image", data=image_base64) - ]) - except Exception as e: - logger.error(f"下载音乐预览图失败: {e}") - - return Seg(type="text", data=text_content) - # 如果没有提取到关键信息,返回None return None @@ -915,8 +838,7 @@ class MessageHandler: logger.error(f"处理JSON消息时出错: {e}") return None - @staticmethod - async def handle_rps_message(raw_message: dict) -> Seg: + async def handle_rps_message(self, raw_message: dict) -> Seg: message_data: dict = raw_message.get("data", {}) res = message_data.get("result", "") if res == "1": @@ -1099,55 +1021,7 @@ class MessageHandler: return None return response_data.get("messages") - @staticmethod - async def _send_buffered_message(session_id: str, merged_text: str, original_event: Dict[str, Any]): - """发送缓冲的合并消息""" - try: - # 从原始事件数据中提取信息 - message_info = original_event.get("message_info") - raw_message = original_event.get("raw_message") - - if not message_info or not raw_message: - logger.error("缓冲消息缺少必要信息") - return - - # 创建合并后的消息段 - 将合并的文本转换为Seg格式 - from maim_message import Seg - - merged_seg = Seg(type="text", data=merged_text) - submit_seg = Seg(type="seglist", data=[merged_seg]) - - # 创建新的消息ID - import time - - new_message_id = f"buffered-{message_info.message_id}-{int(time.time() * 1000)}" - - # 更新消息信息 - from maim_message import BaseMessageInfo, MessageBase - - buffered_message_info = BaseMessageInfo( - platform=message_info.platform, - message_id=new_message_id, - time=time.time(), - user_info=message_info.user_info, - group_info=message_info.group_info, - template_info=message_info.template_info, - format_info=message_info.format_info, - additional_config=message_info.additional_config, - ) - - # 创建MessageBase - message_base = MessageBase( - message_info=buffered_message_info, - message_segment=submit_seg, - raw_message=raw_message.get("raw_message", ""), - ) - - logger.debug(f"发送缓冲合并消息到Maibot处理: {session_id}") - await message_send_instance.message_send(message_base) - - except Exception as e: - logger.error(f"发送缓冲消息失败: {e}", exc_info=True) + # 消息缓冲功能已移除 message_handler = MessageHandler() diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py index b7ca408d9..ade4c7193 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py @@ -33,6 +33,7 @@ class MessageSending: try: # 重新导入router from ..mmc_com_layer import router + self.maibot_router = router if self.maibot_router is not None: logger.info("MaiBot router重连成功") @@ -73,14 +74,14 @@ class MessageSending: # 获取对应的客户端并发送切片 platform = message_base.message_info.platform - + # 再次检查router状态(防止运行时被重置) - if self.maibot_router is None or not hasattr(self.maibot_router, 'clients'): + if self.maibot_router is None or not hasattr(self.maibot_router, "clients"): logger.warning("MaiBot router连接已断开,尝试重新连接") if not await self._attempt_reconnect(): logger.error("MaiBot router重连失败,切片发送中止") return False - + if platform not in self.maibot_router.clients: logger.error(f"平台 {platform} 未连接") return False diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py index 83d19a1d7..7ae743c41 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py @@ -23,7 +23,9 @@ class MetaEventHandler: """设置插件配置""" self.plugin_config = plugin_config # 更新interval值 - self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000 + self.interval = ( + config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000 + ) async def handle_meta_event(self, message: dict) -> None: event_type = message.get("meta_event_type") diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py index 58b7f23b9..5ea018f4d 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -9,7 +9,7 @@ from src.common.logger import get_logger logger = get_logger("napcat_adapter") from src.plugin_system.apis import config_api -from ..database import BanUser, napcat_db, is_identical +from ..database import BanUser, db_manager, is_identical from . import NoticeType, ACCEPT_FORMAT from .message_sending import message_send_instance from .message_handler import message_handler @@ -62,7 +62,7 @@ class NoticeHandler: return self.server_connection return websocket_manager.get_connection() - async def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None: + def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None: """ 将用户禁言记录添加到self.banned_list中 如果是全体禁言,则user_id为0 @@ -71,16 +71,16 @@ class NoticeHandler: user_id = 0 # 使用0表示全体禁言 lift_time = -1 ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time) - for record in list(self.banned_list): + for record in self.banned_list: if is_identical(record, ban_record): self.banned_list.remove(record) self.banned_list.append(ban_record) - await napcat_db.create_ban_record(ban_record) # 更新 + db_manager.create_ban_record(ban_record) # 作为更新 return self.banned_list.append(ban_record) - await napcat_db.create_ban_record(ban_record) # 新建 + db_manager.create_ban_record(ban_record) # 添加到数据库 - async def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None: + def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None: """ 从self.lifted_group_list中移除已经解除全体禁言的群 """ @@ -88,12 +88,7 @@ class NoticeHandler: user_id = 0 # 使用0表示全体禁言 ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1) self.lifted_list.append(ban_record) - # 从被禁言列表里移除对应记录 - for record in list(self.banned_list): - if is_identical(record, ban_record): - self.banned_list.remove(record) - break - await napcat_db.delete_ban_record(ban_record) + db_manager.delete_ban_record(ban_record) # 删除数据库中的记录 async def handle_notice(self, raw_message: dict) -> None: notice_type = raw_message.get("notice_type") @@ -121,9 +116,9 @@ class NoticeHandler: sub_type = raw_message.get("sub_type") match sub_type: case NoticeType.Notify.poke: - if config_api.get_plugin_config(self.plugin_config, "features.enable_poke", True) and await message_handler.check_allow_to_chat( - user_id, group_id, False, False - ): + if config_api.get_plugin_config( + self.plugin_config, "features.enable_poke", True + ) and await message_handler.check_allow_to_chat(user_id, group_id, False, False): logger.debug("处理戳一戳消息") handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id) else: @@ -132,14 +127,18 @@ class NoticeHandler: from src.plugin_system.core.event_manager import event_manager from ...event_types import NapcatEvent - await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME) + await event_manager.trigger_event( + NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME + ) case _: logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") - case NoticeType.group_msg_emoji_like: + case NoticeType.group_msg_emoji_like: # 该事件转移到 handle_group_emoji_like_notify函数内触发 if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True): logger.debug("处理群聊表情回复") - handled_message, user_info = await self.handle_group_emoji_like_notify(raw_message,group_id,user_id) + handled_message, user_info = await self.handle_group_emoji_like_notify( + raw_message, group_id, user_id + ) else: logger.warning("群聊表情回复被禁用,取消群聊表情回复处理") case NoticeType.group_ban: @@ -202,11 +201,9 @@ class NoticeHandler: if system_notice: await self.put_notice(message_base) - return None else: logger.debug("发送到Maibot处理通知信息") await message_send_instance.message_send(message_base) - return None async def handle_poke_notify( self, raw_message: dict, group_id: int, user_id: int @@ -301,7 +298,7 @@ class NoticeHandler: async def handle_group_emoji_like_notify(self, raw_message: dict, group_id: int, user_id: int): if not group_id: logger.error("群ID不能为空,无法处理群聊表情回复通知") - return None, None + return None, None user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id) if user_qq_info: @@ -311,37 +308,42 @@ class NoticeHandler: user_name = "QQ用户" user_cardname = "QQ用户" logger.debug("无法获取表情回复对方的用户昵称") - + from src.plugin_system.core.event_manager import event_manager from ...event_types import NapcatEvent - target_message = await event_manager.trigger_event(NapcatEvent.MESSAGE.GET_MSG,message_id=raw_message.get("message_id","")) - target_message_text = target_message.get_message_result().get("data",{}).get("raw_message","") + target_message = await event_manager.trigger_event( + NapcatEvent.MESSAGE.GET_MSG, message_id=raw_message.get("message_id", "") + ) + target_message_text = target_message.get_message_result().get("data", {}).get("raw_message", "") if not target_message: logger.error("未找到对应消息") return None, None if len(target_message_text) > 15: target_message_text = target_message_text[:15] + "..." - + user_info: UserInfo = UserInfo( platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), user_id=user_id, user_nickname=user_name, user_cardname=user_cardname, ) - + like_emoji_id = raw_message.get("likes")[0].get("emoji_id") await event_manager.trigger_event( - NapcatEvent.ON_RECEIVED.EMOJI_LIEK, - permission_group=PLUGIN_NAME, - group_id=group_id, - user_id=user_id, - message_id=raw_message.get("message_id",""), - emoji_id=like_emoji_id - ) - seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]") + NapcatEvent.ON_RECEIVED.EMOJI_LIEK, + permission_group=PLUGIN_NAME, + group_id=group_id, + user_id=user_id, + message_id=raw_message.get("message_id", ""), + emoji_id=like_emoji_id, + ) + seg_data = Seg( + type="text", + data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]", + ) return seg_data, user_info - + async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]: if not group_id: logger.error("群ID不能为空,无法处理禁言通知") @@ -381,7 +383,7 @@ class NoticeHandler: if user_id == 0: # 为全体禁言 sub_type: str = "whole_ban" - await self._ban_operation(group_id) + self._ban_operation(group_id) else: # 为单人禁言 # 获取被禁言人的信息 sub_type: str = "ban" @@ -395,7 +397,7 @@ class NoticeHandler: user_nickname=user_nickname, user_cardname=user_cardname, ) - await self._ban_operation(group_id, user_id, int(time.time() + duration)) + self._ban_operation(group_id, user_id, int(time.time() + duration)) seg_data: Seg = Seg( type="notify", @@ -444,7 +446,7 @@ class NoticeHandler: user_id = raw_message.get("user_id") if user_id == 0: # 全体禁言解除 sub_type = "whole_lift_ban" - await self._lift_operation(group_id) + self._lift_operation(group_id) else: # 单人禁言解除 sub_type = "lift_ban" # 获取被解除禁言人的信息 @@ -460,7 +462,7 @@ class NoticeHandler: user_nickname=user_nickname, user_cardname=user_cardname, ) - await self._lift_operation(group_id, user_id) + self._lift_operation(group_id, user_id) seg_data: Seg = Seg( type="notify", @@ -471,8 +473,7 @@ class NoticeHandler: ) return seg_data, operator_info - @staticmethod - async def put_notice(message_base: MessageBase) -> None: + async def put_notice(self, message_base: MessageBase) -> None: """ 将处理后的通知消息放入通知队列 """ @@ -488,7 +489,7 @@ class NoticeHandler: group_id = lift_record.group_id user_id = lift_record.user_id - asyncio.create_task(napcat_db.delete_ban_record(lift_record)) # 从数据库中删除禁言记录 + db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录 seg_message: Seg = await self.natural_lift(group_id, user_id) @@ -585,8 +586,7 @@ class NoticeHandler: self.banned_list.remove(ban_record) await asyncio.sleep(5) - @staticmethod - async def send_notice() -> None: + async def send_notice(self) -> None: """ 发送通知消息到Napcat """ diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/response_pool.py b/src/plugins/built_in/napcat_adapter_plugin/src/response_pool.py index 3e8e5c4a4..7ba313af5 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/response_pool.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/response_pool.py @@ -45,12 +45,12 @@ async def check_timeout_response() -> None: while True: cleaned_message_count: int = 0 now_time = time.time() - + # 获取心跳间隔配置 heartbeat_interval = 30 # 默认值 if plugin_config: heartbeat_interval = config_api.get_plugin_config(plugin_config, "napcat_server.heartbeat_interval", 30) - + for echo_id, response_time in list(response_time_dict.items()): if now_time - response_time > heartbeat_interval: cleaned_message_count += 1 diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py index ef380c82f..ec4fbe75e 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py @@ -96,6 +96,7 @@ class SendHandler: logger.error("无法识别的消息类型") return None logger.info("尝试发送到napcat") + logger.debug(f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'") response = await self.send_message_to_napcat( action, { @@ -228,8 +229,10 @@ class SendHandler: new_payload = payload if seg.type == "reply": target_id = seg.data + target_id = str(target_id) if target_id == "notice": return payload + logger.info(target_id if isinstance(target_id, str) else "") new_payload = self.build_payload( payload, await self.handle_reply_message(target_id if isinstance(target_id, str) else "", user_info), @@ -294,15 +297,17 @@ class SendHandler: async def handle_reply_message(self, id: str, user_info: UserInfo) -> dict | list: """处理回复消息""" + logger.debug(f"开始处理回复消息,消息ID: {id}") reply_seg = {"type": "reply", "data": {"id": id}} # 检查是否启用引用艾特功能 if not config_api.get_plugin_config(self.plugin_config, "features.enable_reply_at", False): + logger.info("引用艾特功能未启用,仅发送普通回复") return reply_seg try: - # 尝试通过 message_id 获取消息详情 - msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)}) + msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": id}) + logger.debug(f"获取消息 {id} 的详情响应: {msg_info_response}") replied_user_id = None if msg_info_response and msg_info_response.get("status") == "ok": @@ -313,6 +318,7 @@ class SendHandler: # 如果没有获取到被回复者的ID,则直接返回,不进行@ if not replied_user_id: logger.warning(f"无法获取消息 {id} 的发送者信息,跳过 @") + logger.info(f"最终返回的回复段: {reply_seg}") return reply_seg # 根据概率决定是否艾特用户 @@ -320,13 +326,17 @@ class SendHandler: at_seg = {"type": "at", "data": {"qq": str(replied_user_id)}} # 在艾特后面添加一个空格 text_seg = {"type": "text", "data": {"text": " "}} - return [reply_seg, at_seg, text_seg] + result_seg = [reply_seg, at_seg, text_seg] + logger.info(f"最终返回的回复段: {result_seg}") + return result_seg except Exception as e: logger.error(f"处理引用回复并尝试@时出错: {e}") # 出现异常时,只发送普通的回复,避免程序崩溃 + logger.info(f"最终返回的回复段: {reply_seg}") return reply_seg + logger.info(f"最终返回的回复段: {reply_seg}") return reply_seg @staticmethod @@ -366,7 +376,7 @@ class SendHandler: use_tts = False if self.plugin_config: use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False) - + if not use_tts: logger.warning("未启用语音消息处理") return {} diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/websocket_manager.py b/src/plugins/built_in/napcat_adapter_plugin/src/websocket_manager.py index 484b9b59e..0ef55a70f 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/websocket_manager.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/websocket_manager.py @@ -18,7 +18,9 @@ class WebSocketManager: self.max_reconnect_attempts = 10 # 最大重连次数 self.plugin_config = None - async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict) -> None: + async def start_connection( + self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict + ) -> None: """根据配置启动 WebSocket 连接""" self.plugin_config = plugin_config mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode") @@ -72,9 +74,7 @@ class WebSocketManager: # 如果配置了访问令牌,添加到请求头 access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token") if access_token: - connect_kwargs["additional_headers"] = { - "Authorization": f"Bearer {access_token}" - } + connect_kwargs["additional_headers"] = {"Authorization": f"Bearer {access_token}"} logger.info("已添加访问令牌到连接请求头") async with Server.connect(url, **connect_kwargs) as websocket: diff --git a/src/plugins/built_in/napcat_adapter_plugin/template/features_template.toml b/src/plugins/built_in/napcat_adapter_plugin/template/features_template.toml deleted file mode 100644 index 679267ab2..000000000 --- a/src/plugins/built_in/napcat_adapter_plugin/template/features_template.toml +++ /dev/null @@ -1,43 +0,0 @@ -# 权限配置文件 -# 此文件用于管理群聊和私聊的黑白名单设置,以及聊天相关功能 -# 支持热重载,修改后会自动生效 - -# 群聊权限设置 -group_list_type = "whitelist" # 群聊列表类型:whitelist(白名单)或 blacklist(黑名单) -group_list = [] # 群聊ID列表 -# 当 group_list_type 为 whitelist 时,只有列表中的群聊可以使用机器人 -# 当 group_list_type 为 blacklist 时,列表中的群聊无法使用机器人 -# 示例:group_list = [123456789, 987654321] - -# 私聊权限设置 -private_list_type = "whitelist" # 私聊列表类型:whitelist(白名单)或 blacklist(黑名单) -private_list = [] # 用户ID列表 -# 当 private_list_type 为 whitelist 时,只有列表中的用户可以私聊机器人 -# 当 private_list_type 为 blacklist 时,列表中的用户无法私聊机器人 -# 示例:private_list = [123456789, 987654321] - -# 全局禁止设置 -ban_user_id = [] # 全局禁止用户ID列表,这些用户无法在任何地方使用机器人 -ban_qq_bot = false # 是否屏蔽QQ官方机器人消息 - -# 聊天功能设置 -enable_poke = true # 是否启用戳一戳功能 -ignore_non_self_poke = false # 是否无视不是针对自己的戳一戳 -poke_debounce_seconds = 3 # 戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略 -enable_reply_at = true # 是否启用引用回复时艾特用户的功能 -reply_at_rate = 0.5 # 引用回复时艾特用户的几率 (0.0 ~ 1.0) - -# 视频处理设置 -enable_video_analysis = true # 是否启用视频识别功能 -max_video_size_mb = 100 # 视频文件最大大小限制(MB) -download_timeout = 60 # 视频下载超时时间(秒) -supported_formats = ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"] # 支持的视频格式 - -# 消息缓冲设置 -enable_message_buffer = true # 是否启用消息缓冲合并功能 -message_buffer_enable_group = true # 是否启用群聊消息缓冲合并 -message_buffer_enable_private = true # 是否启用私聊消息缓冲合并 -message_buffer_interval = 3.0 # 消息合并间隔时间(秒),在此时间内的连续消息将被合并 -message_buffer_initial_delay = 0.5 # 消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并 -message_buffer_max_components = 50 # 单个会话最大缓冲消息组件数量,超过此数量将强制合并 -message_buffer_block_prefixes = ["/"] # 消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲 \ No newline at end of file diff --git a/src/plugins/built_in/napcat_adapter_plugin/template/template_config.toml b/src/plugins/built_in/napcat_adapter_plugin/template/template_config.toml deleted file mode 100644 index a06906ad3..000000000 --- a/src/plugins/built_in/napcat_adapter_plugin/template/template_config.toml +++ /dev/null @@ -1,29 +0,0 @@ -[inner] -version = "0.2.1" # 版本号 -# 请勿修改版本号,除非你知道自己在做什么 - -[nickname] # 现在没用 -nickname = "" - -[napcat_server] # Napcat连接的ws服务设置 -mode = "reverse" # 连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端) -host = "localhost" # 主机地址 -port = 8095 # 端口号 -url = "" # 正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用) -access_token = "" # WebSocket 连接的访问令牌,用于身份验证(可选) -heartbeat_interval = 30 # 心跳间隔时间(按秒计) - -[maibot_server] # 连接麦麦的ws服务设置 -host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段 -port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段 - -[voice] # 发送语音设置 -use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter) - -[slicing] # WebSocket消息切片设置 -max_frame_size = 64 # WebSocket帧的最大大小,单位为字节,默认64KB -delay_ms = 10 # 切片发送间隔时间,单位为毫秒 - -[debug] -level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL) - diff --git a/src/plugins/built_in/poke_plugin/plugin.py b/src/plugins/built_in/poke_plugin/plugin.py index 13cf33ca0..a37c45dd1 100644 --- a/src/plugins/built_in/poke_plugin/plugin.py +++ b/src/plugins/built_in/poke_plugin/plugin.py @@ -30,7 +30,8 @@ class PokeAction(BaseAction): # === 功能描述(必须填写)=== action_parameters = { - "user_name": "需要戳一戳的用户的名字", + "user_name": "需要戳一戳的用户的名字 (可选)", + "user_id": "需要戳一戳的用户的ID (可选,优先级更高)", "times": "需要戳一戳的次数 (默认为 1)", } action_require = ["当需要戳某个用户时使用", "当你想提醒特定用户时使用"] @@ -46,32 +47,38 @@ class PokeAction(BaseAction): async def execute(self) -> Tuple[bool, str]: """执行戳一戳的动作""" + user_id = self.action_data.get("user_id") user_name = self.action_data.get("user_name") + try: times = int(self.action_data.get("times", 1)) except (ValueError, TypeError): times = 1 - if not user_name: - logger.warning("戳一戳动作缺少 'user_name' 参数。") - return False, "缺少 'user_name' 参数" - - user_info = await get_person_info_manager().get_person_info_by_name(user_name) - if not user_info or not user_info.get("user_id"): - logger.info(f"找不到名为 '{user_name}' 的用户。") - return False, f"找不到名为 '{user_name}' 的用户" - - user_id = user_info.get("user_id") + # 优先使用 user_id + if not user_id: + if not user_name: + logger.warning("戳一戳动作缺少 'user_id' 或 'user_name' 参数。") + return False, "缺少用户标识参数" + + # 备用方案:通过 user_name 查找 + user_info = await get_person_info_manager().get_person_info_by_name(user_name) + if not user_info or not user_info.get("user_id"): + logger.info(f"找不到名为 '{user_name}' 的用户。") + return False, f"找不到名为 '{user_name}' 的用户" + user_id = user_info.get("user_id") + + display_name = user_name or user_id for i in range(times): - logger.info(f"正在向 {user_name} ({user_id}) 发送第 {i + 1}/{times} 次戳一戳...") + logger.info(f"正在向 {display_name} ({user_id}) 发送第 {i + 1}/{times} 次戳一戳...") await self.send_command( - "SEND_POKE", args={"qq_id": user_id}, display_message=f"戳了戳 {user_name} ({i + 1}/{times})" + "SEND_POKE", args={"qq_id": user_id}, display_message=f"戳了戳 {display_name} ({i + 1}/{times})" ) # 添加一个小的延迟,以避免发送过快 await asyncio.sleep(0.5) - success_message = f"已向 {user_name} 发送 {times} 次戳一戳。" + success_message = f"已向 {display_name} 发送 {times} 次戳一戳。" await self.store_action_info( action_build_into_prompt=True, action_prompt_display=success_message, action_done=True ) diff --git a/src/plugins/built_in/web_search_tool/engines/base.py b/src/plugins/built_in/web_search_tool/engines/base.py index f7641aa2f..30d20a540 100644 --- a/src/plugins/built_in/web_search_tool/engines/base.py +++ b/src/plugins/built_in/web_search_tool/engines/base.py @@ -1,6 +1,7 @@ """ Base search engine interface """ + from abc import ABC, abstractmethod from typing import Dict, List, Any @@ -9,20 +10,20 @@ class BaseSearchEngine(ABC): """ 搜索引擎基类 """ - + @abstractmethod async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: """ 执行搜索 - + Args: args: 搜索参数,包含 query、num_results、time_range 等 - + Returns: 搜索结果列表,每个结果包含 title、url、snippet、provider 字段 """ pass - + @abstractmethod def is_available(self) -> bool: """ diff --git a/src/plugins/built_in/web_search_tool/engines/bing_engine.py b/src/plugins/built_in/web_search_tool/engines/bing_engine.py index 6d32492ad..ece747fbd 100644 --- a/src/plugins/built_in/web_search_tool/engines/bing_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/bing_engine.py @@ -1,6 +1,7 @@ """ Bing search engine implementation """ + import asyncio import functools import random @@ -58,21 +59,21 @@ class BingSearchEngine(BaseSearchEngine): """ Bing搜索引擎实现 """ - + def __init__(self): self.session = requests.Session() self.session.headers = HEADERS - + def is_available(self) -> bool: """检查Bing搜索引擎是否可用""" return True # Bing是免费搜索引擎,总是可用 - + async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: """执行Bing搜索""" query = args["query"] num_results = args.get("num_results", 3) time_range = args.get("time_range", "any") - + try: loop = asyncio.get_running_loop() func = functools.partial(self._search_sync, query, num_results, time_range) @@ -81,17 +82,17 @@ class BingSearchEngine(BaseSearchEngine): except Exception as e: logger.error(f"Bing 搜索失败: {e}") return [] - + def _search_sync(self, keyword: str, num_results: int, time_range: str) -> List[Dict[str, Any]]: """同步执行Bing搜索""" if not keyword: return [] list_result = [] - + # 构建搜索URL search_url = bing_search_url + keyword - + # 如果指定了时间范围,添加时间过滤参数 if time_range == "week": search_url += "&qft=+filterui:date-range-7" @@ -182,34 +183,29 @@ class BingSearchEngine(BaseSearchEngine): # 尝试提取搜索结果 # 方法1: 查找标准的搜索结果容器 results = root.select("ol#b_results li.b_algo") - + if results: for _rank, result in enumerate(results, 1): # 提取标题和链接 title_link = result.select_one("h2 a") if not title_link: continue - + title = title_link.get_text().strip() url = title_link.get("href", "") - + # 提取摘要 abstract = "" abstract_elem = result.select_one("div.b_caption p") if abstract_elem: abstract = abstract_elem.get_text().strip() - + # 限制摘要长度 if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH: abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..." - - list_data.append({ - "title": title, - "url": url, - "snippet": abstract, - "provider": "Bing" - }) - + + list_data.append({"title": title, "url": url, "snippet": abstract, "provider": "Bing"}) + if len(list_data) >= 10: # 限制结果数量 break @@ -217,22 +213,34 @@ class BingSearchEngine(BaseSearchEngine): if not list_data: # 查找所有可能的搜索结果链接 all_links = root.find_all("a") - + for link in all_links: href = link.get("href", "") text = link.get_text().strip() - + # 过滤有效的搜索结果链接 - if (href and text and len(text) > 10 + if ( + href + and text + and len(text) > 10 and not href.startswith("javascript:") and not href.startswith("#") and "http" in href - and not any(x in href for x in [ - "bing.com/search", "bing.com/images", "bing.com/videos", - "bing.com/maps", "bing.com/news", "login", "account", - "microsoft", "javascript" - ])): - + and not any( + x in href + for x in [ + "bing.com/search", + "bing.com/images", + "bing.com/videos", + "bing.com/maps", + "bing.com/news", + "login", + "account", + "microsoft", + "javascript", + ] + ) + ): # 尝试获取摘要 abstract = "" parent = link.parent @@ -240,18 +248,13 @@ class BingSearchEngine(BaseSearchEngine): full_text = parent.get_text().strip() if len(full_text) > len(text): abstract = full_text.replace(text, "", 1).strip() - + # 限制摘要长度 if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH: abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..." - - list_data.append({ - "title": text, - "url": href, - "snippet": abstract, - "provider": "Bing" - }) - + + list_data.append({"title": text, "url": href, "snippet": abstract, "provider": "Bing"}) + if len(list_data) >= 10: break diff --git a/src/plugins/built_in/web_search_tool/engines/ddg_engine.py b/src/plugins/built_in/web_search_tool/engines/ddg_engine.py index 011935e27..29f03b31a 100644 --- a/src/plugins/built_in/web_search_tool/engines/ddg_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/ddg_engine.py @@ -1,6 +1,7 @@ """ DuckDuckGo search engine implementation """ + from typing import Dict, List, Any from asyncddgs import aDDGS @@ -14,27 +15,22 @@ class DDGSearchEngine(BaseSearchEngine): """ DuckDuckGo搜索引擎实现 """ - + def is_available(self) -> bool: """检查DuckDuckGo搜索引擎是否可用""" return True # DuckDuckGo不需要API密钥,总是可用 - + async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: """执行DuckDuckGo搜索""" query = args["query"] num_results = args.get("num_results", 3) - + try: async with aDDGS() as ddgs: search_response = await ddgs.text(query, max_results=num_results) - + return [ - { - "title": r.get("title"), - "url": r.get("href"), - "snippet": r.get("body"), - "provider": "DuckDuckGo" - } + {"title": r.get("title"), "url": r.get("href"), "snippet": r.get("body"), "provider": "DuckDuckGo"} for r in search_response ] except Exception as e: diff --git a/src/plugins/built_in/web_search_tool/engines/exa_engine.py b/src/plugins/built_in/web_search_tool/engines/exa_engine.py index 7327afaeb..269e32bd1 100644 --- a/src/plugins/built_in/web_search_tool/engines/exa_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/exa_engine.py @@ -1,6 +1,7 @@ """ Exa search engine implementation """ + import asyncio import functools from datetime import datetime, timedelta @@ -19,31 +20,27 @@ class ExaSearchEngine(BaseSearchEngine): """ Exa搜索引擎实现 """ - + def __init__(self): self._initialize_clients() - + def _initialize_clients(self): """初始化Exa客户端""" # 从主配置文件读取API密钥 exa_api_keys = config_api.get_global_config("web_search.exa_api_keys", None) - + # 创建API密钥管理器 - self.api_manager = create_api_key_manager_from_config( - exa_api_keys, - lambda key: Exa(api_key=key), - "Exa" - ) - + self.api_manager = create_api_key_manager_from_config(exa_api_keys, lambda key: Exa(api_key=key), "Exa") + def is_available(self) -> bool: """检查Exa搜索引擎是否可用""" return self.api_manager.is_available() - + async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: """执行Exa搜索""" if not self.is_available(): return [] - + query = args["query"] num_results = args.get("num_results", 3) time_range = args.get("time_range", "any") @@ -52,7 +49,7 @@ class ExaSearchEngine(BaseSearchEngine): if time_range != "any": today = datetime.now() start_date = today - timedelta(days=7 if time_range == "week" else 30) - exa_args["start_published_date"] = start_date.strftime('%Y-%m-%d') + exa_args["start_published_date"] = start_date.strftime("%Y-%m-%d") try: # 使用API密钥管理器获取下一个客户端 @@ -60,17 +57,17 @@ class ExaSearchEngine(BaseSearchEngine): if not exa_client: logger.error("无法获取Exa客户端") return [] - + loop = asyncio.get_running_loop() func = functools.partial(exa_client.search_and_contents, query, **exa_args) search_response = await loop.run_in_executor(None, func) - + return [ { "title": res.title, "url": res.url, - "snippet": " ".join(getattr(res, 'highlights', [])) or (getattr(res, 'text', '')[:250] + '...'), - "provider": "Exa" + "snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."), + "provider": "Exa", } for res in search_response.results ] diff --git a/src/plugins/built_in/web_search_tool/engines/tavily_engine.py b/src/plugins/built_in/web_search_tool/engines/tavily_engine.py index d7cf61d6c..2f929284f 100644 --- a/src/plugins/built_in/web_search_tool/engines/tavily_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/tavily_engine.py @@ -1,6 +1,7 @@ """ Tavily search engine implementation """ + import asyncio import functools from typing import Dict, List, Any @@ -18,31 +19,29 @@ class TavilySearchEngine(BaseSearchEngine): """ Tavily搜索引擎实现 """ - + def __init__(self): self._initialize_clients() - + def _initialize_clients(self): """初始化Tavily客户端""" # 从主配置文件读取API密钥 tavily_api_keys = config_api.get_global_config("web_search.tavily_api_keys", None) - + # 创建API密钥管理器 self.api_manager = create_api_key_manager_from_config( - tavily_api_keys, - lambda key: TavilyClient(api_key=key), - "Tavily" + tavily_api_keys, lambda key: TavilyClient(api_key=key), "Tavily" ) - + def is_available(self) -> bool: """检查Tavily搜索引擎是否可用""" return self.api_manager.is_available() - + async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: """执行Tavily搜索""" if not self.is_available(): return [] - + query = args["query"] num_results = args.get("num_results", 3) time_range = args.get("time_range", "any") @@ -53,38 +52,40 @@ class TavilySearchEngine(BaseSearchEngine): if not tavily_client: logger.error("无法获取Tavily客户端") return [] - + # 构建Tavily搜索参数 search_params = { "query": query, "max_results": num_results, "search_depth": "basic", "include_answer": False, - "include_raw_content": False + "include_raw_content": False, } - + # 根据时间范围调整搜索参数 if time_range == "week": search_params["days"] = 7 elif time_range == "month": search_params["days"] = 30 - + loop = asyncio.get_running_loop() func = functools.partial(tavily_client.search, **search_params) search_response = await loop.run_in_executor(None, func) - + results = [] if search_response and "results" in search_response: for res in search_response["results"]: - results.append({ - "title": res.get("title", "无标题"), - "url": res.get("url", ""), - "snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要", - "provider": "Tavily" - }) - + results.append( + { + "title": res.get("title", "无标题"), + "url": res.get("url", ""), + "snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要", + "provider": "Tavily", + } + ) + return results - + except Exception as e: logger.error(f"Tavily 搜索失败: {e}") return [] diff --git a/src/plugins/built_in/web_search_tool/plugin.py b/src/plugins/built_in/web_search_tool/plugin.py index 1789062ae..fadc02a88 100644 --- a/src/plugins/built_in/web_search_tool/plugin.py +++ b/src/plugins/built_in/web_search_tool/plugin.py @@ -3,15 +3,10 @@ Web Search Tool Plugin 一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。 """ + from typing import List, Tuple, Type -from src.plugin_system import ( - BasePlugin, - register_plugin, - ComponentInfo, - ConfigField, - PythonDependency -) +from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency from src.plugin_system.apis import config_api from src.common.logger import get_logger @@ -25,7 +20,7 @@ logger = get_logger("web_search_plugin") class WEBSEARCHPLUGIN(BasePlugin): """ 网络搜索工具插件 - + 提供网络搜索和URL解析功能,支持多种搜索引擎: - Exa (需要API密钥) - Tavily (需要API密钥) @@ -37,11 +32,11 @@ class WEBSEARCHPLUGIN(BasePlugin): plugin_name: str = "web_search_tool" # 内部标识符 enable_plugin: bool = True dependencies: List[str] = [] # 插件依赖列表 - + def __init__(self, *args, **kwargs): """初始化插件,立即加载所有搜索引擎""" super().__init__(*args, **kwargs) - + # 立即初始化所有搜索引擎,触发API密钥管理器的日志输出 logger.info("🚀 正在初始化所有搜索引擎...") try: @@ -49,65 +44,58 @@ class WEBSEARCHPLUGIN(BasePlugin): from .engines.tavily_engine import TavilySearchEngine from .engines.ddg_engine import DDGSearchEngine from .engines.bing_engine import BingSearchEngine - + # 实例化所有搜索引擎,这会触发API密钥管理器的初始化 exa_engine = ExaSearchEngine() tavily_engine = TavilySearchEngine() ddg_engine = DDGSearchEngine() bing_engine = BingSearchEngine() - + # 报告每个引擎的状态 engines_status = { "Exa": exa_engine.is_available(), "Tavily": tavily_engine.is_available(), "DuckDuckGo": ddg_engine.is_available(), - "Bing": bing_engine.is_available() + "Bing": bing_engine.is_available(), } - + available_engines = [name for name, available in engines_status.items() if available] unavailable_engines = [name for name, available in engines_status.items() if not available] - + if available_engines: logger.info(f"✅ 可用搜索引擎: {', '.join(available_engines)}") if unavailable_engines: logger.info(f"❌ 不可用搜索引擎: {', '.join(unavailable_engines)}") - + except Exception as e: logger.error(f"❌ 搜索引擎初始化失败: {e}", exc_info=True) - + # Python包依赖列表 python_dependencies: List[PythonDependency] = [ - PythonDependency( - package_name="asyncddgs", - description="异步DuckDuckGo搜索库", - optional=False - ), + PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False), PythonDependency( package_name="exa_py", description="Exa搜索API客户端库", - optional=True # 如果没有API密钥,这个是可选的 + optional=True, # 如果没有API密钥,这个是可选的 ), PythonDependency( package_name="tavily", install_name="tavily-python", # 安装时使用这个名称 description="Tavily搜索API客户端库", - optional=True # 如果没有API密钥,这个是可选的 + optional=True, # 如果没有API密钥,这个是可选的 ), PythonDependency( package_name="httpx", version=">=0.20.0", install_name="httpx[socks]", # 安装时使用这个名称(包含可选依赖) description="支持SOCKS代理的HTTP客户端库", - optional=False - ) + optional=False, + ), ] config_file_name: str = "config.toml" # 配置文件名 # 配置节描述 - config_section_descriptions = { - "plugin": "插件基本信息", - "proxy": "链接本地解析代理配置" - } + config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"} # 配置Schema定义 # 注意:EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分 @@ -119,42 +107,32 @@ class WEBSEARCHPLUGIN(BasePlugin): }, "proxy": { "http_proxy": ConfigField( - type=str, - default=None, - description="HTTP代理地址,格式如: http://proxy.example.com:8080" + type=str, default=None, description="HTTP代理地址,格式如: http://proxy.example.com:8080" ), "https_proxy": ConfigField( - type=str, - default=None, - description="HTTPS代理地址,格式如: http://proxy.example.com:8080" + type=str, default=None, description="HTTPS代理地址,格式如: http://proxy.example.com:8080" ), "socks5_proxy": ConfigField( - type=str, - default=None, - description="SOCKS5代理地址,格式如: socks5://proxy.example.com:1080" + type=str, default=None, description="SOCKS5代理地址,格式如: socks5://proxy.example.com:1080" ), - "enable_proxy": ConfigField( - type=bool, - default=False, - description="是否启用代理" - ) + "enable_proxy": ConfigField(type=bool, default=False, description="是否启用代理"), }, } - + def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: """ 获取插件组件列表 - + Returns: 组件信息和类型的元组列表 """ enable_tool = [] - + # 从主配置文件读取组件启用配置 if config_api.get_global_config("web_search.enable_web_search_tool", True): enable_tool.append((WebSurfingTool.get_tool_info(), WebSurfingTool)) - + if config_api.get_global_config("web_search.enable_url_tool", True): enable_tool.append((URLParserTool.get_tool_info(), URLParserTool)) - + return enable_tool diff --git a/src/plugins/built_in/web_search_tool/tools/url_parser.py b/src/plugins/built_in/web_search_tool/tools/url_parser.py index 3a05423a7..25338c35c 100644 --- a/src/plugins/built_in/web_search_tool/tools/url_parser.py +++ b/src/plugins/built_in/web_search_tool/tools/url_parser.py @@ -1,6 +1,7 @@ """ URL parser tool implementation """ + import asyncio import functools from typing import Any, Dict @@ -24,17 +25,18 @@ class URLParserTool(BaseTool): """ 一个用于解析和总结一个或多个网页URL内容的工具。 """ + name: str = "parse_url" description: str = "当需要理解一个或多个特定网页链接的内容时,使用此工具。例如:'这些网页讲了什么?[https://example.com, https://example2.com]' 或 '帮我总结一下这些文章'" available_for_llm: bool = True parameters = [ ("urls", ToolParamType.STRING, "要理解的网站", True, None), ] - + def __init__(self, plugin_config=None): super().__init__(plugin_config) self._initialize_exa_clients() - + def _initialize_exa_clients(self): """初始化Exa客户端""" # 优先从主配置文件读取,如果没有则从插件配置文件读取 @@ -42,12 +44,10 @@ class URLParserTool(BaseTool): if exa_api_keys is None: # 从插件配置文件读取 exa_api_keys = self.get_config("exa.api_keys", []) - + # 创建API密钥管理器 self.api_manager = create_api_key_manager_from_config( - exa_api_keys, - lambda key: Exa(api_key=key), - "Exa URL Parser" + exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser" ) async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]: @@ -58,12 +58,12 @@ class URLParserTool(BaseTool): # 读取代理配置 enable_proxy = self.get_config("proxy.enable_proxy", False) proxies = None - + if enable_proxy: socks5_proxy = self.get_config("proxy.socks5_proxy", None) http_proxy = self.get_config("proxy.http_proxy", None) https_proxy = self.get_config("proxy.https_proxy", None) - + # 优先使用SOCKS5代理(全协议代理) if socks5_proxy: proxies = socks5_proxy @@ -75,17 +75,17 @@ class URLParserTool(BaseTool): if https_proxy: proxies["https://"] = https_proxy logger.info(f"使用HTTP/HTTPS代理配置: {proxies}") - + client_kwargs = {"timeout": 15.0, "follow_redirects": True} if proxies: client_kwargs["proxies"] = proxies - + async with httpx.AsyncClient(**client_kwargs) as client: response = await client.get(url) response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") - + title = soup.title.string if soup.title else "无标题" for script in soup(["script", "style"]): script.extract() @@ -104,12 +104,12 @@ class URLParserTool(BaseTool): return {"error": "未配置LLM模型"} success, summary, reasoning, model_name = await llm_api.generate_with_model( - prompt=summary_prompt, - model_config=model_config, - request_type="story.generate", - temperature=0.3, - max_tokens=1000 - ) + prompt=summary_prompt, + model_config=model_config, + request_type="story.generate", + temperature=0.3, + max_tokens=1000, + ) if not success: logger.info(f"生成摘要失败: {summary}") @@ -117,12 +117,7 @@ class URLParserTool(BaseTool): logger.info(f"成功生成摘要内容:'{summary}'") - return { - "title": title, - "url": url, - "snippet": summary, - "source": "local" - } + return {"title": title, "url": url, "snippet": summary, "source": "local"} except httpx.HTTPStatusError as e: logger.warning(f"本地解析URL '{url}' 失败 (HTTP {e.response.status_code})") @@ -137,6 +132,7 @@ class URLParserTool(BaseTool): """ # 获取当前文件路径用于缓存键 import os + current_file_path = os.path.abspath(__file__) # 检查缓存 @@ -144,7 +140,7 @@ class URLParserTool(BaseTool): if cached_result: logger.info(f"缓存命中: {self.name} -> {function_args}") return cached_result - + urls_input = function_args.get("urls") if not urls_input: return {"error": "URL列表不能为空。"} @@ -158,14 +154,14 @@ class URLParserTool(BaseTool): valid_urls = validate_urls(urls) if not valid_urls: return {"error": "未找到有效的URL。"} - + urls = valid_urls logger.info(f"准备解析 {len(urls)} 个URL: {urls}") successful_results = [] error_messages = [] urls_to_retry_locally = [] - + # 步骤 1: 尝试使用 Exa API 进行解析 contents_response = None if self.api_manager.is_available(): @@ -182,41 +178,45 @@ class URLParserTool(BaseTool): contents_response = await loop.run_in_executor(None, func) except Exception as e: logger.error(f"执行 Exa URL解析时发生严重异常: {e}", exc_info=True) - contents_response = None # 确保异常后为None + contents_response = None # 确保异常后为None # 步骤 2: 处理Exa的响应 - if contents_response and hasattr(contents_response, 'statuses'): - results_map = {res.url: res for res in contents_response.results} if hasattr(contents_response, 'results') else {} + if contents_response and hasattr(contents_response, "statuses"): + results_map = ( + {res.url: res for res in contents_response.results} if hasattr(contents_response, "results") else {} + ) if contents_response.statuses: for status in contents_response.statuses: - if status.status == 'success': + if status.status == "success": res = results_map.get(status.id) if res: - summary = getattr(res, 'summary', '') - highlights = " ".join(getattr(res, 'highlights', [])) - text_snippet = (getattr(res, 'text', '')[:300] + '...') if getattr(res, 'text', '') else '' - snippet = summary or highlights or text_snippet or '无摘要' - - successful_results.append({ - "title": getattr(res, 'title', '无标题'), - "url": getattr(res, 'url', status.id), - "snippet": snippet, - "source": "exa" - }) + summary = getattr(res, "summary", "") + highlights = " ".join(getattr(res, "highlights", [])) + text_snippet = (getattr(res, "text", "")[:300] + "...") if getattr(res, "text", "") else "" + snippet = summary or highlights or text_snippet or "无摘要" + + successful_results.append( + { + "title": getattr(res, "title", "无标题"), + "url": getattr(res, "url", status.id), + "snippet": snippet, + "source": "exa", + } + ) else: - error_tag = getattr(status, 'error', '未知错误') + error_tag = getattr(status, "error", "未知错误") logger.warning(f"Exa解析URL '{status.id}' 失败: {error_tag}。准备本地重试。") urls_to_retry_locally.append(status.id) else: # 如果Exa未配置、API调用失败或返回无效响应,则所有URL都进入本地重试 - urls_to_retry_locally.extend(url for url in urls if url not in [res['url'] for res in successful_results]) + urls_to_retry_locally.extend(url for url in urls if url not in [res["url"] for res in successful_results]) # 步骤 3: 对失败的URL进行本地解析 if urls_to_retry_locally: logger.info(f"开始本地解析以下URL: {urls_to_retry_locally}") local_tasks = [self._local_parse_and_summarize(url) for url in urls_to_retry_locally] local_results = await asyncio.gather(*local_tasks) - + for i, res in enumerate(local_results): url = urls_to_retry_locally[i] if "error" in res: @@ -228,13 +228,9 @@ class URLParserTool(BaseTool): return {"error": "无法从所有给定的URL获取内容。", "details": error_messages} formatted_content = format_url_parse_results(successful_results) - - result = { - "type": "url_parse_result", - "content": formatted_content, - "errors": error_messages - } - + + result = {"type": "url_parse_result", "content": formatted_content, "errors": error_messages} + # 保存到缓存 if "error" not in result: await tool_cache.set(self.name, function_args, current_file_path, result) diff --git a/src/plugins/built_in/web_search_tool/tools/web_search.py b/src/plugins/built_in/web_search_tool/tools/web_search.py index c09ad5e92..3e4039cb8 100644 --- a/src/plugins/built_in/web_search_tool/tools/web_search.py +++ b/src/plugins/built_in/web_search_tool/tools/web_search.py @@ -1,6 +1,7 @@ """ Web search tool implementation """ + import asyncio from typing import Any, Dict, List @@ -22,14 +23,23 @@ class WebSurfingTool(BaseTool): """ 网络搜索工具 """ + name: str = "web_search" - description: str = "用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具" + description: str = ( + "用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具" + ) available_for_llm: bool = True parameters = [ ("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None), ("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量,默认为5。", False, None), - ("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。", False, ["any", "week", "month"]) - ] # type: ignore + ( + "time_range", + ToolParamType.STRING, + "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。", + False, + ["any", "week", "month"], + ), + ] # type: ignore def __init__(self, plugin_config=None): super().__init__(plugin_config) @@ -38,7 +48,7 @@ class WebSurfingTool(BaseTool): "exa": ExaSearchEngine(), "tavily": TavilySearchEngine(), "ddg": DDGSearchEngine(), - "bing": BingSearchEngine() + "bing": BingSearchEngine(), } async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: @@ -48,6 +58,7 @@ class WebSurfingTool(BaseTool): # 获取当前文件路径用于缓存键 import os + current_file_path = os.path.abspath(__file__) # 检查缓存 @@ -59,7 +70,7 @@ class WebSurfingTool(BaseTool): # 读取搜索配置 enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"]) search_strategy = config_api.get_global_config("web_search.search_strategy", "single") - + logger.info(f"开始搜索,策略: {search_strategy}, 启用引擎: {enabled_engines}, 参数: '{function_args}'") # 根据策略执行搜索 @@ -69,17 +80,19 @@ class WebSurfingTool(BaseTool): result = await self._execute_fallback_search(function_args, enabled_engines) else: # single result = await self._execute_single_search(function_args, enabled_engines) - + # 保存到缓存 if "error" not in result: await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query) - + return result - async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: + async def _execute_parallel_search( + self, function_args: Dict[str, Any], enabled_engines: List[str] + ) -> Dict[str, Any]: """并行搜索策略:同时使用所有启用的搜索引擎""" search_tasks = [] - + for engine_name in enabled_engines: engine = self.engines.get(engine_name) if engine and engine.is_available(): @@ -92,7 +105,7 @@ class WebSurfingTool(BaseTool): try: search_results_lists = await asyncio.gather(*search_tasks, return_exceptions=True) - + all_results = [] for result in search_results_lists: if isinstance(result, list): @@ -103,7 +116,7 @@ class WebSurfingTool(BaseTool): # 去重并格式化 unique_results = deduplicate_results(all_results) formatted_content = format_search_results(unique_results) - + return { "type": "web_search_result", "content": formatted_content, @@ -113,30 +126,32 @@ class WebSurfingTool(BaseTool): logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True) return {"error": f"执行网络搜索时发生严重错误: {str(e)}"} - async def _execute_fallback_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: + async def _execute_fallback_search( + self, function_args: Dict[str, Any], enabled_engines: List[str] + ) -> Dict[str, Any]: """回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个""" for engine_name in enabled_engines: engine = self.engines.get(engine_name) if not engine or not engine.is_available(): continue - + try: custom_args = function_args.copy() custom_args["num_results"] = custom_args.get("num_results", 5) - + results = await engine.search(custom_args) - + if results: # 如果有结果,直接返回 formatted_content = format_search_results(results) return { "type": "web_search_result", "content": formatted_content, } - + except Exception as e: logger.warning(f"{engine_name} 搜索失败,尝试下一个引擎: {e}") continue - + return {"error": "所有搜索引擎都失败了。"} async def _execute_single_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: @@ -145,20 +160,20 @@ class WebSurfingTool(BaseTool): engine = self.engines.get(engine_name) if not engine or not engine.is_available(): continue - + try: custom_args = function_args.copy() custom_args["num_results"] = custom_args.get("num_results", 5) - + results = await engine.search(custom_args) formatted_content = format_search_results(results) return { "type": "web_search_result", "content": formatted_content, } - + except Exception as e: logger.error(f"{engine_name} 搜索失败: {e}") return {"error": f"{engine_name} 搜索失败: {str(e)}"} - + return {"error": "没有可用的搜索引擎。"} diff --git a/src/plugins/built_in/web_search_tool/utils/api_key_manager.py b/src/plugins/built_in/web_search_tool/utils/api_key_manager.py index f8e0afa71..07757cdb1 100644 --- a/src/plugins/built_in/web_search_tool/utils/api_key_manager.py +++ b/src/plugins/built_in/web_search_tool/utils/api_key_manager.py @@ -1,24 +1,25 @@ """ API密钥管理器,提供轮询机制 """ + import itertools from typing import List, Optional, TypeVar, Generic, Callable from src.common.logger import get_logger logger = get_logger("api_key_manager") -T = TypeVar('T') +T = TypeVar("T") class APIKeyManager(Generic[T]): """ API密钥管理器,支持轮询机制 """ - + def __init__(self, api_keys: List[str], client_factory: Callable[[str], T], service_name: str = "Unknown"): """ 初始化API密钥管理器 - + Args: api_keys: API密钥列表 client_factory: 客户端工厂函数,接受API密钥参数并返回客户端实例 @@ -27,14 +28,14 @@ class APIKeyManager(Generic[T]): self.service_name = service_name self.clients: List[T] = [] self.client_cycle: Optional[itertools.cycle] = None - + if api_keys: # 过滤有效的API密钥,排除None、空字符串、"None"字符串等 valid_keys = [] for key in api_keys: if isinstance(key, str) and key.strip() and key.strip().lower() not in ("none", "null", ""): valid_keys.append(key.strip()) - + if valid_keys: try: self.clients = [client_factory(key) for key in valid_keys] @@ -48,35 +49,33 @@ class APIKeyManager(Generic[T]): logger.warning(f"⚠️ {service_name} API Keys 配置无效(包含None或空值),{service_name} 功能将不可用") else: logger.warning(f"⚠️ {service_name} API Keys 未配置,{service_name} 功能将不可用") - + def is_available(self) -> bool: """检查是否有可用的客户端""" return bool(self.clients and self.client_cycle) - + def get_next_client(self) -> Optional[T]: """获取下一个客户端(轮询)""" if not self.is_available(): return None return next(self.client_cycle) - + def get_client_count(self) -> int: """获取可用客户端数量""" return len(self.clients) def create_api_key_manager_from_config( - config_keys: Optional[List[str]], - client_factory: Callable[[str], T], - service_name: str + config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str ) -> APIKeyManager[T]: """ 从配置创建API密钥管理器的便捷函数 - + Args: config_keys: 从配置读取的API密钥列表 client_factory: 客户端工厂函数 service_name: 服务名称 - + Returns: API密钥管理器实例 """ diff --git a/src/plugins/built_in/web_search_tool/utils/formatters.py b/src/plugins/built_in/web_search_tool/utils/formatters.py index 434f6f3c8..df1e4ea18 100644 --- a/src/plugins/built_in/web_search_tool/utils/formatters.py +++ b/src/plugins/built_in/web_search_tool/utils/formatters.py @@ -1,6 +1,7 @@ """ Formatters for web search results """ + from typing import List, Dict, Any @@ -13,15 +14,15 @@ def format_search_results(results: List[Dict[str, Any]]) -> str: formatted_string = "根据网络搜索结果:\n\n" for i, res in enumerate(results, 1): - title = res.get("title", '无标题') - url = res.get("url", '#') - snippet = res.get("snippet", '无摘要') + title = res.get("title", "无标题") + url = res.get("url", "#") + snippet = res.get("snippet", "无摘要") provider = res.get("provider", "未知来源") - + formatted_string += f"{i}. **{title}** (来自: {provider})\n" formatted_string += f" - 摘要: {snippet}\n" formatted_string += f" - 来源: {url}\n\n" - + return formatted_string @@ -31,10 +32,10 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str: """ formatted_parts = [] for res in results: - title = res.get('title', '无标题') - url = res.get('url', '#') - snippet = res.get('snippet', '无摘要') - source = res.get('source', '未知') + title = res.get("title", "无标题") + url = res.get("url", "#") + snippet = res.get("snippet", "无摘要") + source = res.get("source", "未知") formatted_string = f"**{title}**\n" formatted_string += f"**内容摘要**:\n{snippet}\n" diff --git a/src/plugins/built_in/web_search_tool/utils/url_utils.py b/src/plugins/built_in/web_search_tool/utils/url_utils.py index 74afbc819..5bdde0a55 100644 --- a/src/plugins/built_in/web_search_tool/utils/url_utils.py +++ b/src/plugins/built_in/web_search_tool/utils/url_utils.py @@ -1,6 +1,7 @@ """ URL processing utilities """ + import re from typing import List @@ -12,11 +13,11 @@ def parse_urls_from_input(urls_input) -> List[str]: if isinstance(urls_input, str): # 如果是字符串,尝试解析为URL列表 # 提取所有HTTP/HTTPS URL - url_pattern = r'https?://[^\s\],]+' + url_pattern = r"https?://[^\s\],]+" urls = re.findall(url_pattern, urls_input) if not urls: # 如果没有找到标准URL,将整个字符串作为单个URL - if urls_input.strip().startswith(('http://', 'https://')): + if urls_input.strip().startswith(("http://", "https://")): urls = [urls_input.strip()] else: return [] @@ -24,7 +25,7 @@ def parse_urls_from_input(urls_input) -> List[str]: urls = [url.strip() for url in urls_input if isinstance(url, str) and url.strip()] else: return [] - + return urls @@ -34,6 +35,6 @@ def validate_urls(urls: List[str]) -> List[str]: """ valid_urls = [] for url in urls: - if url.startswith(('http://', 'https://')): + if url.startswith(("http://", "https://")): valid_urls.append(url) return valid_urls diff --git a/src/plugins/reminder_plugin/plugin.py b/src/plugins/reminder_plugin/plugin.py new file mode 100644 index 000000000..31ea899df --- /dev/null +++ b/src/plugins/reminder_plugin/plugin.py @@ -0,0 +1,216 @@ +import asyncio +from datetime import datetime +from typing import List, Tuple, Type +from dateutil.parser import parse as parse_datetime + +from src.common.logger import get_logger +from src.manager.async_task_manager import AsyncTask, async_task_manager +from src.person_info.person_info import get_person_info_manager +from src.plugin_system import ( + BaseAction, + ActionInfo, + BasePlugin, + register_plugin, + ActionActivationType, +) +from src.plugin_system.apis import send_api +from src.plugin_system.base.component_types import ChatType + +logger = get_logger(__name__) + + +# ============================ AsyncTask ============================ + + +class ReminderTask(AsyncTask): + def __init__( + self, + delay: float, + stream_id: str, + is_group: bool, + target_user_id: str, + target_user_name: str, + event_details: str, + creator_name: str, + ): + super().__init__(task_name=f"ReminderTask_{target_user_id}_{datetime.now().timestamp()}") + self.delay = delay + self.stream_id = stream_id + self.is_group = is_group + self.target_user_id = target_user_id + self.target_user_name = target_user_name + self.event_details = event_details + self.creator_name = creator_name + + async def run(self): + try: + if self.delay > 0: + logger.info(f"等待 {self.delay:.2f} 秒后执行提醒...") + await asyncio.sleep(self.delay) + + logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒") + + reminder_text = f"叮咚!这是 {self.creator_name} 让我准时提醒你的事情:\n\n{self.event_details}" + + if self.is_group: + # 在群聊中,构造 @ 消息段并发送 + group_id = self.stream_id.split("_")[-1] if "_" in self.stream_id else self.stream_id + message_payload = [ + {"type": "at", "data": {"qq": self.target_user_id}}, + {"type": "text", "data": {"text": f" {reminder_text}"}}, + ] + await send_api.adapter_command_to_stream( + action="send_group_msg", + params={"group_id": group_id, "message": message_payload}, + stream_id=self.stream_id, + ) + else: + # 在私聊中,直接发送文本 + await send_api.text_to_stream(text=reminder_text, stream_id=self.stream_id) + + logger.info(f"提醒任务 {self.task_name} 成功完成。") + + except Exception as e: + logger.error(f"执行提醒任务 {self.task_name} 时出错: {e}", exc_info=True) + + +# =============================== Actions =============================== + + +class RemindAction(BaseAction): + """一个能从对话中智能识别并设置定时提醒的动作。""" + + # === 基本信息 === + action_name = "set_reminder" + action_description = "根据用户的对话内容,智能地设置一个未来的提醒事项。" + activation_type = ActionActivationType.LLM_JUDGE + chat_type_allow = ChatType.ALL + + # === LLM 判断与参数提取 === + llm_judge_prompt = """ + 判断用户是否意图设置一个未来的提醒。 + - 必须包含明确的时间点或时间段(如“十分钟后”、“明天下午3点”、“周五”)。 + - 必须包含一个需要被提醒的事件。 + - 可能会包含需要提醒的特定人物。 + - 如果只是普通的聊天或询问时间,则不应触发。 + + 示例: + - "半小时后提醒我开会" -> 是 + - "明天下午三点叫张三来一下" -> 是 + - "别忘了周五把报告交了" -> 是 + - "现在几点了?" -> 否 + - "我明天下午有空" -> 否 + + 请只回答"是"或"否"。 + """ + action_parameters = { + "user_name": "需要被提醒的人的称呼或名字,如果没有明确指定给某人,则默认为'自己'", + "remind_time": "描述提醒时间的自然语言字符串,例如'十分钟后'或'明天下午3点'", + "event_details": "需要提醒的具体事件内容", + } + action_require = [ + "当用户请求在未来的某个时间点提醒他/她或别人某件事时使用", + "适用于包含明确时间信息和事件描述的对话", + "例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'", + ] + + async def execute(self) -> Tuple[bool, str]: + """执行设置提醒的动作""" + user_name = self.action_data.get("user_name") + remind_time_str = self.action_data.get("remind_time") + event_details = self.action_data.get("event_details") + + if not all([user_name, remind_time_str, event_details]): + missing_params = [ + p + for p, v in { + "user_name": user_name, + "remind_time": remind_time_str, + "event_details": event_details, + }.items() + if not v + ] + error_msg = f"缺少必要的提醒参数: {', '.join(missing_params)}" + logger.warning(f"[ReminderPlugin] LLM未能提取完整参数: {error_msg}") + return False, error_msg + + # 1. 解析时间 + try: + assert isinstance(remind_time_str, str) + target_time = parse_datetime(remind_time_str, fuzzy=True) + except Exception as e: + logger.error(f"[ReminderPlugin] 无法解析时间字符串 '{remind_time_str}': {e}") + await self.send_text(f"抱歉,我无法理解您说的时间 '{remind_time_str}',提醒设置失败。") + return False, f"无法解析时间 '{remind_time_str}'" + + now = datetime.now() + if target_time <= now: + await self.send_text("提醒时间必须是一个未来的时间点哦,提醒设置失败。") + return False, "提醒时间必须在未来" + + delay_seconds = (target_time - now).total_seconds() + + # 2. 解析用户 + person_manager = get_person_info_manager() + user_id_to_remind = None + user_name_to_remind = "" + + assert isinstance(user_name, str) + + if user_name.strip() in ["自己", "我", "me"]: + user_id_to_remind = self.user_id + user_name_to_remind = self.user_nickname + else: + user_info = await person_manager.get_person_info_by_name(user_name) + if not user_info or not user_info.get("user_id"): + logger.warning(f"[ReminderPlugin] 找不到名为 '{user_name}' 的用户") + await self.send_text(f"抱歉,我的联系人里找不到叫做 '{user_name}' 的人,提醒设置失败。") + return False, f"用户 '{user_name}' 不存在" + user_id_to_remind = user_info.get("user_id") + user_name_to_remind = user_name + + # 3. 创建并调度异步任务 + try: + assert user_id_to_remind is not None + assert event_details is not None + + reminder_task = ReminderTask( + delay=delay_seconds, + stream_id=self.chat_id, + is_group=self.is_group, + target_user_id=str(user_id_to_remind), + target_user_name=str(user_name_to_remind), + event_details=str(event_details), + creator_name=str(self.user_nickname), + ) + await async_task_manager.add_task(reminder_task) + + # 4. 发送确认消息 + confirm_message = f"好的,我记下了。\n将在 {target_time.strftime('%Y-%m-%d %H:%M:%S')} 提醒 {user_name_to_remind}:\n{event_details}" + await self.send_text(confirm_message) + + return True, "提醒设置成功" + except Exception as e: + logger.error(f"[ReminderPlugin] 创建提醒任务时出错: {e}", exc_info=True) + await self.send_text("抱歉,设置提醒时发生了一点内部错误。") + return False, "设置提醒时发生内部错误" + + +# =============================== Plugin =============================== + + +@register_plugin +class ReminderPlugin(BasePlugin): + """一个能从对话中智能识别并设置定时提醒的插件。""" + + # --- 插件基础信息 --- + plugin_name = "reminder_plugin" + enable_plugin = True + dependencies = [] + python_dependencies = [] + config_file_name = "config.toml" + config_schema = {} + + def get_plugin_components(self) -> List[Tuple[ActionInfo, Type[BaseAction]]]: + """注册插件的所有功能组件。""" + return [(RemindAction.get_action_info(), RemindAction)] diff --git a/src/schedule/database.py b/src/schedule/database.py index 5025c1fa3..9117b9586 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -301,4 +301,4 @@ async def has_active_plans(month: str) -> bool: return result.scalar_one() > 0 except Exception as e: logger.error(f"检查 {month} 的有效月度计划时发生错误: {e}") - return False \ No newline at end of file + return False diff --git a/src/schedule/llm_generator.py b/src/schedule/llm_generator.py index 5c1464c71..d3ec56bb6 100644 --- a/src/schedule/llm_generator.py +++ b/src/schedule/llm_generator.py @@ -226,4 +226,4 @@ class MonthlyPlanLLMGenerator: return plans except Exception as e: logger.error(f"解析月度计划响应时发生错误: {e}") - return [] \ No newline at end of file + return [] diff --git a/src/schedule/plan_manager.py b/src/schedule/plan_manager.py index 82f8a8e04..d72f55275 100644 --- a/src/schedule/plan_manager.py +++ b/src/schedule/plan_manager.py @@ -28,20 +28,20 @@ class PlanManager: if target_month is None: target_month = datetime.now().strftime("%Y-%m") - if not await has_active_plans(target_month): + if not has_active_plans(target_month): logger.info(f" {target_month} 没有任何有效的月度计划,将触发同步生成。") generation_successful = await self._generate_monthly_plans_logic(target_month) return generation_successful else: logger.info(f"{target_month} 已存在有效的月度计划。") - plans = await get_active_plans_for_month(target_month) + plans = get_active_plans_for_month(target_month) max_plans = global_config.planning_system.max_plans_per_month if len(plans) > max_plans: logger.warning(f"当前月度计划数量 ({len(plans)}) 超出上限 ({max_plans}),将自动删除多余的计划。") plans_to_delete = plans[: len(plans) - max_plans] delete_ids = [p.id for p in plans_to_delete] - await delete_plans_by_ids(delete_ids) # type: ignore - plans = await get_active_plans_for_month(target_month) + delete_plans_by_ids(delete_ids) # type: ignore + plans = get_active_plans_for_month(target_month) if plans: plan_texts = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans)]) @@ -64,11 +64,11 @@ class PlanManager: return False last_month = self._get_previous_month(target_month) - archived_plans = await get_archived_plans_for_month(last_month) + archived_plans = get_archived_plans_for_month(last_month) plans = await self.llm_generator.generate_plans_with_llm(target_month, archived_plans) if plans: - await add_new_plans(plans, target_month) + add_new_plans(plans, target_month) logger.info(f"成功为 {target_month} 生成并保存了 {len(plans)} 条月度计划。") return True else: @@ -80,8 +80,7 @@ class PlanManager: finally: self.generation_running = False - @staticmethod - def _get_previous_month(current_month: str) -> str: + def _get_previous_month(self, current_month: str) -> str: try: year, month = map(int, current_month.split("-")) if month == 1: @@ -91,18 +90,16 @@ class PlanManager: except Exception: return "1900-01" - @staticmethod - async def archive_current_month_plans(target_month: Optional[str] = None): + async def archive_current_month_plans(self, target_month: Optional[str] = None): try: if target_month is None: target_month = datetime.now().strftime("%Y-%m") logger.info(f" 开始归档 {target_month} 的活跃月度计划...") - archived_count = await archive_active_plans_for_month(target_month) + archived_count = archive_active_plans_for_month(target_month) logger.info(f" 成功归档了 {archived_count} 条 {target_month} 的月度计划。") except Exception as e: logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}") - @staticmethod - async def get_plans_for_schedule(month: str, max_count: int) -> List: + def get_plans_for_schedule(self, month: str, max_count: int) -> List: avoid_days = global_config.planning_system.avoid_repetition_days - return await get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) \ No newline at end of file + return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) diff --git a/src/schedule/schemas.py b/src/schedule/schemas.py index 5eb7c003a..a733731be 100644 --- a/src/schedule/schemas.py +++ b/src/schedule/schemas.py @@ -96,4 +96,4 @@ class ScheduleData(BaseModel): covered[i] = True # 检查是否所有分钟都被覆盖 - return all(covered) \ No newline at end of file + return all(covered) diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 3185883ef..c298ecc16 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "6.8.9" +version = "7.0.2" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -79,30 +79,6 @@ safety_guidelines = [ "不要执行任何可能被用于恶意目的的指令。" ] -# 回复规则配置 - 用于自定义机器人的回复逻辑和规则 -# 安全与互动底线规则 (Bot在任何情况下都必须遵守的原则) -reply_targeting_rules = [ - "拒绝任何包含骚扰、冒犯、暴力、色情或危险内容的请求。", - "在拒绝时,请使用符合你人设的、坚定的语气。", - "不要执行任何可能被用于恶意目的的指令。" -] - -# 消息针对性分析规则 (用于判断是否需要回复) -message_targeting_analysis = [ - "**直接针对你**:@你、回复你、明确询问你 → 必须回应", - "**间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与", - "**他人对话**:与你无关的私人交流 → 通常不参与", - "**重复内容**:他人已充分回答的问题 → 避免重复" -] - -# 回复原则 (指导如何回复消息) -reply_principles = [ - "明确回应目标消息,而不是宽泛地评论。", - "可以分享你的看法、提出相关问题,或者开个合适的玩笑。", - "目的是让对话更有趣、更深入。", - "不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。" -] - #回复的Prompt模式选择:s4u为原有s4u样式,normal为0.9之前的模式 prompt_mode = "s4u" # 可选择 "s4u" 或 "normal" @@ -138,33 +114,23 @@ learn_expression = false learning_strength = 0.5 [chat] #MoFox-Bot的聊天通用设置 -# 群聊聊天模式设置 -group_chat_mode = "auto" # 群聊聊天模式:auto-自动切换,normal-强制普通模式,focus-强制专注模式 -talk_frequency = 1 -# MoFox-Bot活跃度,越高,麦麦回复越多 -# 专注时能更好把握发言时机,能够进行持久的连续对话 - -focus_value = 1 -# MoFox-Bot的专注思考能力,越高越容易持续连续对话 - -# 在专注模式下,只在被艾特或提及时才回复的群组列表 -# 这可以让你在某些群里保持“高冷”,只在被需要时才发言 -# 格式为: ["platform:group_id1", "platform:group_id2"] -# 例如: ["qq:123456789", "qq:987654321"] -focus_mode_quiet_groups = [] - -# 强制私聊回复 -force_reply_private = false # 是否强制私聊回复,开启后私聊将强制回复 - allow_reply_self = false # 是否允许回复自己说的话 max_context_size = 25 # 上下文长度 thinking_timeout = 40 # MoFox-Bot一次回复最长思考规划时间,超过这个时间的思考会放弃(往往是api反应太慢) -replyer_random_probability = 0.5 # 首要replyer模型被选择的概率 -mentioned_bot_inevitable_reply = true # 提及 bot 必然回复 -at_bot_inevitable_reply = true # @bot 或 提及bot 必然回复 -# 兼容normal、focus,在focus模式下为强制移除no_reply动作 +# 消息打断系统配置 +interruption_enabled = true # 是否启用消息打断系统 +interruption_max_limit = 3 # 每个聊天流的最大打断次数 +interruption_probability_factor = 0.8 # 打断概率因子,当前打断次数/最大打断次数超过此值时触发概率下降 +interruption_afc_reduction = 0.05 # 每次连续打断降低的afc阈值数值 + +# 动态消息分发系统配置 +dynamic_distribution_enabled = true # 是否启用动态消息分发周期调整 +dynamic_distribution_base_interval = 5.0 # 基础分发间隔(秒) +dynamic_distribution_min_interval = 1.0 # 最小分发间隔(秒) +dynamic_distribution_max_interval = 30.0 # 最大分发间隔(秒) +dynamic_distribution_jitter_factor = 0.2 # 分发间隔随机扰动因子 talk_frequency_adjust = [ ["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"], @@ -309,7 +275,7 @@ enable_vector_instant_memory = true # 是否启用基于向量的瞬时记忆 memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ] [voice] -enable_asr = false # 是否启用语音识别,启用后MoFox-Bot可以识别语音消息,启用该功能需要配置语音识别模型[model.voice] +enable_asr = true # 是否启用语音识别,启用后MoFox-Bot可以识别语音消息,启用该功能需要配置语音识别模型[model.voice] [lpmm_knowledge] # lpmm知识库配置 enable = false # 是否启用lpmm知识库 @@ -362,7 +328,7 @@ enable = true # 是否启用回复分割器 split_mode = "punctuation" # 分割模式: "llm" - 由语言模型决定, "punctuation" - 基于标点符号 max_length = 512 # 回复允许的最大长度 max_sentence_num = 8 # 回复允许的最大句子数 -enable_kaomoji_protection = false # 是否启用颜文字保护 +enable_kaomoji_protection = true # 是否启用颜文字保护 [log] date_style = "m-d H:i:s" # 日期格式 @@ -541,4 +507,32 @@ name = "Maizone默认互通组" chat_ids = [ ["group", "111111"], # 示例群聊1 ["private", "222222"] # 示例私聊2 -] \ No newline at end of file +] + +[affinity_flow] +# 兴趣评分系统参数 +reply_action_interest_threshold = 0.62 # 回复动作兴趣阈值 +non_reply_action_interest_threshold = 0.48 # 非回复动作兴趣阈值 +high_match_interest_threshold = 0.65 # 高匹配兴趣阈值 +medium_match_interest_threshold = 0.5 # 中匹配兴趣阈值 +low_match_interest_threshold = 0.2 # 低匹配兴趣阈值 +high_match_keyword_multiplier = 1.8 # 高匹配关键词兴趣倍率 +medium_match_keyword_multiplier = 1.4 # 中匹配关键词兴趣倍率 +low_match_keyword_multiplier = 1.15 # 低匹配关键词兴趣倍率 +match_count_bonus = 0.05 # 匹配数关键词加成值 +max_match_bonus = 0.3 # 最大匹配数加成值 + +# 回复决策系统参数 +no_reply_threshold_adjustment = 0.1 # 不回复兴趣阈值调整值 +reply_cooldown_reduction = 2 # 回复后减少的不回复计数 +max_no_reply_count = 5 # 最大不回复计数次数 + +# 综合评分权重 +keyword_match_weight = 0.4 # 兴趣关键词匹配度权重 +mention_bot_weight = 0.3 # 提及bot分数权重 +relationship_weight = 0.3 # 人物关系分数权重 + +# 提及bot相关参数 +mention_bot_adjustment_threshold = 0.3 # 提及bot后的调整阈值 +mention_bot_interest_score = 0.6 # 提及bot的兴趣分 +base_relationship_score = 0.3 # 基础人物关系分 \ No newline at end of file diff --git a/template/model_config_template.toml b/template/model_config_template.toml index 8c9763c2f..7a08d362a 100644 --- a/template/model_config_template.toml +++ b/template/model_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "1.3.4" +version = "1.3.5" # 配置文件版本号迭代规则同bot_config.toml @@ -195,6 +195,11 @@ model_list = ["siliconflow-deepseek-v3"] temperature = 0.7 max_tokens = 1000 +[model_task_config.relationship_tracker] # 用户关系追踪模型 +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.7 +max_tokens = 1000 + #嵌入模型 [model_task_config.embedding] model_list = ["bge-m3"]