feat: 支持外部message重载prompt template
This commit is contained in:
@@ -2,16 +2,69 @@
|
||||
import ast
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, Dict
|
||||
import asyncio
|
||||
|
||||
class PromptContext:
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
self._current_context: Optional[str] = None
|
||||
self._context_lock = asyncio.Lock() # 添加异步锁
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: str):
|
||||
"""创建一个异步的临时提示模板作用域"""
|
||||
async with self._context_lock:
|
||||
if context_id not in self._context_prompts:
|
||||
self._context_prompts[context_id] = {}
|
||||
|
||||
previous_context = self._current_context
|
||||
self._current_context = context_id
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
async with self._context_lock:
|
||||
self._current_context = previous_context
|
||||
|
||||
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
|
||||
"""异步获取当前作用域中的提示模板"""
|
||||
async with self._context_lock:
|
||||
if self._current_context and name in self._context_prompts[self._current_context]:
|
||||
return self._context_prompts[self._current_context][name]
|
||||
return None
|
||||
|
||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
target_context = context_id or self._current_context
|
||||
if target_context:
|
||||
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
|
||||
|
||||
|
||||
class PromptManager:
|
||||
_instance = None
|
||||
def __init__(self):
|
||||
self._prompts = {}
|
||||
self._counter = 0
|
||||
self._context = PromptContext()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._prompts = {}
|
||||
cls._instance._counter = 0
|
||||
return cls._instance
|
||||
@asynccontextmanager
|
||||
async def async_message_scope(self, message_id: str):
|
||||
"""为消息处理创建异步临时作用域"""
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
|
||||
async def get_prompt_async(self, name: str) -> "Prompt":
|
||||
# 首先尝试从当前上下文获取
|
||||
context_prompt = await self._context.get_prompt_async(name)
|
||||
if context_prompt is not None:
|
||||
return context_prompt
|
||||
# 如果上下文中不存在,则使用全局提示模板
|
||||
async with self._lock:
|
||||
if name not in self._prompts:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
def generate_name(self, template: str) -> str:
|
||||
"""为未命名的prompt生成名称"""
|
||||
@@ -29,13 +82,8 @@ class PromptManager:
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
def get_prompt(self, name: str) -> "Prompt":
|
||||
if name not in self._prompts:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
def format_prompt(self, name: str, **kwargs) -> str:
|
||||
prompt = self.get_prompt(name)
|
||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||
prompt = await self.get_prompt_async(name)
|
||||
return prompt.format(**kwargs)
|
||||
|
||||
|
||||
@@ -71,10 +119,24 @@ class Prompt(str):
|
||||
obj._args = args or []
|
||||
obj._kwargs = kwargs
|
||||
|
||||
# 自动注册到全局管理器
|
||||
global_prompt_manager.register(obj)
|
||||
# 修改自动注册逻辑
|
||||
if global_prompt_manager._context._current_context:
|
||||
# 如果存在当前上下文,则注册到上下文中
|
||||
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
|
||||
pass
|
||||
else:
|
||||
# 否则注册到全局管理器
|
||||
global_prompt_manager.register(obj)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
async def create_async(cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs):
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = cls(fstr, name, args, **kwargs)
|
||||
if global_prompt_manager._context._current_context:
|
||||
await global_prompt_manager._context.register_async(prompt)
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
|
||||
fmt_str = f"f'''{template}'''"
|
||||
|
||||
Reference in New Issue
Block a user