fix: 恢复template_info功能

This commit is contained in:
tcmofashi
2025-05-23 11:04:49 +08:00
parent 75eeea8d92
commit ff9efb1c5e
8 changed files with 149 additions and 53 deletions

View File

@@ -7,6 +7,7 @@ from typing import List, Optional, Dict, Any, Deque, Callable, Coroutine
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_stream import chat_manager
from rich.traceback import install
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.common.logger_manager import get_logger
from src.chat.utils.timer_calculator import Timer
from src.chat.heart_flow.observation.observation import Observation
@@ -228,7 +229,8 @@ class HeartFChatting:
thinking_id = "tid" + str(round(time.time(), 2))
self._current_cycle.set_thinking_id(thinking_id)
# 主循环:思考->决策->执行
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
logger.debug(f"模板 {self.chat_stream.context.get_template_name()}")
loop_info = await self._observe_process_plan_action_loop(cycle_timers, thinking_id)
self._current_cycle.set_loop_info(loop_info)

View File

@@ -125,7 +125,6 @@ class PromptBuilder:
relation_prompt += await relationship_manager.build_relationship_info(person)
else:
logger.warning(f"Invalid person tuple encountered for relationship prompt: {person}")
mood_prompt = mood_manager.get_mood_prompt()
reply_styles1 = [
("然后给出日常且口语化的回复,平淡一些", 0.4),
@@ -146,9 +145,11 @@ class PromptBuilder:
[style[0] for style in reply_styles2], weights=[style[1] for style in reply_styles2], k=1
)[0]
memory_prompt = ""
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
)
related_memory_info = ""
if related_memory:
for memory in related_memory:

View File

@@ -1,7 +1,7 @@
import time
from typing import Optional
from src.chat.message_receive.message import MessageRecv, BaseMessageInfo
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_stream import ChatStream, chat_manager
from src.chat.message_receive.message import UserInfo
from src.common.logger_manager import get_logger
import json

View File

@@ -77,6 +77,7 @@ class ChatBot:
message = MessageRecv(message_data)
group_info = message.message_info.group_info
user_info = message.message_info.user_info
chat_manager.register_message(message)
# 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default:
@@ -86,7 +87,7 @@ class ChatBot:
if isinstance(template_items, dict):
for k in template_items.keys():
await Prompt.create_async(template_items[k], k)
print(f"注册{template_items[k]},{k}")
logger.debug(f"注册{template_items[k]},{k}")
else:
template_group_name = None

View File

@@ -2,13 +2,17 @@ import asyncio
import hashlib
import time
import copy
from typing import Dict, Optional
from typing import Dict, Optional, TYPE_CHECKING
from ...common.database.database import db
from ...common.database.database_model import ChatStreams # 新增导入
from maim_message import GroupInfo, UserInfo
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
from src.common.logger_manager import get_logger
from rich.traceback import install
@@ -18,6 +22,23 @@ install(extra_lines=3)
logger = get_logger("chat_stream")
class ChatMessageContext:
"""聊天消息上下文,存储消息的上下文信息"""
def __init__(self, message: "MessageRecv"):
self.message = message
def get_template_name(self) -> str:
"""获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name
return None
def get_last_message(self) -> "MessageRecv":
"""获取最后一条消息"""
return self.message
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
@@ -36,6 +57,7 @@ class ChatStream:
self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False
self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息
def to_dict(self) -> dict:
"""转换为字典格式"""
@@ -67,6 +89,10 @@ class ChatStream:
self.last_active_time = time.time()
self.saved = False
def set_context(self, message: "MessageRecv"):
"""设置聊天消息上下文"""
self.context = ChatMessageContext(message)
class ChatManager:
"""聊天管理器,管理所有聊天流"""
@@ -82,6 +108,7 @@ class ChatManager:
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
@@ -113,6 +140,16 @@ class ChatManager:
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流"""
stream_id = self._generate_stream_id(
message.message_info.platform,
message.message_info.user_info,
message.message_info.group_info,
)
self.last_messages[stream_id] = message
logger.debug(f"注册消息到聊天流: {stream_id}")
@staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID"""
@@ -146,12 +183,19 @@ class ChatManager:
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
stream.user_info = user_info
if group_info:
stream.group_info = group_info
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
return stream
# 检查数据库中是否存在
@@ -202,14 +246,24 @@ class ChatManager:
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
raise e
stream = copy.deepcopy(stream)
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return copy.deepcopy(stream)
return stream
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""
return self.streams.get(stream_id)
stream = self.streams.get(stream_id)
if stream_id in self.last_messages:
stream.set_context(self.last_messages[stream_id])
return stream
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
@@ -306,6 +360,8 @@ class ChatManager:
stream = ChatStream.from_dict(data)
stream.saved = True
self.streams[stream.stream_id] = stream
if stream.stream_id in self.last_messages:
stream.set_context(self.last_messages[stream.stream_id])
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)

View File

@@ -1,11 +1,13 @@
import time
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional, Any
from typing import Optional, Any, TYPE_CHECKING
import urllib3
from src.common.logger_manager import get_logger
if TYPE_CHECKING:
from .chat_stream import ChatStream
from ..utils.utils_image import image_manager
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
@@ -25,7 +27,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@dataclass
class Message(MessageBase):
chat_stream: ChatStream = None
chat_stream: "ChatStream" = None
reply: Optional["Message"] = None
detailed_plain_text: str = ""
processed_plain_text: str = ""
@@ -34,7 +36,7 @@ class Message(MessageBase):
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
chat_stream: "ChatStream",
user_info: UserInfo,
message_segment: Optional[Seg] = None,
timestamp: Optional[float] = None,
@@ -111,7 +113,7 @@ class MessageRecv(Message):
self.detailed_plain_text = "" # 初始化为空字符串
self.is_emoji = False
def update_chat_stream(self, chat_stream: ChatStream):
def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream
async def process(self) -> None:
@@ -165,7 +167,7 @@ class MessageProcessBase(Message):
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional["MessageRecv"] = None,
@@ -241,7 +243,7 @@ class MessageThinking(MessageProcessBase):
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
reply: Optional["MessageRecv"] = None,
thinking_start_time: float = 0,
@@ -269,7 +271,7 @@ class MessageSending(MessageProcessBase):
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复
message_segment: Seg,
@@ -353,7 +355,7 @@ class MessageSending(MessageProcessBase):
class MessageSet:
"""消息集合类,可以存储多个发送消息"""
def __init__(self, chat_stream: ChatStream, message_id: str):
def __init__(self, chat_stream: "ChatStream", message_id: str):
self.chat_stream = chat_stream
self.message_id = message_id
self.messages: list[MessageSending] = []

View File

@@ -14,6 +14,7 @@ from src.chat.message_receive.chat_stream import ChatStream, chat_manager
from src.chat.person_info.relationship_manager import relationship_manager
from src.chat.utils.info_catcher import info_catcher_manager
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.prompt_builder import global_prompt_manager
from .normal_chat_generator import NormalChatGenerator
from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet
from src.chat.message_receive.message_sender import message_manager
@@ -194,13 +195,13 @@ class NormalChat:
通常由start_monitoring_interest()启动
"""
while True:
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
await asyncio.sleep(0.5) # 每秒检查一次
# 检查任务是否已被取消
if self._chat_task is None or self._chat_task.cancelled():
logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出")
break
items_to_process = list(self.interest_dict.items())
if not items_to_process:
continue

View File

@@ -2,6 +2,7 @@ from typing import Dict, Any, Optional, List, Union
import re
from contextlib import asynccontextmanager
import asyncio
import contextvars
from src.common.logger import get_module_logger
# import traceback
@@ -15,29 +16,59 @@ logger = get_module_logger("prompt_build")
class PromptContext:
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._current_context: Optional[str] = None
self._context_lock = asyncio.Lock() # 添加异步锁
# 使用contextvars创建协程上下文变量
self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock() # 保留锁用于其他操作
@property
def _current_context(self) -> Optional[str]:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value)
@asynccontextmanager
async def async_scope(self, context_id: str):
async def async_scope(self, context_id: Optional[str] = None):
"""创建一个异步的临时提示模板作用域"""
# 保存当前上下文并设置新上下文
if context_id is not None:
async with self._context_lock:
if context_id not in self._context_prompts:
self._context_prompts[context_id] = {}
# 保存当前协程的上下文值,不影响其他协程
previous_context = self._current_context
self._current_context = context_id
# 设置当前协程的新上下文
token = self._current_context_var.set(context_id)
else:
# 如果没有提供新上下文,保持当前上下文不变
previous_context = self._current_context
token = None
try:
yield self
finally:
async with self._context_lock:
# 恢复之前的上下文
if context_id is not None:
if token:
self._current_context_var.reset(token)
else:
self._current_context = previous_context
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板"""
async with self._context_lock:
if self._current_context and name in self._context_prompts[self._current_context]:
return self._context_prompts[self._current_context][name]
current_context = self._current_context
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
if (
current_context
and current_context in self._context_prompts
and name in self._context_prompts[current_context]
):
return self._context_prompts[current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
@@ -56,8 +87,8 @@ class PromptManager:
self._lock = asyncio.Lock()
@asynccontextmanager
async def async_message_scope(self, message_id: str):
"""为消息处理创建异步临时作用域"""
async def async_message_scope(self, message_id: Optional[str] = None):
"""为消息处理创建异步临时作用域,支持 message_id 为 None 的情况"""
async with self._context.async_scope(message_id):
yield self
@@ -65,9 +96,11 @@ class PromptManager:
# 首先尝试从当前上下文获取
context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None:
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
return context_prompt
# 如果上下文中不存在,则使用全局提示模板
async with self._lock:
logger.debug(f"从全局获取提示词: {name}")
if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]