Merge branch 'dev' of https://github.com/SnowindMe/MaiBot into dev
This commit is contained in:
@@ -195,7 +195,7 @@ class PromptBuilder:
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"reasoning_prompt_main",
|
||||
relation_prompt_all=await global_prompt_manager.get_prompt_async("relationship_prompt"),
|
||||
replation_prompt=relation_prompt,
|
||||
relation_prompt=relation_prompt,
|
||||
sender_name=sender_name,
|
||||
memory_prompt=memory_prompt,
|
||||
prompt_info=prompt_info,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# import re
|
||||
import ast
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
|
||||
import re
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
@@ -96,15 +94,13 @@ class Prompt(str):
|
||||
# 如果传入的是元组,转换为列表
|
||||
if isinstance(args, tuple):
|
||||
args = list(args)
|
||||
|
||||
should_register = kwargs.pop("_should_register", True)
|
||||
# 解析模板
|
||||
tree = ast.parse(f"f'''{fstr}'''", mode="eval")
|
||||
template_args = set()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FormattedValue):
|
||||
expr = ast.get_source_segment(fstr, node.value)
|
||||
if expr:
|
||||
template_args.add(expr)
|
||||
template_args = []
|
||||
result = re.findall(r"\{(.*?)\}", fstr)
|
||||
for expr in result:
|
||||
if expr and expr not in template_args:
|
||||
template_args.append(expr)
|
||||
|
||||
# 如果提供了初始参数,立即格式化
|
||||
if kwargs or args:
|
||||
@@ -120,13 +116,14 @@ class Prompt(str):
|
||||
obj._kwargs = kwargs
|
||||
|
||||
# 修改自动注册逻辑
|
||||
if global_prompt_manager._context._current_context:
|
||||
# 如果存在当前上下文,则注册到上下文中
|
||||
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
|
||||
pass
|
||||
else:
|
||||
# 否则注册到全局管理器
|
||||
global_prompt_manager.register(obj)
|
||||
if should_register:
|
||||
if global_prompt_manager._context._current_context:
|
||||
# 如果存在当前上下文,则注册到上下文中
|
||||
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
|
||||
pass
|
||||
else:
|
||||
# 否则注册到全局管理器
|
||||
global_prompt_manager.register(obj)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
@@ -141,14 +138,11 @@ class Prompt(str):
|
||||
|
||||
@classmethod
|
||||
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
|
||||
fmt_str = f"f'''{template}'''"
|
||||
tree = ast.parse(fmt_str, mode="eval")
|
||||
template_args = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FormattedValue):
|
||||
expr = ast.get_source_segment(fmt_str, node.value)
|
||||
if expr and expr not in template_args:
|
||||
template_args.append(expr)
|
||||
result = re.findall(r"\{(.*?)\}", template)
|
||||
for expr in result:
|
||||
if expr and expr not in template_args:
|
||||
template_args.append(expr)
|
||||
formatted_args = {}
|
||||
formatted_kwargs = {}
|
||||
|
||||
@@ -180,14 +174,22 @@ class Prompt(str):
|
||||
template = template.format(**formatted_kwargs)
|
||||
return template
|
||||
except (IndexError, KeyError) as e:
|
||||
raise ValueError(f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs}") from e
|
||||
raise ValueError(
|
||||
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
|
||||
) from e
|
||||
|
||||
def format(self, *args, **kwargs) -> "Prompt":
|
||||
"""支持位置参数和关键字参数的格式化,使用"""
|
||||
ret = type(self)(
|
||||
self.template, self.name, args=list(args) if args else self._args, **kwargs if kwargs else self._kwargs
|
||||
self.template,
|
||||
self.name,
|
||||
args=list(args) if args else self._args,
|
||||
_should_register=False,
|
||||
**kwargs if kwargs else self._kwargs,
|
||||
)
|
||||
# print(f"prompt build result: {ret} name: {ret.name} ")
|
||||
ret.template = str(ret)
|
||||
print(f"prompt build result: {ret} name: {ret.name} ")
|
||||
print(global_prompt_manager._prompts["schedule_prompt"])
|
||||
return ret
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
Reference in New Issue
Block a user