Files
Mofox-Core/src/plugins/utils/prompt_builder.py
2025-04-11 19:59:33 +08:00

138 lines
4.6 KiB
Python

import re
import ast
from typing import Dict, Any, Optional, List, Union
class PromptManager:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._prompts = {}
cls._instance._counter = 0
return cls._instance
def generate_name(self, template: str) -> str:
"""为未命名的prompt生成名称"""
self._counter += 1
return f"prompt_{self._counter}"
def register(self, prompt: "Prompt") -> None:
"""注册一个prompt"""
if not prompt.name:
prompt.name = self.generate_name(prompt.template)
self._prompts[prompt.name] = prompt
def add_prompt(self, name: str, fstr: str) -> "Prompt":
prompt = Prompt(fstr, name=name)
self._prompts[prompt.name] = prompt
return prompt
def get_prompt(self, name: str) -> "Prompt":
if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]
def format_prompt(self, name: str, **kwargs) -> str:
prompt = self.get_prompt(name)
return prompt.format(**kwargs)
# 全局单例
global_prompt_manager = PromptManager()
class Prompt(str):
def __new__(cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs):
# 如果传入的是元组,转换为列表
if isinstance(args, tuple):
args = list(args)
# 解析模板
tree = ast.parse(f"f'''{fstr}'''", mode="eval")
template_args = set()
for node in ast.walk(tree):
if isinstance(node, ast.FormattedValue):
expr = ast.get_source_segment(fstr, node.value)
if expr:
template_args.add(expr)
# 如果提供了初始参数,立即格式化
if kwargs or args:
formatted = cls._format_template(fstr, args=args, kwargs=kwargs)
obj = super().__new__(cls, formatted)
else:
obj = super().__new__(cls, "")
obj.template = fstr
obj.name = name
obj.args = template_args
obj._args = args or []
obj._kwargs = kwargs
# 自动注册到全局管理器
global_prompt_manager.register(obj)
return obj
@classmethod
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
fmt_str = f"f'''{template}'''"
tree = ast.parse(fmt_str, mode="eval")
template_args = []
for node in ast.walk(tree):
if isinstance(node, ast.FormattedValue):
expr = ast.get_source_segment(fmt_str, node.value)
if expr and expr not in template_args:
template_args.append(expr)
formatted_args = {}
formatted_kwargs = {}
# 处理位置参数
if args:
for i in range(len(args)):
arg = args[i]
if isinstance(arg, Prompt):
formatted_args[template_args[i]] = arg.format(**kwargs)
else:
formatted_args[template_args[i]] = arg
# 处理关键字参数
if kwargs:
for key, value in kwargs.items():
if isinstance(value, Prompt):
remaining_kwargs = {k: v for k, v in kwargs.items() if k != key}
formatted_kwargs[key] = value.format(**remaining_kwargs)
else:
formatted_kwargs[key] = value
try:
# 先用位置参数格式化
if args:
template = template.format(**formatted_args)
# 再用关键字参数格式化
if kwargs:
template = template.format(**formatted_kwargs)
return template
except (IndexError, KeyError) as e:
raise ValueError(
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)} keys:{template_args}"
) from e
def format(self, *args, **kwargs) -> "Prompt":
"""支持位置参数和关键字参数的格式化,使用"""
ret = type(self)(
self.template, self.name, args=list(args) if args else self._args, **kwargs if kwargs else self._kwargs
)
print(f"prompt build result: {ret} name: {ret.name} ")
return ret
def __str__(self) -> str:
if self._kwargs or self._args:
return super().__str__()
return self.template
def __repr__(self) -> str:
return f"Prompt(template='{self.template}', name='{self.name}')"