Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -6,6 +6,8 @@ from src.plugin_system import (
|
||||
BaseAction,
|
||||
BaseEventHandler,
|
||||
BasePlugin,
|
||||
BasePrompt,
|
||||
ToolParamType,
|
||||
BaseTool,
|
||||
ChatType,
|
||||
CommandArgs,
|
||||
@@ -36,7 +38,17 @@ class GetSystemInfoTool(BaseTool):
|
||||
name = "get_system_info"
|
||||
description = "获取当前系统的模拟版本和状态信息。"
|
||||
available_for_llm = True
|
||||
parameters = []
|
||||
parameters = [
|
||||
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
|
||||
("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量,默认为5。", False, None),
|
||||
(
|
||||
"time_range",
|
||||
ToolParamType.STRING,
|
||||
"指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。",
|
||||
False,
|
||||
["any", "week", "month"],
|
||||
),
|
||||
] # type: ignore
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
return {"name": self.name, "content": "系统版本: 1.0.1, 状态: 运行正常"}
|
||||
@@ -99,7 +111,6 @@ class LLMJudgeExampleAction(BaseAction):
|
||||
async def go_activate(self, chat_content: str = "", llm_judge_model=None) -> bool:
|
||||
"""LLM 判断激活:判断用户是否情绪低落"""
|
||||
return await self._llm_judge_activation(
|
||||
chat_content=chat_content,
|
||||
judge_prompt="""
|
||||
判断用户是否表达了以下情绪或需求:
|
||||
1. 感到难过、沮丧或失落
|
||||
@@ -169,6 +180,19 @@ class RandomEmojiAction(BaseAction):
|
||||
return True, "成功发送了一个随机表情"
|
||||
|
||||
|
||||
class WeatherPrompt(BasePrompt):
|
||||
"""一个简单的Prompt组件,用于向Planner注入天气信息。"""
|
||||
|
||||
prompt_name = "weather_info_prompt"
|
||||
prompt_description = "向Planner注入当前天气信息,以丰富对话上下文。"
|
||||
injection_point = "planner_prompt"
|
||||
|
||||
async def execute(self) -> str:
|
||||
# 在实际应用中,这里可以调用天气API
|
||||
# 为了演示,我们返回一个固定的天气信息
|
||||
return "当前天气:晴朗,温度25°C。"
|
||||
|
||||
|
||||
@register_plugin
|
||||
class HelloWorldPlugin(BasePlugin):
|
||||
"""一个包含四大核心组件和高级配置功能的入门示例插件。"""
|
||||
@@ -178,7 +202,6 @@ class HelloWorldPlugin(BasePlugin):
|
||||
dependencies = []
|
||||
python_dependencies = []
|
||||
config_file_name = "config.toml"
|
||||
enable_plugin = False
|
||||
|
||||
config_schema = {
|
||||
"meta": {
|
||||
@@ -208,4 +231,7 @@ class HelloWorldPlugin(BasePlugin):
|
||||
if self.get_config("components.random_emoji_action_enabled", True):
|
||||
components.append((RandomEmojiAction.get_action_info(), RandomEmojiAction))
|
||||
|
||||
# 注册新的Prompt组件
|
||||
components.append((WeatherPrompt.get_prompt_info(), WeatherPrompt))
|
||||
|
||||
return components
|
||||
|
||||
@@ -23,7 +23,8 @@ from src.chat.utils.chat_message_builder import (
|
||||
from src.chat.utils.memory_mappings import get_memory_type_chinese_label
|
||||
|
||||
# 导入新的统一Prompt系统
|
||||
from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -8,13 +8,14 @@ import contextvars
|
||||
import re
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.utils.prompt_component_manager import prompt_component_manager
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
@@ -23,80 +24,6 @@ install(extra_lines=3)
|
||||
logger = get_logger("unified_prompt")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptParameters:
|
||||
"""统一提示词参数系统"""
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
sender: str = ""
|
||||
target: str = ""
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
bot_name: str = ""
|
||||
bot_nickname: str = ""
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
enable_expression: bool = True
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
memory_block: str = ""
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
time_block: str = ""
|
||||
identity_block: str = ""
|
||||
schedule_block: str = ""
|
||||
moderation_prompt_block: str = ""
|
||||
safety_guidelines_block: str = ""
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
errors.append("chat_id不能为空")
|
||||
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
||||
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
||||
if self.max_context_messages <= 0:
|
||||
errors.append("max_context_messages必须大于0")
|
||||
return errors
|
||||
|
||||
|
||||
class PromptContext:
|
||||
"""提示词上下文管理器"""
|
||||
|
||||
@@ -131,7 +58,7 @@ class PromptContext:
|
||||
context_id = None
|
||||
|
||||
previous_context = self._current_context
|
||||
token = self._current_context_var.set(context_id) if context_id else None
|
||||
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
|
||||
else:
|
||||
previous_context = self._current_context
|
||||
token = None
|
||||
@@ -184,16 +111,42 @@ class PromptManager:
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
|
||||
async def get_prompt_async(self, name: str) -> "Prompt":
|
||||
"""异步获取提示模板"""
|
||||
async def get_prompt_async(self, name: str, parameters: PromptParameters | None = None) -> "Prompt":
|
||||
"""
|
||||
异步获取提示模板,并动态注入插件内容
|
||||
"""
|
||||
original_prompt = None
|
||||
context_prompt = await self._context.get_prompt_async(name)
|
||||
if context_prompt is not None:
|
||||
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||
return context_prompt
|
||||
|
||||
if name not in self._prompts:
|
||||
original_prompt = context_prompt
|
||||
elif name in self._prompts:
|
||||
original_prompt = self._prompts[name]
|
||||
else:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
# 动态注入插件内容
|
||||
if original_prompt.name:
|
||||
# 确保我们有有效的parameters实例
|
||||
params_for_injection = parameters or original_prompt.parameters
|
||||
|
||||
components_prefix = await prompt_component_manager.execute_components_for(
|
||||
injection_point=original_prompt.name, params=params_for_injection
|
||||
)
|
||||
logger.info(components_prefix)
|
||||
if components_prefix:
|
||||
logger.info(f"为'{name}'注入插件内容: \n{components_prefix}")
|
||||
# 创建一个新的临时Prompt实例,不进行注册
|
||||
new_template = f"{components_prefix}\n\n{original_prompt.template}"
|
||||
temp_prompt = Prompt(
|
||||
template=new_template,
|
||||
name=original_prompt.name,
|
||||
parameters=original_prompt.parameters,
|
||||
should_register=False, # 确保不重新注册
|
||||
)
|
||||
return temp_prompt
|
||||
|
||||
return original_prompt
|
||||
|
||||
def generate_name(self, template: str) -> str:
|
||||
"""为未命名的prompt生成名称"""
|
||||
@@ -215,7 +168,9 @@ class PromptManager:
|
||||
|
||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||
"""格式化提示模板"""
|
||||
prompt = await self.get_prompt_async(name)
|
||||
# 提取parameters用于注入
|
||||
parameters = kwargs.get("parameters")
|
||||
prompt = await self.get_prompt_async(name, parameters=parameters)
|
||||
result = prompt.format(**kwargs)
|
||||
return result
|
||||
|
||||
@@ -303,11 +258,14 @@ class Prompt:
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 构建上下文数据
|
||||
# 1. 构建核心上下文数据
|
||||
context_data = await self._build_context_data()
|
||||
|
||||
# 格式化模板
|
||||
result = await self._format_with_context(context_data)
|
||||
# 2. 格式化主模板
|
||||
main_formatted_prompt = await self._format_with_context(context_data)
|
||||
|
||||
# 3. 拼接组件内容和主模板内容 (逻辑已前置到 get_prompt_async)
|
||||
result = main_formatted_prompt
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
|
||||
@@ -467,9 +425,13 @@ class Prompt:
|
||||
if not self.parameters.message_list_before_now_long:
|
||||
return
|
||||
|
||||
target_user_id = ""
|
||||
if self.parameters.target_user_info:
|
||||
target_user_id = self.parameters.target_user_info.get("user_id") or ""
|
||||
|
||||
read_history_prompt, unread_history_prompt = await self._build_s4u_chat_history_prompts(
|
||||
self.parameters.message_list_before_now_long,
|
||||
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
|
||||
target_user_id,
|
||||
self.parameters.sender,
|
||||
self.parameters.chat_id,
|
||||
)
|
||||
@@ -495,11 +457,14 @@ class Prompt:
|
||||
|
||||
# 创建临时生成器实例来使用其方法
|
||||
temp_generator = await get_replyer(None, chat_id, request_type="prompt_building")
|
||||
return await temp_generator.build_s4u_chat_history_prompts(
|
||||
message_list_before_now, target_user_id, sender, chat_id
|
||||
)
|
||||
if temp_generator:
|
||||
return await temp_generator.build_s4u_chat_history_prompts(
|
||||
message_list_before_now, target_user_id, sender, chat_id
|
||||
)
|
||||
return "", ""
|
||||
except Exception as e:
|
||||
logger.error(f"构建S4U历史消息prompt失败: {e}")
|
||||
return "", ""
|
||||
|
||||
async def _build_expression_habits(self) -> dict[str, Any]:
|
||||
"""构建表达习惯"""
|
||||
@@ -586,10 +551,10 @@ class Prompt:
|
||||
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
|
||||
|
||||
# 处理可能的异常结果
|
||||
if isinstance(running_memories, Exception):
|
||||
if isinstance(running_memories, BaseException):
|
||||
logger.warning(f"长期记忆查询失败: {running_memories}")
|
||||
running_memories = []
|
||||
if isinstance(instant_memory, Exception):
|
||||
if isinstance(instant_memory, BaseException):
|
||||
logger.warning(f"即时记忆查询失败: {instant_memory}")
|
||||
instant_memory = None
|
||||
|
||||
@@ -1103,8 +1068,24 @@ def create_prompt(
|
||||
async def create_prompt_async(
|
||||
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||
"""异步创建Prompt实例,并动态注入插件内容"""
|
||||
# 确保有可用的parameters实例
|
||||
final_params = parameters or PromptParameters(**kwargs)
|
||||
|
||||
# 动态注入插件内容
|
||||
if name:
|
||||
components_prefix = await prompt_component_manager.execute_components_for(
|
||||
injection_point=name, params=final_params
|
||||
)
|
||||
if components_prefix:
|
||||
logger.debug(f"为'{name}'注入插件内容: \n{components_prefix}")
|
||||
template = f"{components_prefix}\n\n{template}"
|
||||
|
||||
# 使用可能已修改的模板创建实例
|
||||
prompt = create_prompt(template, name, final_params)
|
||||
|
||||
# 如果在特定上下文中,则异步注册
|
||||
if global_prompt_manager._context._current_context:
|
||||
await global_prompt_manager._context.register_async(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
109
src/chat/utils/prompt_component_manager.py
Normal file
109
src/chat/utils/prompt_component_manager.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import asyncio
|
||||
from typing import Type
|
||||
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_prompt import BasePrompt
|
||||
from src.plugin_system.base.component_types import ComponentType, PromptInfo
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
logger = get_logger("prompt_component_manager")
|
||||
|
||||
|
||||
class PromptComponentManager:
|
||||
"""
|
||||
管理所有 `BasePrompt` 组件的单例类。
|
||||
|
||||
该管理器负责:
|
||||
1. 从 `component_registry` 中查询 `BasePrompt` 子类。
|
||||
2. 根据注入点(目标Prompt名称)对它们进行筛选。
|
||||
3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。
|
||||
"""
|
||||
|
||||
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]:
|
||||
"""
|
||||
获取指定注入点的所有已注册组件类。
|
||||
|
||||
Args:
|
||||
injection_point: 目标Prompt的名称。
|
||||
|
||||
Returns:
|
||||
list[Type[BasePrompt]]: 与该注入点关联的组件类列表。
|
||||
"""
|
||||
# 从组件注册中心获取所有启用的Prompt组件
|
||||
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
|
||||
|
||||
matching_components: list[Type[BasePrompt]] = []
|
||||
|
||||
for prompt_name, prompt_info in enabled_prompts.items():
|
||||
# 确保 prompt_info 是 PromptInfo 类型
|
||||
if not isinstance(prompt_info, PromptInfo):
|
||||
continue
|
||||
|
||||
# 获取注入点信息
|
||||
injection_points = prompt_info.injection_point
|
||||
if isinstance(injection_points, str):
|
||||
injection_points = [injection_points]
|
||||
|
||||
# 检查当前注入点是否匹配
|
||||
if injection_point in injection_points:
|
||||
# 获取组件类
|
||||
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
|
||||
if component_class and issubclass(component_class, BasePrompt):
|
||||
matching_components.append(component_class)
|
||||
|
||||
return matching_components
|
||||
|
||||
async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str:
|
||||
"""
|
||||
实例化并执行指定注入点的所有组件,然后将它们的输出拼接成一个字符串。
|
||||
|
||||
Args:
|
||||
injection_point: 目标Prompt的名称。
|
||||
params: 用于初始化组件的 PromptParameters 对象。
|
||||
|
||||
Returns:
|
||||
str: 所有相关组件生成的、用换行符连接的文本内容。
|
||||
"""
|
||||
component_classes = self.get_components_for(injection_point)
|
||||
if not component_classes:
|
||||
return ""
|
||||
|
||||
tasks = []
|
||||
for component_class in component_classes:
|
||||
try:
|
||||
# 从注册中心获取组件信息
|
||||
prompt_info = component_registry.get_component_info(
|
||||
component_class.prompt_name, ComponentType.PROMPT
|
||||
)
|
||||
if not isinstance(prompt_info, PromptInfo):
|
||||
logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置")
|
||||
plugin_config = {}
|
||||
else:
|
||||
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
|
||||
|
||||
instance = component_class(params=params, plugin_config=plugin_config)
|
||||
tasks.append(instance.execute())
|
||||
except Exception as e:
|
||||
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
|
||||
|
||||
if not tasks:
|
||||
return ""
|
||||
|
||||
# 并行执行所有组件
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 过滤掉执行失败的结果和空字符串
|
||||
valid_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"执行 Prompt 组件 '{component_classes[i].prompt_name}' 失败: {result}")
|
||||
elif result and isinstance(result, str) and result.strip():
|
||||
valid_results.append(result.strip())
|
||||
|
||||
# 使用换行符拼接所有有效结果
|
||||
return "\n".join(valid_results)
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
prompt_component_manager = PromptComponentManager()
|
||||
79
src/chat/utils/prompt_params.py
Normal file
79
src/chat/utils/prompt_params.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
This module contains the PromptParameters class, which is used to define the parameters for a prompt.
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptParameters:
|
||||
"""统一提示词参数系统"""
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
sender: str = ""
|
||||
target: str = ""
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
bot_name: str = ""
|
||||
bot_nickname: str = ""
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
enable_expression: bool = True
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
memory_block: str = ""
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
time_block: str = ""
|
||||
identity_block: str = ""
|
||||
schedule_block: str = ""
|
||||
moderation_prompt_block: str = ""
|
||||
safety_guidelines_block: str = ""
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
errors.append("chat_id不能为空")
|
||||
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
||||
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
||||
if self.max_context_messages <= 0:
|
||||
errors.append("max_context_messages必须大于0")
|
||||
return errors
|
||||
@@ -26,6 +26,7 @@ from .base import (
|
||||
ActionInfo,
|
||||
BaseAction,
|
||||
BaseCommand,
|
||||
BasePrompt,
|
||||
BaseEventHandler,
|
||||
BasePlugin,
|
||||
BaseTool,
|
||||
@@ -64,6 +65,7 @@ __all__ = [
|
||||
"BaseEventHandler",
|
||||
# 基础类
|
||||
"BasePlugin",
|
||||
"BasePrompt",
|
||||
"BaseTool",
|
||||
"ChatMode",
|
||||
"ChatType",
|
||||
|
||||
@@ -8,6 +8,7 @@ from .base_action import BaseAction
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .base_plugin import BasePlugin
|
||||
from .base_prompt import BasePrompt
|
||||
from .base_tool import BaseTool
|
||||
from .command_args import CommandArgs
|
||||
from .component_types import (
|
||||
@@ -37,6 +38,7 @@ __all__ = [
|
||||
"BaseCommand",
|
||||
"BaseEventHandler",
|
||||
"BasePlugin",
|
||||
"BasePrompt",
|
||||
"BaseTool",
|
||||
"ChatMode",
|
||||
"ChatType",
|
||||
|
||||
@@ -8,6 +8,7 @@ from src.plugin_system.base.component_types import (
|
||||
EventHandlerInfo,
|
||||
InterestCalculatorInfo,
|
||||
PlusCommandInfo,
|
||||
PromptInfo,
|
||||
ToolInfo,
|
||||
)
|
||||
|
||||
@@ -15,6 +16,7 @@ from .base_action import BaseAction
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .base_interest_calculator import BaseInterestCalculator
|
||||
from .base_prompt import BasePrompt
|
||||
from .base_tool import BaseTool
|
||||
from .plugin_base import PluginBase
|
||||
from .plus_command import PlusCommand
|
||||
@@ -80,6 +82,13 @@ class BasePlugin(PluginBase):
|
||||
logger.warning("EventHandler的get_info逻辑尚未实现")
|
||||
return None
|
||||
|
||||
elif component_type == ComponentType.PROMPT:
|
||||
if hasattr(component_class, "get_prompt_info"):
|
||||
return component_class.get_prompt_info()
|
||||
else:
|
||||
logger.warning(f"Prompt类 {component_class.__name__} 缺少 get_prompt_info 方法")
|
||||
return None
|
||||
|
||||
else:
|
||||
logger.error(f"不支持的组件类型: {component_type}")
|
||||
return None
|
||||
@@ -109,6 +118,7 @@ class BasePlugin(PluginBase):
|
||||
| tuple[EventHandlerInfo, type[BaseEventHandler]]
|
||||
| tuple[ToolInfo, type[BaseTool]]
|
||||
| tuple[InterestCalculatorInfo, type[BaseInterestCalculator]]
|
||||
| tuple[PromptInfo, type[BasePrompt]]
|
||||
]:
|
||||
"""获取插件包含的组件列表
|
||||
|
||||
|
||||
95
src/plugin_system/base/base_prompt.py
Normal file
95
src/plugin_system/base/base_prompt.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ComponentType, PromptInfo
|
||||
|
||||
logger = get_logger("base_prompt")
|
||||
|
||||
|
||||
class BasePrompt(ABC):
|
||||
"""Prompt组件基类
|
||||
|
||||
Prompt是插件的一种组件类型,用于动态地向现有的核心Prompt模板中注入额外的上下文信息。
|
||||
它的主要作用是在不修改核心代码的情况下,扩展和定制模型的行为。
|
||||
|
||||
子类可以通过类属性定义其行为:
|
||||
- prompt_name: Prompt组件的唯一名称。
|
||||
- injection_point: 指定要注入的目标Prompt名称(或名称列表)。
|
||||
"""
|
||||
|
||||
prompt_name: str = ""
|
||||
"""Prompt组件的名称"""
|
||||
prompt_description: str = ""
|
||||
"""Prompt组件的描述"""
|
||||
|
||||
# 定义此组件希望注入到哪个或哪些核心Prompt中
|
||||
# 可以是一个字符串(单个目标)或字符串列表(多个目标)
|
||||
# 例如: "planner_prompt" 或 ["s4u_style_prompt", "normal_style_prompt"]
|
||||
injection_point: str | list[str] = ""
|
||||
"""要注入的目标Prompt名称或列表"""
|
||||
|
||||
def __init__(self, params: PromptParameters, plugin_config: dict | None = None):
|
||||
"""初始化Prompt组件
|
||||
|
||||
Args:
|
||||
params: 统一提示词参数,包含所有构建提示词所需的上下文信息。
|
||||
plugin_config: 插件配置字典。
|
||||
"""
|
||||
self.params = params
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.log_prefix = "[PromptComponent]"
|
||||
|
||||
logger.debug(f"{self.log_prefix} Prompt组件 '{self.prompt_name}' 初始化完成")
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> str:
|
||||
"""执行Prompt生成的抽象方法,子类必须实现。
|
||||
|
||||
此方法应根据初始化时传入的 `self.params` 来构建并返回一个字符串。
|
||||
返回的字符串将被拼接到目标Prompt的最前面。
|
||||
|
||||
Returns:
|
||||
str: 生成的文本内容。
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""获取插件配置值,支持嵌套键访问。
|
||||
|
||||
Args:
|
||||
key: 配置键名,使用点号进行嵌套访问,如 "section.subsection.key"。
|
||||
default: 未找到键时返回的默认值。
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值。
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@classmethod
|
||||
def get_prompt_info(cls) -> "PromptInfo":
|
||||
"""从类属性生成PromptInfo,用于组件注册和管理。
|
||||
|
||||
Returns:
|
||||
PromptInfo: 生成的Prompt信息对象。
|
||||
"""
|
||||
if not cls.prompt_name:
|
||||
raise ValueError("Prompt组件必须定义 'prompt_name' 类属性。")
|
||||
|
||||
return PromptInfo(
|
||||
name=cls.prompt_name,
|
||||
component_type=ComponentType.PROMPT,
|
||||
description=cls.prompt_description,
|
||||
injection_point=cls.injection_point,
|
||||
)
|
||||
@@ -20,6 +20,7 @@ class ComponentType(Enum):
|
||||
EVENT_HANDLER = "event_handler" # 事件处理组件
|
||||
CHATTER = "chatter" # 聊天处理器组件
|
||||
INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件
|
||||
PROMPT = "prompt" # Prompt组件
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
@@ -266,6 +267,18 @@ class EventInfo(ComponentInfo):
|
||||
self.component_type = ComponentType.EVENT_HANDLER
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptInfo(ComponentInfo):
|
||||
"""Prompt组件信息"""
|
||||
|
||||
injection_point: str | list[str] = ""
|
||||
"""要注入的目标Prompt名称或列表"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.PROMPT
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginInfo:
|
||||
"""插件信息"""
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
|
||||
from src.plugin_system.base.base_prompt import BasePrompt
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import (
|
||||
ActionInfo,
|
||||
@@ -22,6 +23,7 @@ from src.plugin_system.base.component_types import (
|
||||
InterestCalculatorInfo,
|
||||
PluginInfo,
|
||||
PlusCommandInfo,
|
||||
PromptInfo,
|
||||
ToolInfo,
|
||||
)
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
@@ -37,6 +39,7 @@ ComponentClassType = (
|
||||
| type[PlusCommand]
|
||||
| type[BaseChatter]
|
||||
| type[BaseInterestCalculator]
|
||||
| type[BasePrompt]
|
||||
)
|
||||
|
||||
|
||||
@@ -183,6 +186,10 @@ class ComponentRegistry:
|
||||
assert isinstance(component_info, InterestCalculatorInfo)
|
||||
assert issubclass(component_class, BaseInterestCalculator)
|
||||
ret = self._register_interest_calculator_component(component_info, component_class)
|
||||
case ComponentType.PROMPT:
|
||||
assert isinstance(component_info, PromptInfo)
|
||||
assert issubclass(component_class, BasePrompt)
|
||||
ret = self._register_prompt_component(component_info, component_class)
|
||||
case _:
|
||||
logger.warning(f"未知组件类型: {component_type}")
|
||||
ret = False
|
||||
@@ -346,6 +353,31 @@ class ComponentRegistry:
|
||||
logger.debug(f"已注册InterestCalculator组件: {calculator_name}")
|
||||
return True
|
||||
|
||||
def _register_prompt_component(
|
||||
self, prompt_info: PromptInfo, prompt_class: "ComponentClassType"
|
||||
) -> bool:
|
||||
"""注册Prompt组件到Prompt特定注册表"""
|
||||
prompt_name = prompt_info.name
|
||||
if not prompt_name:
|
||||
logger.error(f"Prompt组件 {prompt_class.__name__} 必须指定名称")
|
||||
return False
|
||||
|
||||
if not hasattr(self, "_prompt_registry"):
|
||||
self._prompt_registry: dict[str, type[BasePrompt]] = {}
|
||||
if not hasattr(self, "_enabled_prompt_registry"):
|
||||
self._enabled_prompt_registry: dict[str, type[BasePrompt]] = {}
|
||||
|
||||
_assign_plugin_attrs(
|
||||
prompt_class, prompt_info.plugin_name, self.get_plugin_config(prompt_info.plugin_name) or {}
|
||||
)
|
||||
self._prompt_registry[prompt_name] = prompt_class # type: ignore
|
||||
|
||||
if prompt_info.enabled:
|
||||
self._enabled_prompt_registry[prompt_name] = prompt_class # type: ignore
|
||||
|
||||
logger.debug(f"已注册Prompt组件: {prompt_name}")
|
||||
return True
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
|
||||
@@ -580,7 +612,17 @@ class ComponentRegistry:
|
||||
component_name: str,
|
||||
component_type: ComponentType | None = None,
|
||||
) -> (
|
||||
type[BaseCommand | BaseAction | BaseEventHandler | BaseTool | PlusCommand | BaseChatter | BaseInterestCalculator] | None
|
||||
type[
|
||||
BaseCommand
|
||||
| BaseAction
|
||||
| BaseEventHandler
|
||||
| BaseTool
|
||||
| PlusCommand
|
||||
| BaseChatter
|
||||
| BaseInterestCalculator
|
||||
| BasePrompt
|
||||
]
|
||||
| None
|
||||
):
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
@@ -829,6 +871,7 @@ class ComponentRegistry:
|
||||
events_handlers: int = 0
|
||||
plus_command_components: int = 0
|
||||
chatter_components: int = 0
|
||||
prompt_components: int = 0
|
||||
for component in self._components.values():
|
||||
if component.component_type == ComponentType.ACTION:
|
||||
action_components += 1
|
||||
@@ -842,6 +885,8 @@ class ComponentRegistry:
|
||||
plus_command_components += 1
|
||||
elif component.component_type == ComponentType.CHATTER:
|
||||
chatter_components += 1
|
||||
elif component.component_type == ComponentType.PROMPT:
|
||||
prompt_components += 1
|
||||
return {
|
||||
"action_components": action_components,
|
||||
"command_components": command_components,
|
||||
@@ -849,6 +894,7 @@ class ComponentRegistry:
|
||||
"event_handlers": events_handlers,
|
||||
"plus_command_components": plus_command_components,
|
||||
"chatter_components": chatter_components,
|
||||
"prompt_components": prompt_components,
|
||||
"total_components": len(self._components),
|
||||
"total_plugins": len(self._plugins),
|
||||
"components_by_type": {
|
||||
|
||||
@@ -358,13 +358,14 @@ class PluginManager:
|
||||
event_handler_count = stats.get("event_handlers", 0)
|
||||
plus_command_count = stats.get("plus_command_components", 0)
|
||||
chatter_count = stats.get("chatter_components", 0)
|
||||
prompt_count = stats.get("prompt_components", 0)
|
||||
total_components = stats.get("total_components", 0)
|
||||
|
||||
# 📋 显示插件加载总览
|
||||
if total_registered > 0:
|
||||
logger.info("🎉 插件系统加载完成!")
|
||||
logger.info(
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count})"
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count}, Prompt: {prompt_count})"
|
||||
)
|
||||
|
||||
# 显示详细的插件列表
|
||||
@@ -402,6 +403,9 @@ class PluginManager:
|
||||
plus_command_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.PLUS_COMMAND
|
||||
]
|
||||
prompt_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.PROMPT
|
||||
]
|
||||
|
||||
if action_components:
|
||||
action_details = [format_component(c) for c in action_components]
|
||||
@@ -425,6 +429,9 @@ class PluginManager:
|
||||
if event_handler_components:
|
||||
event_handler_details = [format_component(c) for c in event_handler_components]
|
||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_details)}")
|
||||
if prompt_components:
|
||||
prompt_details = [format_component(c) for c in prompt_components]
|
||||
logger.info(f" 📝 Prompt组件: {', '.join(prompt_details)}")
|
||||
|
||||
# 权限节点信息
|
||||
if plugin_instance := self.loaded_plugins.get(plugin_name):
|
||||
|
||||
@@ -155,87 +155,22 @@ class ChatterPlanFilter:
|
||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||
|
||||
schedule_block = ""
|
||||
# 优先检查是否被吵醒
|
||||
|
||||
angry_prompt_addition = ""
|
||||
try:
|
||||
from src.plugins.built_in.sleep_system.api import get_wakeup_manager
|
||||
wakeup_mgr = get_wakeup_manager()
|
||||
except ImportError:
|
||||
logger.debug("无法导入睡眠系统API,将跳过相关检查。")
|
||||
wakeup_mgr = None
|
||||
|
||||
if wakeup_mgr:
|
||||
|
||||
# 双重检查确保愤怒状态不会丢失
|
||||
# 检查1: 直接从 wakeup_manager 获取
|
||||
if wakeup_mgr.is_in_angry_state():
|
||||
angry_prompt_addition = wakeup_mgr.get_angry_prompt_addition()
|
||||
|
||||
# 检查2: 如果上面没获取到,再从 mood_manager 确认
|
||||
if not angry_prompt_addition:
|
||||
chat_mood_for_check = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||
if chat_mood_for_check.is_angry_from_wakeup:
|
||||
angry_prompt_addition = global_config.sleep_system.angry_prompt
|
||||
|
||||
if angry_prompt_addition:
|
||||
schedule_block = angry_prompt_addition
|
||||
elif global_config.planning_system.schedule_enable:
|
||||
if global_config.planning_system.schedule_enable:
|
||||
if activity_info := schedule_manager.get_current_activity():
|
||||
activity = activity_info.get("activity", "未知活动")
|
||||
schedule_block = f"你当前正在:{activity},但注意它与群聊的聊天无关。"
|
||||
|
||||
mood_block = ""
|
||||
# 如果被吵醒,则心情也是愤怒的,不需要另外的情绪模块
|
||||
if not angry_prompt_addition and global_config.mood.enable_mood:
|
||||
# 需要情绪模块打开才能获得情绪,否则会引发报错
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||
mood_block = f"你现在的心情是:{chat_mood.mood_state}"
|
||||
|
||||
if plan.mode == ChatMode.PROACTIVE:
|
||||
long_term_memory_block = await self._get_long_term_memory_context()
|
||||
|
||||
chat_content_block, message_id_list = await build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
|
||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
)
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
|
||||
prompt = prompt_template.format(
|
||||
time_block=time_block,
|
||||
identity_block=identity_block,
|
||||
schedule_block=schedule_block,
|
||||
mood_block=mood_block,
|
||||
long_term_memory_block=long_term_memory_block,
|
||||
chat_content_block=chat_content_block or "最近没有聊天内容。",
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
)
|
||||
return prompt, message_id_list
|
||||
|
||||
# 构建已读/未读历史消息
|
||||
read_history_block, unread_history_block, message_id_list = await self._build_read_unread_history_blocks(
|
||||
plan
|
||||
)
|
||||
|
||||
# 为了兼容性,保留原有的chat_content_block
|
||||
chat_content_block, _ = await build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
@@ -285,7 +220,7 @@ class ChatterPlanFilter:
|
||||
is_group_chat = plan.chat_type == ChatType.GROUP
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
if not is_group_chat and plan.target_info:
|
||||
chat_target_name = plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方"
|
||||
chat_target_name = plan.target_info.person_name or plan.target_info.user_nickname or "对方"
|
||||
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
||||
|
||||
action_options_block = await self._build_action_options(plan.available_actions)
|
||||
|
||||
@@ -9,7 +9,7 @@ from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import Plan, TargetPersonInfo
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType, ComponentType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
|
||||
@@ -55,6 +55,11 @@ class ChatterPlanGenerator:
|
||||
try:
|
||||
# 获取聊天类型和目标信息
|
||||
chat_type, target_info = await get_chat_type_and_target_info(self.chat_id)
|
||||
if chat_type:
|
||||
chat_type = ChatType.GROUP
|
||||
else:
|
||||
#遇到未知类型也当私聊处理
|
||||
chat_type = ChatType.PRIVATE
|
||||
|
||||
# 获取可用动作列表
|
||||
available_actions = await self._get_available_actions(chat_type, mode)
|
||||
@@ -62,12 +67,16 @@ class ChatterPlanGenerator:
|
||||
# 获取聊天历史记录
|
||||
recent_messages = await self._get_recent_messages()
|
||||
|
||||
# 构建计划对象
|
||||
# 使用 target_info 字典创建 TargetPersonInfo 实例
|
||||
target_person_info = TargetPersonInfo(**target_info) if target_info else TargetPersonInfo()
|
||||
|
||||
# 构建计划对象
|
||||
plan = Plan(
|
||||
chat_id=self.chat_id,
|
||||
chat_type=chat_type,
|
||||
mode=mode,
|
||||
target_info=target_info,
|
||||
target_info=target_person_info,
|
||||
available_actions=available_actions,
|
||||
chat_history=recent_messages,
|
||||
)
|
||||
@@ -77,6 +86,7 @@ class ChatterPlanGenerator:
|
||||
except Exception:
|
||||
# 如果生成失败,返回一个基本的空计划
|
||||
return Plan(
|
||||
chat_type = ChatType.PRIVATE,#空计划默认当成私聊
|
||||
chat_id=self.chat_id,
|
||||
mode=mode,
|
||||
target_info=TargetPersonInfo(),
|
||||
@@ -124,7 +134,7 @@ class ChatterPlanGenerator:
|
||||
try:
|
||||
# 获取最近的消息记录
|
||||
raw_messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_id, timestamp=time.time(), limit=global_config.memory.short_memory_length
|
||||
chat_id=self.chat_id, timestamp=time.time(), limit=global_config.chat.max_context_size
|
||||
)
|
||||
|
||||
# 转换为 DatabaseMessages 对象
|
||||
|
||||
@@ -70,6 +70,7 @@ class ChatterActionPlanner:
|
||||
"replies_generated": 0,
|
||||
"other_actions_executed": 0,
|
||||
}
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
async def plan(self, context: "StreamContext | None" = None) -> tuple[list[dict[str, Any]], Any | None]:
|
||||
"""
|
||||
@@ -157,7 +158,9 @@ class ChatterActionPlanner:
|
||||
)
|
||||
|
||||
if interest_updates:
|
||||
asyncio.create_task(self._commit_interest_updates(interest_updates))
|
||||
task = asyncio.create_task(self._commit_interest_updates(interest_updates))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._handle_task_result)
|
||||
|
||||
# 检查兴趣度是否达到非回复动作阈值
|
||||
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
@@ -266,6 +269,17 @@ class ChatterActionPlanner:
|
||||
|
||||
return final_actions_dict, final_target_message_dict
|
||||
|
||||
def _handle_task_result(self, task: asyncio.Task) -> None:
|
||||
"""处理后台任务的结果,记录异常。"""
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
pass # 任务被取消是正常现象
|
||||
except Exception as e:
|
||||
logger.error(f"后台任务执行失败: {e}", exc_info=True)
|
||||
finally:
|
||||
self._background_tasks.discard(task)
|
||||
|
||||
def get_planner_stats(self) -> dict[str, Any]:
|
||||
"""获取规划器统计"""
|
||||
return self.planner_stats.copy()
|
||||
|
||||
Reference in New Issue
Block a user