fix: 恢复template_info功能

This commit is contained in:
tcmofashi
2025-05-23 11:04:49 +08:00
parent 75eeea8d92
commit ff9efb1c5e
8 changed files with 149 additions and 53 deletions

View File

@@ -7,6 +7,7 @@ from typing import List, Optional, Dict, Any, Deque, Callable, Coroutine
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_stream import chat_manager from src.chat.message_receive.chat_stream import chat_manager
from rich.traceback import install from rich.traceback import install
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.chat.heart_flow.observation.observation import Observation from src.chat.heart_flow.observation.observation import Observation
@@ -228,8 +229,9 @@ class HeartFChatting:
thinking_id = "tid" + str(round(time.time(), 2)) thinking_id = "tid" + str(round(time.time(), 2))
self._current_cycle.set_thinking_id(thinking_id) self._current_cycle.set_thinking_id(thinking_id)
# 主循环:思考->决策->执行 # 主循环:思考->决策->执行
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
loop_info = await self._observe_process_plan_action_loop(cycle_timers, thinking_id) logger.debug(f"模板 {self.chat_stream.context.get_template_name()}")
loop_info = await self._observe_process_plan_action_loop(cycle_timers, thinking_id)
self._current_cycle.set_loop_info(loop_info) self._current_cycle.set_loop_info(loop_info)

View File

@@ -125,7 +125,6 @@ class PromptBuilder:
relation_prompt += await relationship_manager.build_relationship_info(person) relation_prompt += await relationship_manager.build_relationship_info(person)
else: else:
logger.warning(f"Invalid person tuple encountered for relationship prompt: {person}") logger.warning(f"Invalid person tuple encountered for relationship prompt: {person}")
mood_prompt = mood_manager.get_mood_prompt() mood_prompt = mood_manager.get_mood_prompt()
reply_styles1 = [ reply_styles1 = [
("然后给出日常且口语化的回复,平淡一些", 0.4), ("然后给出日常且口语化的回复,平淡一些", 0.4),
@@ -146,9 +145,11 @@ class PromptBuilder:
[style[0] for style in reply_styles2], weights=[style[1] for style in reply_styles2], k=1 [style[0] for style in reply_styles2], weights=[style[1] for style in reply_styles2], k=1
)[0] )[0]
memory_prompt = "" memory_prompt = ""
related_memory = await HippocampusManager.get_instance().get_memory_from_text( related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
) )
related_memory_info = "" related_memory_info = ""
if related_memory: if related_memory:
for memory in related_memory: for memory in related_memory:

View File

@@ -1,7 +1,7 @@
import time import time
from typing import Optional from typing import Optional
from src.chat.message_receive.message import MessageRecv, BaseMessageInfo from src.chat.message_receive.message import MessageRecv, BaseMessageInfo
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream, chat_manager
from src.chat.message_receive.message import UserInfo from src.chat.message_receive.message import UserInfo
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
import json import json

View File

@@ -77,6 +77,7 @@ class ChatBot:
message = MessageRecv(message_data) message = MessageRecv(message_data)
group_info = message.message_info.group_info group_info = message.message_info.group_info
user_info = message.message_info.user_info user_info = message.message_info.user_info
chat_manager.register_message(message)
# 确认从接口发来的message是否有自定义的prompt模板信息 # 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default: if message.message_info.template_info and not message.message_info.template_info.template_default:
@@ -86,7 +87,7 @@ class ChatBot:
if isinstance(template_items, dict): if isinstance(template_items, dict):
for k in template_items.keys(): for k in template_items.keys():
await Prompt.create_async(template_items[k], k) await Prompt.create_async(template_items[k], k)
print(f"注册{template_items[k]},{k}") logger.debug(f"注册{template_items[k]},{k}")
else: else:
template_group_name = None template_group_name = None

View File

