fix: 恢复template_info功能

This commit is contained in:
tcmofashi
2025-05-23 11:04:49 +08:00
parent 75eeea8d92
commit ff9efb1c5e
8 changed files with 149 additions and 53 deletions

View File

@@ -77,6 +77,7 @@ class ChatBot:
message = MessageRecv(message_data)
group_info = message.message_info.group_info
user_info = message.message_info.user_info
chat_manager.register_message(message)
# 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default:
@@ -86,7 +87,7 @@ class ChatBot:
if isinstance(template_items, dict):
for k in template_items.keys():
await Prompt.create_async(template_items[k], k)
print(f"注册{template_items[k]},{k}")
logger.debug(f"注册{template_items[k]},{k}")
else:
template_group_name = None

View File

@@ -2,13 +2,17 @@ import asyncio
import hashlib
import time
import copy
from typing import Dict, Optional
from typing import Dict, Optional, TYPE_CHECKING
from ...common.database.database import db
from ...common.database.database_model import ChatStreams # 新增导入
from maim_message import GroupInfo, UserInfo
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
from src.common.logger_manager import get_logger
from rich.traceback import install
@@ -18,6 +22,23 @@ install(extra_lines=3)
logger = get_logger("chat_stream")
class ChatMessageContext:
"""聊天消息上下文,存储消息的上下文信息"""
def __init__(self, message: "MessageRecv"):
self.message = message
def get_template_name(self) -> str:
"""获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name
return None
def get_last_message(self) -> "MessageRecv":
"""获取最后一条消息"""
return self.message
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
@@ -36,6 +57,7 @@ class ChatStream:
self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False
self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息
def to_dict(self) -> dict:
"""转换为字典格式"""
@@ -67,6 +89,10 @@ class ChatStream:
self.last_active_time = time.time()
self.saved = False
def set_context(self, message: "MessageRecv"):
"""设置聊天消息上下文"""
self.context = ChatMessageContext(message)
class ChatManager:
"""聊天管理器,管理所有聊天流"""
@@ -82,6 +108,7 @@ class ChatManager:
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
@@ -113,6 +140,16 @@ class ChatManager:
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流"""
stream_id = self._generate_stream_id(
message.message_info.platform,
message.message_info.user_info,
message.message_info.group_info,
)
self.last_messages[stream_id] = message
logger.debug(f"注册消息到聊天流: {stream_id}")
@staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID"""
@@ -146,12 +183,19 @@ class ChatManager:
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
stream.user_info = user_info
if group_info:
stream.group_info = group_info
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
return stream
# 检查数据库中是否存在
@@ -202,14 +246,24 @@ class ChatManager:
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
raise e
stream = copy.deepcopy(stream)
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return copy.deepcopy(stream)
return stream
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""
return self.streams.get(stream_id)
stream = self.streams.get(stream_id)
if stream_id in self.last_messages:
stream.set_context(self.last_messages[stream_id])
return stream
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
@@ -306,6 +360,8 @@ class ChatManager:
stream = ChatStream.from_dict(data)
stream.saved = True
self.streams[stream.stream_id] = stream
if stream.stream_id in self.last_messages:
stream.set_context(self.last_messages[stream.stream_id])
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)

View File

@@ -1,12 +1,14 @@
import time
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional, Any
from typing import Optional, Any, TYPE_CHECKING
import urllib3
from src.common.logger_manager import get_logger
from .chat_stream import ChatStream
if TYPE_CHECKING:
from .chat_stream import ChatStream
from ..utils.utils_image import image_manager
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from rich.traceback import install
@@ -25,7 +27,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@dataclass
class Message(MessageBase):
chat_stream: ChatStream = None
chat_stream: "ChatStream" = None
reply: Optional["Message"] = None
detailed_plain_text: str = ""
processed_plain_text: str = ""
@@ -34,7 +36,7 @@ class Message(MessageBase):
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
chat_stream: "ChatStream",
user_info: UserInfo,
message_segment: Optional[Seg] = None,
timestamp: Optional[float] = None,
@@ -111,7 +113,7 @@ class MessageRecv(Message):
self.detailed_plain_text = "" # 初始化为空字符串
self.is_emoji = False
def update_chat_stream(self, chat_stream: ChatStream):
def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream
async def process(self) -> None:
@@ -165,7 +167,7 @@ class MessageProcessBase(Message):
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional["MessageRecv"] = None,
@@ -241,7 +243,7 @@ class MessageThinking(MessageProcessBase):
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
reply: Optional["MessageRecv"] = None,
thinking_start_time: float = 0,
@@ -269,7 +271,7 @@ class MessageSending(MessageProcessBase):
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复
message_segment: Seg,
@@ -353,7 +355,7 @@ class MessageSending(MessageProcessBase):
class MessageSet:
"""消息集合类,可以存储多个发送消息"""
def __init__(self, chat_stream: ChatStream, message_id: str):
def __init__(self, chat_stream: "ChatStream", message_id: str):
self.chat_stream = chat_stream
self.message_id = message_id
self.messages: list[MessageSending] = []