fix: 恢复template_info功能
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import List, Optional, Dict, Any, Deque, Callable, Coroutine
|
|||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
from src.chat.heart_flow.observation.observation import Observation
|
from src.chat.heart_flow.observation.observation import Observation
|
||||||
@@ -228,7 +229,8 @@ class HeartFChatting:
|
|||||||
thinking_id = "tid" + str(round(time.time(), 2))
|
thinking_id = "tid" + str(round(time.time(), 2))
|
||||||
self._current_cycle.set_thinking_id(thinking_id)
|
self._current_cycle.set_thinking_id(thinking_id)
|
||||||
# 主循环:思考->决策->执行
|
# 主循环:思考->决策->执行
|
||||||
|
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||||
|
logger.debug(f"模板 {self.chat_stream.context.get_template_name()}")
|
||||||
loop_info = await self._observe_process_plan_action_loop(cycle_timers, thinking_id)
|
loop_info = await self._observe_process_plan_action_loop(cycle_timers, thinking_id)
|
||||||
|
|
||||||
self._current_cycle.set_loop_info(loop_info)
|
self._current_cycle.set_loop_info(loop_info)
|
||||||
|
|||||||
@@ -125,7 +125,6 @@ class PromptBuilder:
|
|||||||
relation_prompt += await relationship_manager.build_relationship_info(person)
|
relation_prompt += await relationship_manager.build_relationship_info(person)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Invalid person tuple encountered for relationship prompt: {person}")
|
logger.warning(f"Invalid person tuple encountered for relationship prompt: {person}")
|
||||||
|
|
||||||
mood_prompt = mood_manager.get_mood_prompt()
|
mood_prompt = mood_manager.get_mood_prompt()
|
||||||
reply_styles1 = [
|
reply_styles1 = [
|
||||||
("然后给出日常且口语化的回复,平淡一些", 0.4),
|
("然后给出日常且口语化的回复,平淡一些", 0.4),
|
||||||
@@ -146,9 +145,11 @@ class PromptBuilder:
|
|||||||
[style[0] for style in reply_styles2], weights=[style[1] for style in reply_styles2], k=1
|
[style[0] for style in reply_styles2], weights=[style[1] for style in reply_styles2], k=1
|
||||||
)[0]
|
)[0]
|
||||||
memory_prompt = ""
|
memory_prompt = ""
|
||||||
|
|
||||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||||
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||||
)
|
)
|
||||||
|
|
||||||
related_memory_info = ""
|
related_memory_info = ""
|
||||||
if related_memory:
|
if related_memory:
|
||||||
for memory in related_memory:
|
for memory in related_memory:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from src.chat.message_receive.message import MessageRecv, BaseMessageInfo
|
from src.chat.message_receive.message import MessageRecv, BaseMessageInfo
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream, chat_manager
|
||||||
from src.chat.message_receive.message import UserInfo
|
from src.chat.message_receive.message import UserInfo
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
import json
|
import json
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ class ChatBot:
|
|||||||
message = MessageRecv(message_data)
|
message = MessageRecv(message_data)
|
||||||
group_info = message.message_info.group_info
|
group_info = message.message_info.group_info
|
||||||
user_info = message.message_info.user_info
|
user_info = message.message_info.user_info
|
||||||
|
chat_manager.register_message(message)
|
||||||
|
|
||||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||||
@@ -86,7 +87,7 @@ class ChatBot:
|
|||||||
if isinstance(template_items, dict):
|
if isinstance(template_items, dict):
|
||||||
for k in template_items.keys():
|
for k in template_items.keys():
|
||||||
await Prompt.create_async(template_items[k], k)
|
await Prompt.create_async(template_items[k], k)
|
||||||
print(f"注册{template_items[k]},{k}")
|
logger.debug(f"注册{template_items[k]},{k}")
|
||||||
else:
|
else:
|
||||||
template_group_name = None
|
template_group_name = None
|
||||||
|
|
||||||
|
|||||||
@@ -2,13 +2,17 @@ import asyncio
|
|||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
from ...common.database.database import db
|
from ...common.database.database import db
|
||||||
from ...common.database.database_model import ChatStreams # 新增导入
|
from ...common.database.database_model import ChatStreams # 新增导入
|
||||||
from maim_message import GroupInfo, UserInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
|
|
||||||
|
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .message import MessageRecv
|
||||||
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
@@ -18,6 +22,23 @@ install(extra_lines=3)
|
|||||||
logger = get_logger("chat_stream")
|
logger = get_logger("chat_stream")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageContext:
|
||||||
|
"""聊天消息上下文,存储消息的上下文信息"""
|
||||||
|
|
||||||
|
def __init__(self, message: "MessageRecv"):
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def get_template_name(self) -> str:
|
||||||
|
"""获取模板名称"""
|
||||||
|
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
||||||
|
return self.message.message_info.template_info.template_name
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_last_message(self) -> "MessageRecv":
|
||||||
|
"""获取最后一条消息"""
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
|
||||||
class ChatStream:
|
class ChatStream:
|
||||||
"""聊天流对象,存储一个完整的聊天上下文"""
|
"""聊天流对象,存储一个完整的聊天上下文"""
|
||||||
|
|
||||||
@@ -36,6 +57,7 @@ class ChatStream:
|
|||||||
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
||||||
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
||||||
self.saved = False
|
self.saved = False
|
||||||
|
self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""转换为字典格式"""
|
"""转换为字典格式"""
|
||||||
@@ -67,6 +89,10 @@ class ChatStream:
|
|||||||
self.last_active_time = time.time()
|
self.last_active_time = time.time()
|
||||||
self.saved = False
|
self.saved = False
|
||||||
|
|
||||||
|
def set_context(self, message: "MessageRecv"):
|
||||||
|
"""设置聊天消息上下文"""
|
||||||
|
self.context = ChatMessageContext(message)
|
||||||
|
|
||||||
|
|
||||||
class ChatManager:
|
class ChatManager:
|
||||||
"""聊天管理器,管理所有聊天流"""
|
"""聊天管理器,管理所有聊天流"""
|
||||||
@@ -82,6 +108,7 @@ class ChatManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||||
|
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||||
try:
|
try:
|
||||||
db.connect(reuse_if_open=True)
|
db.connect(reuse_if_open=True)
|
||||||
# 确保 ChatStreams 表存在
|
# 确保 ChatStreams 表存在
|
||||||
@@ -113,6 +140,16 @@ class ChatManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||||
|
|
||||||
|
def register_message(self, message: "MessageRecv"):
|
||||||
|
"""注册消息到聊天流"""
|
||||||
|
stream_id = self._generate_stream_id(
|
||||||
|
message.message_info.platform,
|
||||||
|
message.message_info.user_info,
|
||||||
|
message.message_info.group_info,
|
||||||
|
)
|
||||||
|
self.last_messages[stream_id] = message
|
||||||
|
logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
||||||
"""生成聊天流唯一ID"""
|
"""生成聊天流唯一ID"""
|
||||||
@@ -146,12 +183,19 @@ class ChatManager:
|
|||||||
# 检查内存中是否存在
|
# 检查内存中是否存在
|
||||||
if stream_id in self.streams:
|
if stream_id in self.streams:
|
||||||
stream = self.streams[stream_id]
|
stream = self.streams[stream_id]
|
||||||
|
|
||||||
# 更新用户信息和群组信息
|
# 更新用户信息和群组信息
|
||||||
stream.update_active_time()
|
stream.update_active_time()
|
||||||
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
|
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
|
||||||
stream.user_info = user_info
|
stream.user_info = user_info
|
||||||
if group_info:
|
if group_info:
|
||||||
stream.group_info = group_info
|
stream.group_info = group_info
|
||||||
|
from .message import MessageRecv # 延迟导入,避免循环引用
|
||||||
|
|
||||||
|
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||||||
|
stream.set_context(self.last_messages[stream_id])
|
||||||
|
else:
|
||||||
|
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
# 检查数据库中是否存在
|
# 检查数据库中是否存在
|
||||||
@@ -202,14 +246,24 @@ class ChatManager:
|
|||||||
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
|
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
stream = copy.deepcopy(stream)
|
||||||
|
from .message import MessageRecv # 延迟导入,避免循环引用
|
||||||
|
|
||||||
|
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||||||
|
stream.set_context(self.last_messages[stream_id])
|
||||||
|
else:
|
||||||
|
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||||||
# 保存到内存和数据库
|
# 保存到内存和数据库
|
||||||
self.streams[stream_id] = stream
|
self.streams[stream_id] = stream
|
||||||
await self._save_stream(stream)
|
await self._save_stream(stream)
|
||||||
return copy.deepcopy(stream)
|
return stream
|
||||||
|
|
||||||
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
||||||
"""通过stream_id获取聊天流"""
|
"""通过stream_id获取聊天流"""
|
||||||
return self.streams.get(stream_id)
|
stream = self.streams.get(stream_id)
|
||||||
|
if stream_id in self.last_messages:
|
||||||
|
stream.set_context(self.last_messages[stream_id])
|
||||||
|
return stream
|
||||||
|
|
||||||
def get_stream_by_info(
|
def get_stream_by_info(
|
||||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||||
@@ -306,6 +360,8 @@ class ChatManager:
|
|||||||
stream = ChatStream.from_dict(data)
|
stream = ChatStream.from_dict(data)
|
||||||
stream.saved = True
|
stream.saved = True
|
||||||
self.streams[stream.stream_id] = stream
|
self.streams[stream.stream_id] = stream
|
||||||
|
if stream.stream_id in self.last_messages:
|
||||||
|
stream.set_context(self.last_messages[stream.stream_id])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
|
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any, TYPE_CHECKING
|
||||||
|
|
||||||
import urllib3
|
import urllib3
|
||||||
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
from ..utils.utils_image import image_manager
|
from ..utils.utils_image import image_manager
|
||||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||||
@@ -25,7 +27,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Message(MessageBase):
|
class Message(MessageBase):
|
||||||
chat_stream: ChatStream = None
|
chat_stream: "ChatStream" = None
|
||||||
reply: Optional["Message"] = None
|
reply: Optional["Message"] = None
|
||||||
detailed_plain_text: str = ""
|
detailed_plain_text: str = ""
|
||||||
processed_plain_text: str = ""
|
processed_plain_text: str = ""
|
||||||
@@ -34,7 +36,7 @@ class Message(MessageBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
chat_stream: ChatStream,
|
chat_stream: "ChatStream",
|
||||||
user_info: UserInfo,
|
user_info: UserInfo,
|
||||||
message_segment: Optional[Seg] = None,
|
message_segment: Optional[Seg] = None,
|
||||||
timestamp: Optional[float] = None,
|
timestamp: Optional[float] = None,
|
||||||
@@ -111,7 +113,7 @@ class MessageRecv(Message):
|
|||||||
self.detailed_plain_text = "" # 初始化为空字符串
|
self.detailed_plain_text = "" # 初始化为空字符串
|
||||||
self.is_emoji = False
|
self.is_emoji = False
|
||||||
|
|
||||||
def update_chat_stream(self, chat_stream: ChatStream):
|
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
|
|
||||||
async def process(self) -> None:
|
async def process(self) -> None:
|
||||||
@@ -165,7 +167,7 @@ class MessageProcessBase(Message):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
chat_stream: ChatStream,
|
chat_stream: "ChatStream",
|
||||||
bot_user_info: UserInfo,
|
bot_user_info: UserInfo,
|
||||||
message_segment: Optional[Seg] = None,
|
message_segment: Optional[Seg] = None,
|
||||||
reply: Optional["MessageRecv"] = None,
|
reply: Optional["MessageRecv"] = None,
|
||||||
@@ -241,7 +243,7 @@ class MessageThinking(MessageProcessBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
chat_stream: ChatStream,
|
chat_stream: "ChatStream",
|
||||||
bot_user_info: UserInfo,
|
bot_user_info: UserInfo,
|
||||||
reply: Optional["MessageRecv"] = None,
|
reply: Optional["MessageRecv"] = None,
|
||||||
thinking_start_time: float = 0,
|
thinking_start_time: float = 0,
|
||||||
@@ -269,7 +271,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
chat_stream: ChatStream,
|
chat_stream: "ChatStream",
|
||||||
bot_user_info: UserInfo,
|
bot_user_info: UserInfo,
|
||||||
sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复
|
sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复
|
||||||
message_segment: Seg,
|
message_segment: Seg,
|
||||||
@@ -353,7 +355,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
class MessageSet:
|
class MessageSet:
|
||||||
"""消息集合类,可以存储多个发送消息"""
|
"""消息集合类,可以存储多个发送消息"""
|
||||||
|
|
||||||
def __init__(self, chat_stream: ChatStream, message_id: str):
|
def __init__(self, chat_stream: "ChatStream", message_id: str):
|
||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
self.message_id = message_id
|
self.message_id = message_id
|
||||||
self.messages: list[MessageSending] = []
|
self.messages: list[MessageSending] = []
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from src.chat.message_receive.chat_stream import ChatStream, chat_manager
|
|||||||
from src.chat.person_info.relationship_manager import relationship_manager
|
from src.chat.person_info.relationship_manager import relationship_manager
|
||||||
from src.chat.utils.info_catcher import info_catcher_manager
|
from src.chat.utils.info_catcher import info_catcher_manager
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
|
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||||
from .normal_chat_generator import NormalChatGenerator
|
from .normal_chat_generator import NormalChatGenerator
|
||||||
from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||||
from src.chat.message_receive.message_sender import message_manager
|
from src.chat.message_receive.message_sender import message_manager
|
||||||
@@ -194,13 +195,13 @@ class NormalChat:
|
|||||||
通常由start_monitoring_interest()启动
|
通常由start_monitoring_interest()启动
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
|
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||||
await asyncio.sleep(0.5) # 每秒检查一次
|
await asyncio.sleep(0.5) # 每秒检查一次
|
||||||
# 检查任务是否已被取消
|
# 检查任务是否已被取消
|
||||||
if self._chat_task is None or self._chat_task.cancelled():
|
if self._chat_task is None or self._chat_task.cancelled():
|
||||||
logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出")
|
logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
items_to_process = list(self.interest_dict.items())
|
items_to_process = list(self.interest_dict.items())
|
||||||
if not items_to_process:
|
if not items_to_process:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Dict, Any, Optional, List, Union
|
|||||||
import re
|
import re
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextvars
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
|
||||||
# import traceback
|
# import traceback
|
||||||
@@ -15,29 +16,59 @@ logger = get_module_logger("prompt_build")
|
|||||||
class PromptContext:
|
class PromptContext:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||||
self._current_context: Optional[str] = None
|
# 使用contextvars创建协程上下文变量
|
||||||
self._context_lock = asyncio.Lock() # 添加异步锁
|
self._current_context_var = contextvars.ContextVar("current_context", default=None)
|
||||||
|
self._context_lock = asyncio.Lock() # 保留锁用于其他操作
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _current_context(self) -> Optional[str]:
|
||||||
|
"""获取当前协程的上下文ID"""
|
||||||
|
return self._current_context_var.get()
|
||||||
|
|
||||||
|
@_current_context.setter
|
||||||
|
def _current_context(self, value: Optional[str]):
|
||||||
|
"""设置当前协程的上下文ID"""
|
||||||
|
self._current_context_var.set(value)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def async_scope(self, context_id: str):
|
async def async_scope(self, context_id: Optional[str] = None):
|
||||||
"""创建一个异步的临时提示模板作用域"""
|
"""创建一个异步的临时提示模板作用域"""
|
||||||
|
# 保存当前上下文并设置新上下文
|
||||||
|
if context_id is not None:
|
||||||
async with self._context_lock:
|
async with self._context_lock:
|
||||||
if context_id not in self._context_prompts:
|
if context_id not in self._context_prompts:
|
||||||
self._context_prompts[context_id] = {}
|
self._context_prompts[context_id] = {}
|
||||||
|
|
||||||
|
# 保存当前协程的上下文值,不影响其他协程
|
||||||
previous_context = self._current_context
|
previous_context = self._current_context
|
||||||
self._current_context = context_id
|
# 设置当前协程的新上下文
|
||||||
|
token = self._current_context_var.set(context_id)
|
||||||
|
else:
|
||||||
|
# 如果没有提供新上下文,保持当前上下文不变
|
||||||
|
previous_context = self._current_context
|
||||||
|
token = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self
|
yield self
|
||||||
finally:
|
finally:
|
||||||
async with self._context_lock:
|
# 恢复之前的上下文
|
||||||
|
if context_id is not None:
|
||||||
|
if token:
|
||||||
|
self._current_context_var.reset(token)
|
||||||
|
else:
|
||||||
self._current_context = previous_context
|
self._current_context = previous_context
|
||||||
|
|
||||||
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
|
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
|
||||||
"""异步获取当前作用域中的提示模板"""
|
"""异步获取当前作用域中的提示模板"""
|
||||||
async with self._context_lock:
|
async with self._context_lock:
|
||||||
if self._current_context and name in self._context_prompts[self._current_context]:
|
current_context = self._current_context
|
||||||
return self._context_prompts[self._current_context][name]
|
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
|
||||||
|
if (
|
||||||
|
current_context
|
||||||
|
and current_context in self._context_prompts
|
||||||
|
and name in self._context_prompts[current_context]
|
||||||
|
):
|
||||||
|
return self._context_prompts[current_context][name]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||||
@@ -56,8 +87,8 @@ class PromptManager:
|
|||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def async_message_scope(self, message_id: str):
|
async def async_message_scope(self, message_id: Optional[str] = None):
|
||||||
"""为消息处理创建异步临时作用域"""
|
"""为消息处理创建异步临时作用域,支持 message_id 为 None 的情况"""
|
||||||
async with self._context.async_scope(message_id):
|
async with self._context.async_scope(message_id):
|
||||||
yield self
|
yield self
|
||||||
|
|
||||||
@@ -65,9 +96,11 @@ class PromptManager:
|
|||||||
# 首先尝试从当前上下文获取
|
# 首先尝试从当前上下文获取
|
||||||
context_prompt = await self._context.get_prompt_async(name)
|
context_prompt = await self._context.get_prompt_async(name)
|
||||||
if context_prompt is not None:
|
if context_prompt is not None:
|
||||||
|
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||||
return context_prompt
|
return context_prompt
|
||||||
# 如果上下文中不存在,则使用全局提示模板
|
# 如果上下文中不存在,则使用全局提示模板
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
|
logger.debug(f"从全局获取提示词: {name}")
|
||||||
if name not in self._prompts:
|
if name not in self._prompts:
|
||||||
raise KeyError(f"Prompt '{name}' not found")
|
raise KeyError(f"Prompt '{name}' not found")
|
||||||
return self._prompts[name]
|
return self._prompts[name]
|
||||||
|
|||||||
Reference in New Issue
Block a user