fix: 修复了prompt模板功能的若干bug

This commit is contained in:
tcmofashi
2025-04-13 02:02:52 +08:00
parent 46da415d98
commit e61bcdb435
2 changed files with 15 additions and 19 deletions

View File

@@ -195,7 +195,7 @@ class PromptBuilder:
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"reasoning_prompt_main", "reasoning_prompt_main",
relation_prompt_all=await global_prompt_manager.get_prompt_async("relationship_prompt"), relation_prompt_all=await global_prompt_manager.get_prompt_async("relationship_prompt"),
replation_prompt=relation_prompt, relation_prompt=relation_prompt,
sender_name=sender_name, sender_name=sender_name,
memory_prompt=memory_prompt, memory_prompt=memory_prompt,
prompt_info=prompt_info, prompt_info=prompt_info,

View File

@@ -1,7 +1,5 @@
# import re
import ast
from typing import Dict, Any, Optional, List, Union from typing import Dict, Any, Optional, List, Union
import re
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import asyncio import asyncio
@@ -98,13 +96,11 @@ class Prompt(str):
args = list(args) args = list(args)
# 解析模板 # 解析模板
tree = ast.parse(f"f'''{fstr}'''", mode="eval") template_args = []
template_args = set() result = re.findall(r"\{(.*?)\}", fstr)
for node in ast.walk(tree): for expr in result:
if isinstance(node, ast.FormattedValue): if expr and expr not in template_args:
expr = ast.get_source_segment(fstr, node.value) template_args.append(expr)
if expr:
template_args.add(expr)
# 如果提供了初始参数,立即格式化 # 如果提供了初始参数,立即格式化
if kwargs or args: if kwargs or args:
@@ -141,14 +137,11 @@ class Prompt(str):
@classmethod @classmethod
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str: def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
fmt_str = f"f'''{template}'''"
tree = ast.parse(fmt_str, mode="eval")
template_args = [] template_args = []
for node in ast.walk(tree): result = re.findall(r"\{(.*?)\}", template)
if isinstance(node, ast.FormattedValue): for expr in result:
expr = ast.get_source_segment(fmt_str, node.value) if expr and expr not in template_args:
if expr and expr not in template_args: template_args.append(expr)
template_args.append(expr)
formatted_args = {} formatted_args = {}
formatted_kwargs = {} formatted_kwargs = {}
@@ -180,13 +173,16 @@ class Prompt(str):
template = template.format(**formatted_kwargs) template = template.format(**formatted_kwargs)
return template return template
except (IndexError, KeyError) as e: except (IndexError, KeyError) as e:
raise ValueError(f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs}") from e raise ValueError(
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
) from e
def format(self, *args, **kwargs) -> "Prompt": def format(self, *args, **kwargs) -> "Prompt":
"""支持位置参数和关键字参数的格式化,使用""" """支持位置参数和关键字参数的格式化,使用"""
ret = type(self)( ret = type(self)(
self.template, self.name, args=list(args) if args else self._args, **kwargs if kwargs else self._kwargs self.template, self.name, args=list(args) if args else self._args, **kwargs if kwargs else self._kwargs
) )
ret.template = str(ret)
# print(f"prompt build result: {ret} name: {ret.name} ") # print(f"prompt build result: {ret} name: {ret.name} ")
return ret return ret