@@ -2,13 +2,17 @@ import asyncio
import hashlib import hashlib
import time import time
import copy import copy
from typing import Dict, Optional from typing import Dict, Optional, TYPE_CHECKING
from ...common.database.database import db from ...common.database.database import db
from ...common.database.database_model import ChatStreams # 新增导入 from ...common.database.database_model import ChatStreams # 新增导入
from maim_message import GroupInfo, UserInfo from maim_message import GroupInfo, UserInfo
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from rich.traceback import install from rich.traceback import install
@@ -18,6 +22,23 @@ install(extra_lines=3)
logger = get_logger("chat_stream") logger = get_logger("chat_stream")
class ChatMessageContext:
"""聊天消息上下文,存储消息的上下文信息"""
def __init__(self, message: "MessageRecv"):
self.message = message
def get_template_name(self) -> str:
"""获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name
return None
def get_last_message(self) -> "MessageRecv":
"""获取最后一条消息"""
return self.message
class ChatStream: class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文""" """聊天流对象,存储一个完整的聊天上下文"""
@@ -36,6 +57,7 @@ class ChatStream:
self.create_time = data.get("create_time", time.time()) if data else time.time() self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False self.saved = False
self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""转换为字典格式""" """转换为字典格式"""
@@ -67,6 +89,10 @@ class ChatStream:
self.last_active_time = time.time() self.last_active_time = time.time()
self.saved = False self.saved = False
def set_context(self, message: "MessageRecv"):
"""设置聊天消息上下文"""
self.context = ChatMessageContext(message)
class ChatManager: class ChatManager:
"""聊天管理器,管理所有聊天流""" """聊天管理器,管理所有聊天流"""
@@ -82,6 +108,7 @@ class ChatManager:
def __init__(self): def __init__(self):
if not self._initialized: if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
try: try:
db.connect(reuse_if_open=True) db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在 # 确保 ChatStreams 表存在
@@ -113,6 +140,16 @@ class ChatManager:
except Exception as e: except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}") logger.error(f"聊天流自动保存失败: {str(e)}")
def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流"""
stream_id = self._generate_stream_id(
message.message_info.platform,
message.message_info.user_info,
message.message_info.group_info,
)
self.last_messages[stream_id] = message
logger.debug(f"注册消息到聊天流: {stream_id}")
@staticmethod @staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID""" """生成聊天流唯一ID"""
@@ -146,12 +183,19 @@ class ChatManager:
# 检查内存中是否存在 # 检查内存中是否存在
if stream_id in self.streams: if stream_id in self.streams:
stream = self.streams[stream_id] stream = self.streams[stream_id]
# 更新用户信息和群组信息 # 更新用户信息和群组信息
stream.update_active_time() stream.update_active_time()
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存 stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
stream.user_info = user_info stream.user_info = user_info
if group_info: if group_info:
stream.group_info = group_info stream.group_info = group_info
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
return stream return stream
# 检查数据库中是否存在 # 检查数据库中是否存在
@@ -202,14 +246,24 @@ class ChatManager:
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True) logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
raise e raise e
stream = copy.deepcopy(stream)
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
# 保存到内存和数据库 # 保存到内存和数据库
self.streams[stream_id] = stream self.streams[stream_id] = stream
await self._save_stream(stream) await self._save_stream(stream)
return copy.deepcopy(stream) return stream
def get_stream(self, stream_id: str) -> Optional[ChatStream]: def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流""" """通过stream_id获取聊天流"""
return self.streams.get(stream_id) stream = self.streams.get(stream_id)
if stream_id in self.last_messages:
stream.set_context(self.last_messages[stream_id])
return stream
def get_stream_by_info( def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
@@ -306,6 +360,8 @@ class ChatManager:
stream = ChatStream.from_dict(data) stream = ChatStream.from_dict(data)
stream.saved = True stream.saved = True
self.streams[stream.stream_id] = stream self.streams[stream.stream_id] = stream
if stream.stream_id in self.last_messages:
stream.set_context(self.last_messages[stream.stream_id])
except Exception as e: except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True) logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)

View File

