fix: 恢复template_info功能

This commit is contained in:
tcmofashi
2025-05-23 11:04:49 +08:00
parent 75eeea8d92
commit ff9efb1c5e
8 changed files with 149 additions and 53 deletions

View File

@@ -2,6 +2,7 @@ from typing import Dict, Any, Optional, List, Union
import re
from contextlib import asynccontextmanager
import asyncio
import contextvars
from src.common.logger import get_module_logger
# import traceback
@@ -15,29 +16,59 @@ logger = get_module_logger("prompt_build")
class PromptContext:
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._current_context: Optional[str] = None
self._context_lock = asyncio.Lock() # 添加异步锁
# 使用contextvars创建协程上下文变量
self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock() # 保留锁用于其他操作
@property
def _current_context(self) -> Optional[str]:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value)
@asynccontextmanager
async def async_scope(self, context_id: str):
async def async_scope(self, context_id: Optional[str] = None):
"""创建一个异步的临时提示模板作用域"""
async with self._context_lock:
if context_id not in self._context_prompts:
self._context_prompts[context_id] = {}
# 保存当前上下文并设置新上下文
if context_id is not None:
async with self._context_lock:
if context_id not in self._context_prompts:
self._context_prompts[context_id] = {}
# 保存当前协程的上下文值,不影响其他协程
previous_context = self._current_context
self._current_context = context_id
# 设置当前协程的新上下文
token = self._current_context_var.set(context_id)
else:
# 如果没有提供新上下文,保持当前上下文不变
previous_context = self._current_context
token = None
try:
yield self
finally:
async with self._context_lock:
self._current_context = previous_context
# 恢复之前的上下文
if context_id is not None:
if token:
self._current_context_var.reset(token)
else:
self._current_context = previous_context
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板"""
async with self._context_lock:
if self._current_context and name in self._context_prompts[self._current_context]:
return self._context_prompts[self._current_context][name]
current_context = self._current_context
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
if (
current_context
and current_context in self._context_prompts
and name in self._context_prompts[current_context]
):
return self._context_prompts[current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
@@ -56,8 +87,8 @@ class PromptManager:
self._lock = asyncio.Lock()
@asynccontextmanager
async def async_message_scope(self, message_id: str):
"""为消息处理创建异步临时作用域"""
async def async_message_scope(self, message_id: Optional[str] = None):
"""为消息处理创建异步临时作用域,支持 message_id 为 None 的情况"""
async with self._context.async_scope(message_id):
yield self
@@ -65,9 +96,11 @@ class PromptManager:
# 首先尝试从当前上下文获取
context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None:
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
return context_prompt
# 如果上下文中不存在,则使用全局提示模板
async with self._lock:
logger.debug(f"从全局获取提示词: {name}")
if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]