fix: 恢复template_info功能
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import Dict, Any, Optional, List, Union
|
||||
import re
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
import contextvars
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
# import traceback
|
||||
@@ -15,29 +16,59 @@ logger = get_module_logger("prompt_build")
|
||||
class PromptContext:
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
self._current_context: Optional[str] = None
|
||||
self._context_lock = asyncio.Lock() # 添加异步锁
|
||||
# 使用contextvars创建协程上下文变量
|
||||
self._current_context_var = contextvars.ContextVar("current_context", default=None)
|
||||
self._context_lock = asyncio.Lock() # 保留锁用于其他操作
|
||||
|
||||
@property
|
||||
def _current_context(self) -> Optional[str]:
|
||||
"""获取当前协程的上下文ID"""
|
||||
return self._current_context_var.get()
|
||||
|
||||
@_current_context.setter
|
||||
def _current_context(self, value: Optional[str]):
|
||||
"""设置当前协程的上下文ID"""
|
||||
self._current_context_var.set(value)
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: str):
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
"""创建一个异步的临时提示模板作用域"""
|
||||
async with self._context_lock:
|
||||
if context_id not in self._context_prompts:
|
||||
self._context_prompts[context_id] = {}
|
||||
# 保存当前上下文并设置新上下文
|
||||
if context_id is not None:
|
||||
async with self._context_lock:
|
||||
if context_id not in self._context_prompts:
|
||||
self._context_prompts[context_id] = {}
|
||||
|
||||
# 保存当前协程的上下文值,不影响其他协程
|
||||
previous_context = self._current_context
|
||||
self._current_context = context_id
|
||||
# 设置当前协程的新上下文
|
||||
token = self._current_context_var.set(context_id)
|
||||
else:
|
||||
# 如果没有提供新上下文,保持当前上下文不变
|
||||
previous_context = self._current_context
|
||||
token = None
|
||||
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
async with self._context_lock:
|
||||
self._current_context = previous_context
|
||||
# 恢复之前的上下文
|
||||
if context_id is not None:
|
||||
if token:
|
||||
self._current_context_var.reset(token)
|
||||
else:
|
||||
self._current_context = previous_context
|
||||
|
||||
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
|
||||
"""异步获取当前作用域中的提示模板"""
|
||||
async with self._context_lock:
|
||||
if self._current_context and name in self._context_prompts[self._current_context]:
|
||||
return self._context_prompts[self._current_context][name]
|
||||
current_context = self._current_context
|
||||
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
|
||||
if (
|
||||
current_context
|
||||
and current_context in self._context_prompts
|
||||
and name in self._context_prompts[current_context]
|
||||
):
|
||||
return self._context_prompts[current_context][name]
|
||||
return None
|
||||
|
||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||
@@ -56,8 +87,8 @@ class PromptManager:
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_message_scope(self, message_id: str):
|
||||
"""为消息处理创建异步临时作用域"""
|
||||
async def async_message_scope(self, message_id: Optional[str] = None):
|
||||
"""为消息处理创建异步临时作用域,支持 message_id 为 None 的情况"""
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
|
||||
@@ -65,9 +96,11 @@ class PromptManager:
|
||||
# 首先尝试从当前上下文获取
|
||||
context_prompt = await self._context.get_prompt_async(name)
|
||||
if context_prompt is not None:
|
||||
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||
return context_prompt
|
||||
# 如果上下文中不存在,则使用全局提示模板
|
||||
async with self._lock:
|
||||
logger.debug(f"从全局获取提示词: {name}")
|
||||
if name not in self._prompts:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
Reference in New Issue
Block a user