Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -23,7 +23,8 @@ from src.chat.utils.chat_message_builder import (
|
||||
from src.chat.utils.memory_mappings import get_memory_type_chinese_label
|
||||
|
||||
# 导入新的统一Prompt系统
|
||||
from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -8,13 +8,14 @@ import contextvars
|
||||
import re
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.utils.prompt_component_manager import prompt_component_manager
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
@@ -23,80 +24,6 @@ install(extra_lines=3)
|
||||
logger = get_logger("unified_prompt")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptParameters:
|
||||
"""统一提示词参数系统"""
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
sender: str = ""
|
||||
target: str = ""
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
bot_name: str = ""
|
||||
bot_nickname: str = ""
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
enable_expression: bool = True
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
memory_block: str = ""
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
time_block: str = ""
|
||||
identity_block: str = ""
|
||||
schedule_block: str = ""
|
||||
moderation_prompt_block: str = ""
|
||||
safety_guidelines_block: str = ""
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
errors.append("chat_id不能为空")
|
||||
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
||||
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
||||
if self.max_context_messages <= 0:
|
||||
errors.append("max_context_messages必须大于0")
|
||||
return errors
|
||||
|
||||
|
||||
class PromptContext:
|
||||
"""提示词上下文管理器"""
|
||||
|
||||
@@ -131,7 +58,7 @@ class PromptContext:
|
||||
context_id = None
|
||||
|
||||
previous_context = self._current_context
|
||||
token = self._current_context_var.set(context_id) if context_id else None
|
||||
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
|
||||
else:
|
||||
previous_context = self._current_context
|
||||
token = None
|
||||
@@ -184,16 +111,42 @@ class PromptManager:
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
|
||||
async def get_prompt_async(self, name: str) -> "Prompt":
|
||||
"""异步获取提示模板"""
|
||||
async def get_prompt_async(self, name: str, parameters: PromptParameters | None = None) -> "Prompt":
|
||||
"""
|
||||
异步获取提示模板,并动态注入插件内容
|
||||
"""
|
||||
original_prompt = None
|
||||
context_prompt = await self._context.get_prompt_async(name)
|
||||
if context_prompt is not None:
|
||||
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||
return context_prompt
|
||||
|
||||
if name not in self._prompts:
|
||||
original_prompt = context_prompt
|
||||
elif name in self._prompts:
|
||||
original_prompt = self._prompts[name]
|
||||
else:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
# 动态注入插件内容
|
||||
if original_prompt.name:
|
||||
# 确保我们有有效的parameters实例
|
||||
params_for_injection = parameters or original_prompt.parameters
|
||||
|
||||
components_prefix = await prompt_component_manager.execute_components_for(
|
||||
injection_point=original_prompt.name, params=params_for_injection
|
||||
)
|
||||
logger.info(components_prefix)
|
||||
if components_prefix:
|
||||
logger.info(f"为'{name}'注入插件内容: \n{components_prefix}")
|
||||
# 创建一个新的临时Prompt实例,不进行注册
|
||||
new_template = f"{components_prefix}\n\n{original_prompt.template}"
|
||||
temp_prompt = Prompt(
|
||||
template=new_template,
|
||||
name=original_prompt.name,
|
||||
parameters=original_prompt.parameters,
|
||||
should_register=False, # 确保不重新注册
|
||||
)
|
||||
return temp_prompt
|
||||
|
||||
return original_prompt
|
||||
|
||||
def generate_name(self, template: str) -> str:
|
||||
"""为未命名的prompt生成名称"""
|
||||
@@ -215,7 +168,9 @@ class PromptManager:
|
||||
|
||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||
"""格式化提示模板"""
|
||||
prompt = await self.get_prompt_async(name)
|
||||
# 提取parameters用于注入
|
||||
parameters = kwargs.get("parameters")
|
||||
prompt = await self.get_prompt_async(name, parameters=parameters)
|
||||
result = prompt.format(**kwargs)
|
||||
return result
|
||||
|
||||
@@ -303,11 +258,14 @@ class Prompt:
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 构建上下文数据
|
||||
# 1. 构建核心上下文数据
|
||||
context_data = await self._build_context_data()
|
||||
|
||||
# 格式化模板
|
||||
result = await self._format_with_context(context_data)
|
||||
# 2. 格式化主模板
|
||||
main_formatted_prompt = await self._format_with_context(context_data)
|
||||
|
||||
# 3. 拼接组件内容和主模板内容 (逻辑已前置到 get_prompt_async)
|
||||
result = main_formatted_prompt
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
|
||||
@@ -467,9 +425,13 @@ class Prompt:
|
||||
if not self.parameters.message_list_before_now_long:
|
||||
return
|
||||
|
||||
target_user_id = ""
|
||||
if self.parameters.target_user_info:
|
||||
target_user_id = self.parameters.target_user_info.get("user_id") or ""
|
||||
|
||||
read_history_prompt, unread_history_prompt = await self._build_s4u_chat_history_prompts(
|
||||
self.parameters.message_list_before_now_long,
|
||||
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
|
||||
target_user_id,
|
||||
self.parameters.sender,
|
||||
self.parameters.chat_id,
|
||||
)
|
||||
@@ -495,11 +457,14 @@ class Prompt:
|
||||
|
||||
# 创建临时生成器实例来使用其方法
|
||||
temp_generator = await get_replyer(None, chat_id, request_type="prompt_building")
|
||||
return await temp_generator.build_s4u_chat_history_prompts(
|
||||
message_list_before_now, target_user_id, sender, chat_id
|
||||
)
|
||||
if temp_generator:
|
||||
return await temp_generator.build_s4u_chat_history_prompts(
|
||||
message_list_before_now, target_user_id, sender, chat_id
|
||||
)
|
||||
return "", ""
|
||||
except Exception as e:
|
||||
logger.error(f"构建S4U历史消息prompt失败: {e}")
|
||||
return "", ""
|
||||
|
||||
async def _build_expression_habits(self) -> dict[str, Any]:
|
||||
"""构建表达习惯"""
|
||||
@@ -586,10 +551,10 @@ class Prompt:
|
||||
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
|
||||
|
||||
# 处理可能的异常结果
|
||||
if isinstance(running_memories, Exception):
|
||||
if isinstance(running_memories, BaseException):
|
||||
logger.warning(f"长期记忆查询失败: {running_memories}")
|
||||
running_memories = []
|
||||
if isinstance(instant_memory, Exception):
|
||||
if isinstance(instant_memory, BaseException):
|
||||
logger.warning(f"即时记忆查询失败: {instant_memory}")
|
||||
instant_memory = None
|
||||
|
||||
@@ -1103,8 +1068,24 @@ def create_prompt(
|
||||
async def create_prompt_async(
|
||||
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||
"""异步创建Prompt实例,并动态注入插件内容"""
|
||||
# 确保有可用的parameters实例
|
||||
final_params = parameters or PromptParameters(**kwargs)
|
||||
|
||||
# 动态注入插件内容
|
||||
if name:
|
||||
components_prefix = await prompt_component_manager.execute_components_for(
|
||||
injection_point=name, params=final_params
|
||||
)
|
||||
if components_prefix:
|
||||
logger.debug(f"为'{name}'注入插件内容: \n{components_prefix}")
|
||||
template = f"{components_prefix}\n\n{template}"
|
||||
|
||||
# 使用可能已修改的模板创建实例
|
||||
prompt = create_prompt(template, name, final_params)
|
||||
|
||||
# 如果在特定上下文中,则异步注册
|
||||
if global_prompt_manager._context._current_context:
|
||||
await global_prompt_manager._context.register_async(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
109
src/chat/utils/prompt_component_manager.py
Normal file
109
src/chat/utils/prompt_component_manager.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import asyncio
|
||||
from typing import Type
|
||||
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_prompt import BasePrompt
|
||||
from src.plugin_system.base.component_types import ComponentType, PromptInfo
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
logger = get_logger("prompt_component_manager")
|
||||
|
||||
|
||||
class PromptComponentManager:
|
||||
"""
|
||||
管理所有 `BasePrompt` 组件的单例类。
|
||||
|
||||
该管理器负责:
|
||||
1. 从 `component_registry` 中查询 `BasePrompt` 子类。
|
||||
2. 根据注入点(目标Prompt名称)对它们进行筛选。
|
||||
3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。
|
||||
"""
|
||||
|
||||
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]:
|
||||
"""
|
||||
获取指定注入点的所有已注册组件类。
|
||||
|
||||
Args:
|
||||
injection_point: 目标Prompt的名称。
|
||||
|
||||
Returns:
|
||||
list[Type[BasePrompt]]: 与该注入点关联的组件类列表。
|
||||
"""
|
||||
# 从组件注册中心获取所有启用的Prompt组件
|
||||
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
|
||||
|
||||
matching_components: list[Type[BasePrompt]] = []
|
||||
|
||||
for prompt_name, prompt_info in enabled_prompts.items():
|
||||
# 确保 prompt_info 是 PromptInfo 类型
|
||||
if not isinstance(prompt_info, PromptInfo):
|
||||
continue
|
||||
|
||||
# 获取注入点信息
|
||||
injection_points = prompt_info.injection_point
|
||||
if isinstance(injection_points, str):
|
||||
injection_points = [injection_points]
|
||||
|
||||
# 检查当前注入点是否匹配
|
||||
if injection_point in injection_points:
|
||||
# 获取组件类
|
||||
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
|
||||
if component_class and issubclass(component_class, BasePrompt):
|
||||
matching_components.append(component_class)
|
||||
|
||||
return matching_components
|
||||
|
||||
async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str:
|
||||
"""
|
||||
实例化并执行指定注入点的所有组件,然后将它们的输出拼接成一个字符串。
|
||||
|
||||
Args:
|
||||
injection_point: 目标Prompt的名称。
|
||||
params: 用于初始化组件的 PromptParameters 对象。
|
||||
|
||||
Returns:
|
||||
str: 所有相关组件生成的、用换行符连接的文本内容。
|
||||
"""
|
||||
component_classes = self.get_components_for(injection_point)
|
||||
if not component_classes:
|
||||
return ""
|
||||
|
||||
tasks = []
|
||||
for component_class in component_classes:
|
||||
try:
|
||||
# 从注册中心获取组件信息
|
||||
prompt_info = component_registry.get_component_info(
|
||||
component_class.prompt_name, ComponentType.PROMPT
|
||||
)
|
||||
if not isinstance(prompt_info, PromptInfo):
|
||||
logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置")
|
||||
plugin_config = {}
|
||||
else:
|
||||
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
|
||||
|
||||
instance = component_class(params=params, plugin_config=plugin_config)
|
||||
tasks.append(instance.execute())
|
||||
except Exception as e:
|
||||
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
|
||||
|
||||
if not tasks:
|
||||
return ""
|
||||
|
||||
# 并行执行所有组件
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 过滤掉执行失败的结果和空字符串
|
||||
valid_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"执行 Prompt 组件 '{component_classes[i].prompt_name}' 失败: {result}")
|
||||
elif result and isinstance(result, str) and result.strip():
|
||||
valid_results.append(result.strip())
|
||||
|
||||
# 使用换行符拼接所有有效结果
|
||||
return "\n".join(valid_results)
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
prompt_component_manager = PromptComponentManager()
|
||||
79
src/chat/utils/prompt_params.py
Normal file
79
src/chat/utils/prompt_params.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
This module contains the PromptParameters class, which is used to define the parameters for a prompt.
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptParameters:
|
||||
"""统一提示词参数系统"""
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
sender: str = ""
|
||||
target: str = ""
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
bot_name: str = ""
|
||||
bot_nickname: str = ""
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
enable_expression: bool = True
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
memory_block: str = ""
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
time_block: str = ""
|
||||
identity_block: str = ""
|
||||
schedule_block: str = ""
|
||||
moderation_prompt_block: str = ""
|
||||
safety_guidelines_block: str = ""
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
errors.append("chat_id不能为空")
|
||||
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
||||
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
||||
if self.max_context_messages <= 0:
|
||||
errors.append("max_context_messages必须大于0")
|
||||
return errors
|
||||
Reference in New Issue
Block a user