This commit is contained in:
SnowindMe
2025-04-13 02:27:29 +08:00
2 changed files with 31 additions and 29 deletions

View File

@@ -195,7 +195,7 @@ class PromptBuilder:
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"reasoning_prompt_main", "reasoning_prompt_main",
relation_prompt_all=await global_prompt_manager.get_prompt_async("relationship_prompt"), relation_prompt_all=await global_prompt_manager.get_prompt_async("relationship_prompt"),
replation_prompt=relation_prompt, relation_prompt=relation_prompt,
sender_name=sender_name, sender_name=sender_name,
memory_prompt=memory_prompt, memory_prompt=memory_prompt,
prompt_info=prompt_info, prompt_info=prompt_info,

View File

@@ -1,7 +1,5 @@
# import re
import ast
from typing import Dict, Any, Optional, List, Union from typing import Dict, Any, Optional, List, Union
import re
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import asyncio import asyncio
@@ -96,15 +94,13 @@ class Prompt(str):
# 如果传入的是元组,转换为列表 # 如果传入的是元组,转换为列表
if isinstance(args, tuple): if isinstance(args, tuple):
args = list(args) args = list(args)
should_register = kwargs.pop("_should_register", True)
# 解析模板 # 解析模板
tree = ast.parse(f"f'''{fstr}'''", mode="eval") template_args = []
template_args = set() result = re.findall(r"\{(.*?)\}", fstr)
for node in ast.walk(tree): for expr in result:
if isinstance(node, ast.FormattedValue): if expr and expr not in template_args:
expr = ast.get_source_segment(fstr, node.value) template_args.append(expr)
if expr:
template_args.add(expr)
# 如果提供了初始参数,立即格式化 # 如果提供了初始参数,立即格式化
if kwargs or args: if kwargs or args:
@@ -120,6 +116,7 @@ class Prompt(str):
obj._kwargs = kwargs obj._kwargs = kwargs
# 修改自动注册逻辑 # 修改自动注册逻辑
if should_register:
if global_prompt_manager._context._current_context: if global_prompt_manager._context._current_context:
# 如果存在当前上下文,则注册到上下文中 # 如果存在当前上下文,则注册到上下文中
# asyncio.create_task(global_prompt_manager._context.register_async(obj)) # asyncio.create_task(global_prompt_manager._context.register_async(obj))
@@ -141,12 +138,9 @@ class Prompt(str):
@classmethod @classmethod
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str: def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
fmt_str = f"f'''{template}'''"
tree = ast.parse(fmt_str, mode="eval")
template_args = [] template_args = []
for node in ast.walk(tree): result = re.findall(r"\{(.*?)\}", template)
if isinstance(node, ast.FormattedValue): for expr in result:
expr = ast.get_source_segment(fmt_str, node.value)
if expr and expr not in template_args: if expr and expr not in template_args:
template_args.append(expr) template_args.append(expr)
formatted_args = {} formatted_args = {}
@@ -180,14 +174,22 @@ class Prompt(str):
template = template.format(**formatted_kwargs) template = template.format(**formatted_kwargs)
return template return template
except (IndexError, KeyError) as e: except (IndexError, KeyError) as e:
raise ValueError(f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs}") from e raise ValueError(
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
) from e
def format(self, *args, **kwargs) -> "Prompt": def format(self, *args, **kwargs) -> "Prompt":
"""支持位置参数和关键字参数的格式化,使用""" """支持位置参数和关键字参数的格式化,使用"""
ret = type(self)( ret = type(self)(
self.template, self.name, args=list(args) if args else self._args, **kwargs if kwargs else self._kwargs self.template,
self.name,
args=list(args) if args else self._args,
_should_register=False,
**kwargs if kwargs else self._kwargs,
) )
# print(f"prompt build result: {ret} name: {ret.name} ") ret.template = str(ret)
print(f"prompt build result: {ret} name: {ret.name} ")
print(global_prompt_manager._prompts["schedule_prompt"])
return ret return ret
def __str__(self) -> str: def __str__(self) -> str: