refactor: 优化自定义提示词转义实现方案,减少对原代码的修改
This commit is contained in:
@@ -2,10 +2,6 @@ from typing import Dict, Any, Optional, List, Union
|
|||||||
import re
|
import re
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import asyncio
|
import asyncio
|
||||||
from src.common.logger import get_module_logger
|
|
||||||
# import traceback
|
|
||||||
|
|
||||||
logger = get_module_logger("prompt_build")
|
|
||||||
|
|
||||||
|
|
||||||
class PromptContext:
|
class PromptContext:
|
||||||
@@ -100,14 +96,17 @@ class Prompt(str):
|
|||||||
args = list(args)
|
args = list(args)
|
||||||
should_register = kwargs.pop("_should_register", True)
|
should_register = kwargs.pop("_should_register", True)
|
||||||
|
|
||||||
# 预处理模板字符串,替换转义的花括号
|
# 预处理模板中的转义花括号
|
||||||
processed_fstr = fstr
|
processed_fstr = fstr
|
||||||
|
temp_left = "__ESCAPED_LEFT_BRACE__"
|
||||||
|
temp_right = "__ESCAPED_RIGHT_BRACE__"
|
||||||
|
processed_fstr = processed_fstr.replace("\\{", temp_left).replace("\\}", temp_right)
|
||||||
|
|
||||||
# 解析模板
|
# 解析模板
|
||||||
template_args = []
|
template_args = []
|
||||||
result = re.findall(r"\{(.*?)\}", processed_fstr)
|
result = re.findall(r"\{(.*?)\}", processed_fstr)
|
||||||
for expr in result:
|
for expr in result:
|
||||||
if expr and expr not in template_args and not cls._is_escaped(processed_fstr, expr):
|
if expr and expr not in template_args:
|
||||||
template_args.append(expr)
|
template_args.append(expr)
|
||||||
|
|
||||||
# 如果提供了初始参数,立即格式化
|
# 如果提供了初始参数,立即格式化
|
||||||
@@ -134,67 +133,6 @@ class Prompt(str):
|
|||||||
global_prompt_manager.register(obj)
|
global_prompt_manager.register(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_escaped(s: str, expr: str) -> bool:
|
|
||||||
"""判断表达式是否被转义"""
|
|
||||||
pattern = r"\\{" + re.escape(expr) + r"}"
|
|
||||||
return bool(re.search(pattern, s))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _preprocess_template(template: str) -> tuple[str, dict]:
|
|
||||||
"""
|
|
||||||
预处理模板,将转义的花括号替换为唯一的临时标记
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template: 原始模板字符串
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (处理后的模板, 占位符映射字典)
|
|
||||||
"""
|
|
||||||
placeholders = {}
|
|
||||||
counter = 0
|
|
||||||
processed = template
|
|
||||||
|
|
||||||
# 定义替换函数 - 用于生成唯一占位符
|
|
||||||
def create_placeholder(char_type):
|
|
||||||
nonlocal counter
|
|
||||||
placeholder = f"__ESC_{char_type}_{counter}__"
|
|
||||||
counter += 1
|
|
||||||
return placeholder
|
|
||||||
|
|
||||||
# 处理转义的左花括号 \{
|
|
||||||
left_brace_pattern = r"\\{"
|
|
||||||
while re.search(left_brace_pattern, processed):
|
|
||||||
placeholder = create_placeholder("LEFT_BRACE")
|
|
||||||
placeholders[placeholder] = "{"
|
|
||||||
processed = re.sub(left_brace_pattern, placeholder, processed, count=1)
|
|
||||||
|
|
||||||
# 处理转义的右花括号 \}
|
|
||||||
right_brace_pattern = r"\\}"
|
|
||||||
while re.search(right_brace_pattern, processed):
|
|
||||||
placeholder = create_placeholder("RIGHT_BRACE")
|
|
||||||
placeholders[placeholder] = "}"
|
|
||||||
processed = re.sub(right_brace_pattern, placeholder, processed, count=1)
|
|
||||||
|
|
||||||
return processed, placeholders
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _restore_template(template: str, placeholders: dict) -> str:
|
|
||||||
"""
|
|
||||||
还原预处理后的模板中的占位符为实际字符
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template: 处理后的模板字符串
|
|
||||||
placeholders: 占位符映射字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 还原后的字符串
|
|
||||||
"""
|
|
||||||
result = template
|
|
||||||
for placeholder, value in placeholders.items():
|
|
||||||
result = result.replace(placeholder, value)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_async(
|
async def create_async(
|
||||||
cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs
|
cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs
|
||||||
@@ -207,43 +145,29 @@ class Prompt(str):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
|
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
|
||||||
"""
|
# 预处理模板中的转义花括号
|
||||||
格式化模板字符串,同时处理转义的花括号
|
processed_template = template
|
||||||
|
# 临时替换转义的花括号
|
||||||
|
temp_left = "__ESCAPED_LEFT_BRACE__"
|
||||||
|
temp_right = "__ESCAPED_RIGHT_BRACE__"
|
||||||
|
processed_template = processed_template.replace("\\{", temp_left).replace("\\}", temp_right)
|
||||||
|
|
||||||
处理流程:
|
|
||||||
1. 预处理模板,替换转义的花括号为临时占位符
|
|
||||||
2. 解析模板中的参数
|
|
||||||
3. 应用参数进行格式化
|
|
||||||
4. 还原临时占位符为实际花括号
|
|
||||||
"""
|
|
||||||
# 1. 预处理:替换转义的花括号为临时占位符
|
|
||||||
processed_template, placeholders = cls._preprocess_template(template)
|
|
||||||
|
|
||||||
# 2. 解析参数
|
|
||||||
template_args = []
|
template_args = []
|
||||||
result = re.findall(r"\{(.*?)\}", processed_template)
|
result = re.findall(r"\{(.*?)\}", processed_template)
|
||||||
for expr in result:
|
for expr in result:
|
||||||
if expr and expr not in template_args:
|
if expr and expr not in template_args:
|
||||||
template_args.append(expr)
|
template_args.append(expr)
|
||||||
|
|
||||||
formatted_args = {}
|
formatted_args = {}
|
||||||
formatted_kwargs = {}
|
formatted_kwargs = {}
|
||||||
|
|
||||||
# 3. 处理位置参数
|
# 处理位置参数
|
||||||
if args:
|
if args:
|
||||||
# print(len(template_args), len(args), template_args, args)
|
|
||||||
for i in range(len(args)):
|
for i in range(len(args)):
|
||||||
if i < len(template_args):
|
arg = args[i]
|
||||||
arg = args[i]
|
if isinstance(arg, Prompt):
|
||||||
if isinstance(arg, Prompt):
|
formatted_args[template_args[i]] = arg.format(**kwargs)
|
||||||
formatted_args[template_args[i]] = arg.format(**kwargs)
|
|
||||||
else:
|
|
||||||
formatted_args[template_args[i]] = arg
|
|
||||||
else:
|
else:
|
||||||
logger.error(
|
formatted_args[template_args[i]] = arg
|
||||||
f"构建提示词模板失败,解析到的参数列表{template_args},长度为{len(template_args)},输入的参数列表为{args},提示词模板为{template}"
|
|
||||||
)
|
|
||||||
raise ValueError("格式化模板失败")
|
|
||||||
|
|
||||||
# 处理关键字参数
|
# 处理关键字参数
|
||||||
if kwargs:
|
if kwargs:
|
||||||
@@ -255,21 +179,22 @@ class Prompt(str):
|
|||||||
formatted_kwargs[key] = value
|
formatted_kwargs[key] = value
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 应用格式化
|
# 先用位置参数格式化
|
||||||
if args:
|
if args:
|
||||||
processed_template = processed_template.format(**formatted_args)
|
processed_template = processed_template.format(**formatted_args)
|
||||||
|
# 再用关键字参数格式化
|
||||||
if kwargs:
|
if kwargs:
|
||||||
processed_template = processed_template.format(**formatted_kwargs)
|
processed_template = processed_template.format(**formatted_kwargs)
|
||||||
|
|
||||||
# 4. 还原占位符为实际的花括号
|
# 将临时标记还原为实际的花括号
|
||||||
final_result = cls._restore_template(processed_template, placeholders)
|
result = processed_template.replace(temp_left, "{").replace(temp_right, "}")
|
||||||
return final_result
|
return result
|
||||||
except (IndexError, KeyError) as e:
|
except (IndexError, KeyError) as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
|
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
def format(self, *args, **kwargs) -> "str":
|
def format(self, *args, **kwargs) -> "Prompt":
|
||||||
"""支持位置参数和关键字参数的格式化,使用"""
|
"""支持位置参数和关键字参数的格式化,使用"""
|
||||||
ret = type(self)(
|
ret = type(self)(
|
||||||
self.template,
|
self.template,
|
||||||
|
|||||||
Reference in New Issue
Block a user