@@ -1,12 +1,14 @@
import time import time
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Any from typing import Optional, Any, TYPE_CHECKING
import urllib3 import urllib3
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from .chat_stream import ChatStream
if TYPE_CHECKING:
from .chat_stream import ChatStream
from ..utils.utils_image import image_manager from ..utils.utils_image import image_manager
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from rich.traceback import install from rich.traceback import install
@@ -25,7 +27,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@dataclass @dataclass
class Message(MessageBase): class Message(MessageBase):
chat_stream: ChatStream = None chat_stream: "ChatStream" = None
reply: Optional["Message"] = None reply: Optional["Message"] = None
detailed_plain_text: str = "" detailed_plain_text: str = ""
processed_plain_text: str = "" processed_plain_text: str = ""
@@ -34,7 +36,7 @@ class Message(MessageBase):
def __init__( def __init__(
self, self,
message_id: str, message_id: str,
chat_stream: ChatStream, chat_stream: "ChatStream",
user_info: UserInfo, user_info: UserInfo,
message_segment: Optional[Seg] = None, message_segment: Optional[Seg] = None,
timestamp: Optional[float] = None, timestamp: Optional[float] = None,
@@ -111,7 +113,7 @@ class MessageRecv(Message):
self.detailed_plain_text = "" # 初始化为空字符串 self.detailed_plain_text = "" # 初始化为空字符串
self.is_emoji = False self.is_emoji = False
def update_chat_stream(self, chat_stream: ChatStream): def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream self.chat_stream = chat_stream
async def process(self) -> None: async def process(self) -> None:
@@ -165,7 +167,7 @@ class MessageProcessBase(Message):
def __init__( def __init__(
self, self,
message_id: str, message_id: str,
chat_stream: ChatStream, chat_stream: "ChatStream",
bot_user_info: UserInfo, bot_user_info: UserInfo,
message_segment: Optional[Seg] = None, message_segment: Optional[Seg] = None,
reply: Optional["MessageRecv"] = None, reply: Optional["MessageRecv"] = None,
@@ -241,7 +243,7 @@ class MessageThinking(MessageProcessBase):
def __init__( def __init__(
self, self,
message_id: str, message_id: str,
chat_stream: ChatStream, chat_stream: "ChatStream",
bot_user_info: UserInfo, bot_user_info: UserInfo,
reply: Optional["MessageRecv"] = None, reply: Optional["MessageRecv"] = None,
thinking_start_time: float = 0, thinking_start_time: float = 0,
@@ -269,7 +271,7 @@ class MessageSending(MessageProcessBase):
def __init__( def __init__(
self, self,
message_id: str, message_id: str,
chat_stream: ChatStream, chat_stream: "ChatStream",
bot_user_info: UserInfo, bot_user_info: UserInfo,
sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复 sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复
message_segment: Seg, message_segment: Seg,
@@ -353,7 +355,7 @@ class MessageSending(MessageProcessBase):
class MessageSet: class MessageSet:
"""消息集合类,可以存储多个发送消息""" """消息集合类,可以存储多个发送消息"""
def __init__(self, chat_stream: ChatStream, message_id: str): def __init__(self, chat_stream: "ChatStream", message_id: str):
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.message_id = message_id self.message_id = message_id
self.messages: list[MessageSending] = [] self.messages: list[MessageSending] = []

View File

@@ -14,6 +14,7 @@ from src.chat.message_receive.chat_stream import ChatStream, chat_manager
from src.chat.person_info.relationship_manager import relationship_manager from src.chat.person_info.relationship_manager import relationship_manager
from src.chat.utils.info_catcher import info_catcher_manager from src.chat.utils.info_catcher import info_catcher_manager
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.chat.utils.prompt_builder import global_prompt_manager
from .normal_chat_generator import NormalChatGenerator from .normal_chat_generator import NormalChatGenerator
from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet
from src.chat.message_receive.message_sender import message_manager from src.chat.message_receive.message_sender import message_manager
@@ -194,31 +195,31 @@ class NormalChat:
通常由start_monitoring_interest()启动 通常由start_monitoring_interest()启动
""" """
while True: while True:
await asyncio.sleep(0.5) # 每秒检查一次 async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
# 检查任务是否已被取消 await asyncio.sleep(0.5) # 每秒检查一次
if self._chat_task is None or self._chat_task.cancelled(): # 检查任务是否已被取消
logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出") if self._chat_task is None or self._chat_task.cancelled():
break logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出")
break
items_to_process = list(self.interest_dict.items())
if not items_to_process:
continue
items_to_process = list(self.interest_dict.items()) # 处理每条兴趣消息
if not items_to_process: for msg_id, (message, interest_value, is_mentioned) in items_to_process:
continue try:
# 处理消息
# 处理每条兴趣消息 await self.normal_response(
for msg_id, (message, interest_value, is_mentioned) in items_to_process: message=message,
try: is_mentioned=is_mentioned,
# 处理消息 interested_rate=interest_value,
await self.normal_response( rewind_response=False,
message=message, )
is_mentioned=is_mentioned, except Exception as e:
interested_rate=interest_value, logger.error(f"[{self.stream_name}] 处理兴趣消息{msg_id}时出错: {e}\n{traceback.format_exc()}")
rewind_response=False, finally:
) self.interest_dict.pop(msg_id, None)
except Exception as e:
logger.error(f"[{self.stream_name}] 处理兴趣消息{msg_id}时出错: {e}\n{traceback.format_exc()}")
finally:
self.interest_dict.pop(msg_id, None)
# 改为实例方法, 移除 chat 参数 # 改为实例方法, 移除 chat 参数
async def normal_response( async def normal_response(

View File

@@ -2,6 +2,7 @@ from typing import Dict, Any, Optional, List, Union
import re import re
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import asyncio import asyncio
import contextvars
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
# import traceback # import traceback
@@ -15,29 +16,59 @@ logger = get_module_logger("prompt_build")
class PromptContext: class PromptContext:
def __init__(self): def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {} self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._current_context: Optional[str] = None # 使用contextvars创建协程上下文变量
self._context_lock = asyncio.Lock() # 添加异步锁 self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock() # 保留锁用于其他操作
@property
def _current_context(self) -> Optional[str]:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value)
@asynccontextmanager @asynccontextmanager
async def async_scope(self, context_id: str): async def async_scope(self, context_id: Optional[str] = None):
"""创建一个异步的临时提示模板作用域""" """创建一个异步的临时提示模板作用域"""
async with self._context_lock: # 保存当前上下文并设置新上下文
if context_id not in self._context_prompts: if context_id is not None:
self._context_prompts[context_id] = {} async with self._context_lock:
if context_id not in self._context_prompts:
self._context_prompts[context_id] = {}
# 保存当前协程的上下文值,不影响其他协程
previous_context = self._current_context previous_context = self._current_context
self._current_context = context_id # 设置当前协程的新上下文
token = self._current_context_var.set(context_id)
else:
# 如果没有提供新上下文,保持当前上下文不变
previous_context = self._current_context
token = None
try: try:
yield self yield self
finally: finally:
async with self._context_lock: # 恢复之前的上下文
self._current_context = previous_context if context_id is not None:
if token:
self._current_context_var.reset(token)
else:
self._current_context = previous_context
async def get_prompt_async(self, name: str) -> Optional["Prompt"]: async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板""" """异步获取当前作用域中的提示模板"""
async with self._context_lock: async with self._context_lock:
if self._current_context and name in self._context_prompts[self._current_context]: current_context = self._current_context
return self._context_prompts[self._current_context][name] logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
if (
current_context
and current_context in self._context_prompts
and name in self._context_prompts[current_context]
):
return self._context_prompts[current_context][name]
return None return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None: async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
@@ -56,8 +87,8 @@ class PromptManager:
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
@asynccontextmanager @asynccontextmanager
async def async_message_scope(self, message_id: str): async def async_message_scope(self, message_id: Optional[str] = None):
"""为消息处理创建异步临时作用域""" """为消息处理创建异步临时作用域,支持 message_id 为 None 的情况"""
async with self._context.async_scope(message_id): async with self._context.async_scope(message_id):
yield self yield self
@@ -65,9 +96,11 @@ class PromptManager:
# 首先尝试从当前上下文获取 # 首先尝试从当前上下文获取
context_prompt = await self._context.get_prompt_async(name) context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None: if context_prompt is not None:
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
return context_prompt return context_prompt
# 如果上下文中不存在,则使用全局提示模板 # 如果上下文中不存在,则使用全局提示模板
async with self._lock: async with self._lock:
logger.debug(f"从全局获取提示词: {name}")
if name not in self._prompts: if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found") raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name] return self._prompts[name]