diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index 4f17f9bdf..c6205ce47 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -7,6 +7,7 @@ from typing import List, Optional, Dict, Any, Deque, Callable, Coroutine from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import chat_manager from rich.traceback import install +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.common.logger_manager import get_logger from src.chat.utils.timer_calculator import Timer from src.chat.heart_flow.observation.observation import Observation @@ -228,8 +229,9 @@ class HeartFChatting: thinking_id = "tid" + str(round(time.time(), 2)) self._current_cycle.set_thinking_id(thinking_id) # 主循环:思考->决策->执行 - - loop_info = await self._observe_process_plan_action_loop(cycle_timers, thinking_id) + async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): + logger.debug(f"模板 {self.chat_stream.context.get_template_name()}") + loop_info = await self._observe_process_plan_action_loop(cycle_timers, thinking_id) self._current_cycle.set_loop_info(loop_info) diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index 532ceccd1..60b30cfa2 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -125,7 +125,6 @@ class PromptBuilder: relation_prompt += await relationship_manager.build_relationship_info(person) else: logger.warning(f"Invalid person tuple encountered for relationship prompt: {person}") - mood_prompt = mood_manager.get_mood_prompt() reply_styles1 = [ ("然后给出日常且口语化的回复,平淡一些", 0.4), @@ -146,9 +145,11 @@ class PromptBuilder: [style[0] for style in reply_styles2], weights=[style[1] for style in reply_styles2], k=1 )[0] memory_prompt = "" + related_memory = await HippocampusManager.get_instance().get_memory_from_text( text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False ) + related_memory_info = "" if related_memory: for memory in related_memory: diff --git a/src/chat/focus_chat/hfc_utils.py b/src/chat/focus_chat/hfc_utils.py index 36907c4c0..050684743 100644 --- a/src/chat/focus_chat/hfc_utils.py +++ b/src/chat/focus_chat/hfc_utils.py @@ -1,7 +1,7 @@ import time from typing import Optional from src.chat.message_receive.message import MessageRecv, BaseMessageInfo -from src.chat.message_receive.chat_stream import ChatStream +from src.chat.message_receive.chat_stream import ChatStream, chat_manager from src.chat.message_receive.message import UserInfo from src.common.logger_manager import get_logger import json diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 3b9a6f929..3e776f9a1 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -77,6 +77,7 @@ class ChatBot: message = MessageRecv(message_data) group_info = message.message_info.group_info user_info = message.message_info.user_info + chat_manager.register_message(message) # 确认从接口发来的message是否有自定义的prompt模板信息 if message.message_info.template_info and not message.message_info.template_info.template_default: @@ -86,7 +87,7 @@ class ChatBot: if isinstance(template_items, dict): for k in template_items.keys(): await Prompt.create_async(template_items[k], k) - print(f"注册{template_items[k]},{k}") + logger.debug(f"注册{template_items[k]},{k}") else: template_group_name = None diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index e00fc7370..ef7998752 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -2,13 +2,17 @@ import asyncio import hashlib import time import copy -from typing import Dict, Optional +from typing import Dict, Optional, TYPE_CHECKING from ...common.database.database import db from ...common.database.database_model import ChatStreams # 新增导入 from maim_message import GroupInfo, UserInfo +# 避免循环导入,使用TYPE_CHECKING进行类型提示 +if TYPE_CHECKING: + from .message import MessageRecv + from src.common.logger_manager import get_logger from rich.traceback import install @@ -18,6 +22,23 @@ install(extra_lines=3) logger = get_logger("chat_stream") +class ChatMessageContext: + """聊天消息上下文,存储消息的上下文信息""" + + def __init__(self, message: "MessageRecv"): + self.message = message + + def get_template_name(self) -> str: + """获取模板名称""" + if self.message.message_info.template_info and not self.message.message_info.template_info.template_default: + return self.message.message_info.template_info.template_name + return None + + def get_last_message(self) -> "MessageRecv": + """获取最后一条消息""" + return self.message + + class ChatStream: """聊天流对象,存储一个完整的聊天上下文""" @@ -36,6 +57,7 @@ class ChatStream: self.create_time = data.get("create_time", time.time()) if data else time.time() self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time self.saved = False + self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息 def to_dict(self) -> dict: """转换为字典格式""" @@ -67,6 +89,10 @@ class ChatStream: self.last_active_time = time.time() self.saved = False + def set_context(self, message: "MessageRecv"): + """设置聊天消息上下文""" + self.context = ChatMessageContext(message) + class ChatManager: """聊天管理器,管理所有聊天流""" @@ -82,6 +108,7 @@ class ChatManager: def __init__(self): if not self._initialized: self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream + self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message try: db.connect(reuse_if_open=True) # 确保 ChatStreams 表存在 @@ -113,6 +140,16 @@ class ChatManager: except Exception as e: logger.error(f"聊天流自动保存失败: {str(e)}") + def register_message(self, message: "MessageRecv"): + """注册消息到聊天流""" + stream_id = self._generate_stream_id( + message.message_info.platform, + message.message_info.user_info, + message.message_info.group_info, + ) + self.last_messages[stream_id] = message + logger.debug(f"注册消息到聊天流: {stream_id}") + @staticmethod def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: """生成聊天流唯一ID""" @@ -146,12 +183,19 @@ class ChatManager: # 检查内存中是否存在 if stream_id in self.streams: stream = self.streams[stream_id] + # 更新用户信息和群组信息 stream.update_active_time() stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存 stream.user_info = user_info if group_info: stream.group_info = group_info + from .message import MessageRecv # 延迟导入,避免循环引用 + + if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): + stream.set_context(self.last_messages[stream_id]) + else: + logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") return stream # 检查数据库中是否存在 @@ -202,14 +246,24 @@ class ChatManager: logger.error(f"获取或创建聊天流失败: {e}", exc_info=True) raise e + stream = copy.deepcopy(stream) + from .message import MessageRecv # 延迟导入,避免循环引用 + + if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): + stream.set_context(self.last_messages[stream_id]) + else: + logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") # 保存到内存和数据库 self.streams[stream_id] = stream await self._save_stream(stream) - return copy.deepcopy(stream) + return stream def get_stream(self, stream_id: str) -> Optional[ChatStream]: """通过stream_id获取聊天流""" - return self.streams.get(stream_id) + stream = self.streams.get(stream_id) + if stream_id in self.last_messages: + stream.set_context(self.last_messages[stream_id]) + return stream def get_stream_by_info( self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None @@ -306,6 +360,8 @@ class ChatManager: stream = ChatStream.from_dict(data) stream.saved = True self.streams[stream.stream_id] = stream + if stream.stream_id in self.last_messages: + stream.set_context(self.last_messages[stream.stream_id]) except Exception as e: logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index a42a11a82..20691ce1a 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -1,12 +1,14 @@ import time from abc import abstractmethod from dataclasses import dataclass -from typing import Optional, Any +from typing import Optional, Any, TYPE_CHECKING import urllib3 from src.common.logger_manager import get_logger -from .chat_stream import ChatStream + +if TYPE_CHECKING: + from .chat_stream import ChatStream from ..utils.utils_image import image_manager from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase from rich.traceback import install @@ -25,7 +27,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @dataclass class Message(MessageBase): - chat_stream: ChatStream = None + chat_stream: "ChatStream" = None reply: Optional["Message"] = None detailed_plain_text: str = "" processed_plain_text: str = "" @@ -34,7 +36,7 @@ class Message(MessageBase): def __init__( self, message_id: str, - chat_stream: ChatStream, + chat_stream: "ChatStream", user_info: UserInfo, message_segment: Optional[Seg] = None, timestamp: Optional[float] = None, @@ -111,7 +113,7 @@ class MessageRecv(Message): self.detailed_plain_text = "" # 初始化为空字符串 self.is_emoji = False - def update_chat_stream(self, chat_stream: ChatStream): + def update_chat_stream(self, chat_stream: "ChatStream"): self.chat_stream = chat_stream async def process(self) -> None: @@ -165,7 +167,7 @@ class MessageProcessBase(Message): def __init__( self, message_id: str, - chat_stream: ChatStream, + chat_stream: "ChatStream", bot_user_info: UserInfo, message_segment: Optional[Seg] = None, reply: Optional["MessageRecv"] = None, @@ -241,7 +243,7 @@ class MessageThinking(MessageProcessBase): def __init__( self, message_id: str, - chat_stream: ChatStream, + chat_stream: "ChatStream", bot_user_info: UserInfo, reply: Optional["MessageRecv"] = None, thinking_start_time: float = 0, @@ -269,7 +271,7 @@ class MessageSending(MessageProcessBase): def __init__( self, message_id: str, - chat_stream: ChatStream, + chat_stream: "ChatStream", bot_user_info: UserInfo, sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复 message_segment: Seg, @@ -353,7 +355,7 @@ class MessageSending(MessageProcessBase): class MessageSet: """消息集合类,可以存储多个发送消息""" - def __init__(self, chat_stream: ChatStream, message_id: str): + def __init__(self, chat_stream: "ChatStream", message_id: str): self.chat_stream = chat_stream self.message_id = message_id self.messages: list[MessageSending] = [] diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index bd5322137..d38c77947 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -14,6 +14,7 @@ from src.chat.message_receive.chat_stream import ChatStream, chat_manager from src.chat.person_info.relationship_manager import relationship_manager from src.chat.utils.info_catcher import info_catcher_manager from src.chat.utils.timer_calculator import Timer +from src.chat.utils.prompt_builder import global_prompt_manager from .normal_chat_generator import NormalChatGenerator from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet from src.chat.message_receive.message_sender import message_manager @@ -194,31 +195,31 @@ class NormalChat: 通常由start_monitoring_interest()启动 """ while True: - await asyncio.sleep(0.5) # 每秒检查一次 - # 检查任务是否已被取消 - if self._chat_task is None or self._chat_task.cancelled(): - logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出") - break + async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): + await asyncio.sleep(0.5) # 每秒检查一次 + # 检查任务是否已被取消 + if self._chat_task is None or self._chat_task.cancelled(): + logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出") + break + items_to_process = list(self.interest_dict.items()) + if not items_to_process: + continue - items_to_process = list(self.interest_dict.items()) - if not items_to_process: - continue - - # 处理每条兴趣消息 - for msg_id, (message, interest_value, is_mentioned) in items_to_process: - try: - # 处理消息 - await self.normal_response( - message=message, - is_mentioned=is_mentioned, - interested_rate=interest_value, - rewind_response=False, - ) - except Exception as e: - logger.error(f"[{self.stream_name}] 处理兴趣消息{msg_id}时出错: {e}\n{traceback.format_exc()}") - finally: - self.interest_dict.pop(msg_id, None) + # 处理每条兴趣消息 + for msg_id, (message, interest_value, is_mentioned) in items_to_process: + try: + # 处理消息 + await self.normal_response( + message=message, + is_mentioned=is_mentioned, + interested_rate=interest_value, + rewind_response=False, + ) + except Exception as e: + logger.error(f"[{self.stream_name}] 处理兴趣消息{msg_id}时出错: {e}\n{traceback.format_exc()}") + finally: + self.interest_dict.pop(msg_id, None) # 改为实例方法, 移除 chat 参数 async def normal_response( diff --git a/src/chat/utils/prompt_builder.py b/src/chat/utils/prompt_builder.py index 4a226a022..ced5adc53 100644 --- a/src/chat/utils/prompt_builder.py +++ b/src/chat/utils/prompt_builder.py @@ -2,6 +2,7 @@ from typing import Dict, Any, Optional, List, Union import re from contextlib import asynccontextmanager import asyncio +import contextvars from src.common.logger import get_module_logger # import traceback @@ -15,29 +16,59 @@ logger = get_module_logger("prompt_build") class PromptContext: def __init__(self): self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {} - self._current_context: Optional[str] = None - self._context_lock = asyncio.Lock() # 添加异步锁 + # 使用contextvars创建协程上下文变量 + self._current_context_var = contextvars.ContextVar("current_context", default=None) + self._context_lock = asyncio.Lock() # 保留锁用于其他操作 + + @property + def _current_context(self) -> Optional[str]: + """获取当前协程的上下文ID""" + return self._current_context_var.get() + + @_current_context.setter + def _current_context(self, value: Optional[str]): + """设置当前协程的上下文ID""" + self._current_context_var.set(value) @asynccontextmanager - async def async_scope(self, context_id: str): + async def async_scope(self, context_id: Optional[str] = None): """创建一个异步的临时提示模板作用域""" - async with self._context_lock: - if context_id not in self._context_prompts: - self._context_prompts[context_id] = {} + # 保存当前上下文并设置新上下文 + if context_id is not None: + async with self._context_lock: + if context_id not in self._context_prompts: + self._context_prompts[context_id] = {} + # 保存当前协程的上下文值,不影响其他协程 previous_context = self._current_context - self._current_context = context_id + # 设置当前协程的新上下文 + token = self._current_context_var.set(context_id) + else: + # 如果没有提供新上下文,保持当前上下文不变 + previous_context = self._current_context + token = None + try: yield self finally: - async with self._context_lock: - self._current_context = previous_context + # 恢复之前的上下文 + if context_id is not None: + if token: + self._current_context_var.reset(token) + else: + self._current_context = previous_context async def get_prompt_async(self, name: str) -> Optional["Prompt"]: """异步获取当前作用域中的提示模板""" async with self._context_lock: - if self._current_context and name in self._context_prompts[self._current_context]: - return self._context_prompts[self._current_context][name] + current_context = self._current_context + logger.debug(f"获取提示词: {name} 当前上下文: {current_context}") + if ( + current_context + and current_context in self._context_prompts + and name in self._context_prompts[current_context] + ): + return self._context_prompts[current_context][name] return None async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None: @@ -56,8 +87,8 @@ class PromptManager: self._lock = asyncio.Lock() @asynccontextmanager - async def async_message_scope(self, message_id: str): - """为消息处理创建异步临时作用域""" + async def async_message_scope(self, message_id: Optional[str] = None): + """为消息处理创建异步临时作用域,支持 message_id 为 None 的情况""" async with self._context.async_scope(message_id): yield self @@ -65,9 +96,11 @@ class PromptManager: # 首先尝试从当前上下文获取 context_prompt = await self._context.get_prompt_async(name) if context_prompt is not None: + logger.debug(f"从上下文中获取提示词: {name} {context_prompt}") return context_prompt # 如果上下文中不存在,则使用全局提示模板 async with self._lock: + logger.debug(f"从全局获取提示词: {name}") if name not in self._prompts: raise KeyError(f"Prompt '{name}' not found") return self._prompts[name]