This commit is contained in:
SnowindMe
2025-04-13 02:27:29 +08:00
2 changed files with 31 additions and 29 deletions

View File

@@ -195,7 +195,7 @@ class PromptBuilder:
prompt = await global_prompt_manager.format_prompt(
"reasoning_prompt_main",
relation_prompt_all=await global_prompt_manager.get_prompt_async("relationship_prompt"),
replation_prompt=relation_prompt,
relation_prompt=relation_prompt,
sender_name=sender_name,
memory_prompt=memory_prompt,
prompt_info=prompt_info,

View File

@@ -1,7 +1,5 @@
# import re
import ast
from typing import Dict, Any, Optional, List, Union
import re
from contextlib import asynccontextmanager
import asyncio
@@ -96,15 +94,13 @@ class Prompt(str):
# 如果传入的是元组,转换为列表
if isinstance(args, tuple):
args = list(args)
should_register = kwargs.pop("_should_register", True)
# 解析模板
tree = ast.parse(f"f'''{fstr}'''", mode="eval")
template_args = set()
for node in ast.walk(tree):
if isinstance(node, ast.FormattedValue):
expr = ast.get_source_segment(fstr, node.value)
if expr:
template_args.add(expr)
template_args = []
result = re.findall(r"\{(.*?)\}", fstr)
for expr in result:
if expr and expr not in template_args:
template_args.append(expr)
# 如果提供了初始参数,立即格式化
if kwargs or args:
@@ -120,13 +116,14 @@ class Prompt(str):
obj._kwargs = kwargs
# 修改自动注册逻辑
if global_prompt_manager._context._current_context:
# 如果存在当前上下文,则注册到上下文中
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
pass
else:
# 否则注册到全局管理器
global_prompt_manager.register(obj)
if should_register:
if global_prompt_manager._context._current_context:
# 如果存在当前上下文,则注册到上下文中
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
pass
else:
# 否则注册到全局管理器
global_prompt_manager.register(obj)
return obj
@classmethod
@@ -141,14 +138,11 @@ class Prompt(str):
@classmethod
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
fmt_str = f"f'''{template}'''"
tree = ast.parse(fmt_str, mode="eval")
template_args = []
for node in ast.walk(tree):
if isinstance(node, ast.FormattedValue):
expr = ast.get_source_segment(fmt_str, node.value)
if expr and expr not in template_args:
template_args.append(expr)
result = re.findall(r"\{(.*?)\}", template)
for expr in result:
if expr and expr not in template_args:
template_args.append(expr)
formatted_args = {}
formatted_kwargs = {}
@@ -180,14 +174,22 @@ class Prompt(str):
template = template.format(**formatted_kwargs)
return template
except (IndexError, KeyError) as e:
raise ValueError(f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs}") from e
raise ValueError(
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
) from e
def format(self, *args, **kwargs) -> "Prompt":
"""支持位置参数和关键字参数的格式化,使用"""
ret = type(self)(
self.template, self.name, args=list(args) if args else self._args, **kwargs if kwargs else self._kwargs
self.template,
self.name,
args=list(args) if args else self._args,
_should_register=False,
**kwargs if kwargs else self._kwargs,
)
# print(f"prompt build result: {ret} name: {ret.name} ")
ret.template = str(ret)
print(f"prompt build result: {ret} name: {ret.name} ")
print(global_prompt_manager._prompts["schedule_prompt"])
return ret
def __str__(self) -> str: