This commit is contained in:
minecraft1024a
2025-09-06 10:43:10 +08:00
61 changed files with 1250 additions and 2209 deletions

View File

@@ -1,5 +0,0 @@
from .config import global_config
__all__ = [
"global_config",
]

View File

@@ -1,151 +0,0 @@
import os
from dataclasses import dataclass
from datetime import datetime
import tomlkit
import shutil
from tomlkit import TOMLDocument
from tomlkit.items import Table
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from rich.traceback import install
from .config_base import ConfigBase
from .official_configs import (
DebugConfig,
MaiBotServerConfig,
NapcatServerConfig,
NicknameConfig,
SlicingConfig,
VoiceConfig,
)
install(extra_lines=3)
TEMPLATE_DIR = "plugins/napcat_adapter_plugin/template"
CONFIG_DIR = "plugins/napcat_adapter_plugin/config"
OLD_CONFIG_DIR = "plugins/napcat_adapter_plugin/config/old"
def ensure_config_directories():
"""确保配置目录存在"""
os.makedirs(CONFIG_DIR, exist_ok=True)
os.makedirs(OLD_CONFIG_DIR, exist_ok=True)
def update_config():
"""更新配置文件,统一使用 config/old 目录进行备份"""
# 确保目录存在
ensure_config_directories()
# 定义文件路径
template_path = f"{TEMPLATE_DIR}/template_config.toml"
config_path = f"{CONFIG_DIR}/config.toml"
# 检查配置文件是否存在
if not os.path.exists(config_path):
logger.info("主配置文件不存在,从模板创建新配置")
shutil.copy2(template_path, config_path)
logger.info(f"已创建新配置文件: {config_path}")
logger.info("程序将退出,请检查配置文件后重启")
# 读取配置文件和模板文件
with open(config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
with open(template_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version")
new_version = new_config["inner"].get("version")
if old_version and new_version and old_version == new_version:
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
else:
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建备份文件
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = os.path.join(OLD_CONFIG_DIR, f"config.toml.bak.{timestamp}")
# 备份旧配置文件
shutil.copy2(config_path, backup_path)
logger.info(f"已备份旧配置文件到: {backup_path}")
# 复制模板文件到配置目录
shutil.copy2(template_path, config_path)
logger.info(f"已创建新配置文件: {config_path}")
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
"""将source字典的值更新到target字典中如果target中存在相同的键"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
logger.info("开始合并新旧配置...")
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
@dataclass
class Config(ConfigBase):
"""总配置类"""
nickname: NicknameConfig
napcat_server: NapcatServerConfig
maibot_server: MaiBotServerConfig
voice: VoiceConfig
slicing: SlicingConfig
debug: DebugConfig
def load_config(config_path: str) -> Config:
"""
加载配置文件
:param config_path: 配置文件路径
:return: Config对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建Config对象
try:
return Config.from_dict(config_data)
except Exception as e:
logger.critical("配置文件解析失败")
raise e
# 更新配置
update_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=f"{CONFIG_DIR}/config.toml")
logger.info("非常的新鲜,非常的美味!")

View File

@@ -1,359 +0,0 @@
import asyncio
from dataclasses import dataclass, field
from typing import Literal, Optional
from pathlib import Path
import tomlkit
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from .config_base import ConfigBase
from .config_utils import create_config_from_template, create_default_config_dict
@dataclass
class FeaturesConfig(ConfigBase):
"""功能配置类"""
group_list_type: Literal["whitelist", "blacklist"] = "whitelist"
"""群聊列表类型 白名单/黑名单"""
group_list: list[int] = field(default_factory=list)
"""群聊列表"""
private_list_type: Literal["whitelist", "blacklist"] = "whitelist"
"""私聊列表类型 白名单/黑名单"""
private_list: list[int] = field(default_factory=list)
"""私聊列表"""
ban_user_id: list[int] = field(default_factory=list)
"""被封禁的用户ID列表封禁后将无法与其进行交互"""
ban_qq_bot: bool = False
"""是否屏蔽QQ官方机器人若为True则所有QQ官方机器人将无法与MaiMCore进行交互"""
enable_poke: bool = True
"""是否启用戳一戳功能"""
ignore_non_self_poke: bool = False
"""是否无视不是针对自己的戳一戳"""
poke_debounce_seconds: int = 3
"""戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"""
enable_reply_at: bool = True
"""是否启用引用回复时艾特用户的功能"""
reply_at_rate: float = 0.5
"""引用回复时艾特用户的几率 (0.0 ~ 1.0)"""
enable_video_analysis: bool = True
"""是否启用视频识别功能"""
max_video_size_mb: int = 100
"""视频文件最大大小限制MB"""
download_timeout: int = 60
"""视频下载超时时间(秒)"""
supported_formats: list[str] = field(default_factory=lambda: ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"])
"""支持的视频格式"""
# 消息缓冲配置
enable_message_buffer: bool = True
"""是否启用消息缓冲合并功能"""
message_buffer_enable_group: bool = True
"""是否启用群消息缓冲合并"""
message_buffer_enable_private: bool = True
"""是否启用私聊消息缓冲合并"""
message_buffer_interval: float = 3.0
"""消息合并间隔时间(秒),在此时间内的连续消息将被合并"""
message_buffer_initial_delay: float = 0.5
"""消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"""
message_buffer_max_components: int = 50
"""单个会话最大缓冲消息组件数量,超过此数量将强制合并"""
message_buffer_block_prefixes: list[str] = field(default_factory=lambda: ["/", "!", "", ".", "", "#", "%"])
"""消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"""
class FeaturesManager:
"""功能管理器,支持热重载"""
def __init__(self, config_path: str = "plugins/napcat_adapter_plugin/config/features.toml"):
self.config_path = Path(config_path)
self.config: Optional[FeaturesConfig] = None
self._file_watcher_task: Optional[asyncio.Task] = None
self._last_modified: Optional[float] = None
self._callbacks: list = []
def add_reload_callback(self, callback):
"""添加配置重载回调函数"""
self._callbacks.append(callback)
def remove_reload_callback(self, callback):
"""移除配置重载回调函数"""
if callback in self._callbacks:
self._callbacks.remove(callback)
async def _notify_callbacks(self):
"""通知所有回调函数配置已重载"""
for callback in self._callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(self.config)
else:
callback(self.config)
except Exception as e:
logger.error(f"配置重载回调执行失败: {e}")
def load_config(self) -> FeaturesConfig:
"""加载功能配置文件"""
try:
# 检查配置文件是否存在,如果不存在则创建并退出程序
if not self.config_path.exists():
logger.info(f"功能配置文件不存在: {self.config_path}")
self._create_default_config()
# 配置文件创建后程序应该退出,让用户检查配置
logger.info("程序将退出,请检查功能配置文件后重启")
quit(0)
with open(self.config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
self.config = FeaturesConfig.from_dict(config_data)
self._last_modified = self.config_path.stat().st_mtime
logger.info(f"功能配置加载成功: {self.config_path}")
return self.config
except Exception as e:
logger.error(f"功能配置加载失败: {e}")
logger.critical("无法加载功能配置文件,程序退出")
quit(1)
def _create_default_config(self):
"""创建默认功能配置文件"""
template_path = "template/features_template.toml"
# 尝试从模板创建配置文件
if create_config_from_template(
str(self.config_path),
template_path,
"功能配置文件",
should_exit=False, # 不在这里退出,由调用方决定
):
return
# 如果模板文件不存在,创建基本配置
logger.info("模板文件不存在,创建基本功能配置")
default_config = {
"group_list_type": "whitelist",
"group_list": [],
"private_list_type": "whitelist",
"private_list": [],
"ban_user_id": [],
"ban_qq_bot": False,
"enable_poke": True,
"ignore_non_self_poke": False,
"poke_debounce_seconds": 3,
"enable_reply_at": True,
"reply_at_rate": 0.5,
"enable_video_analysis": True,
"max_video_size_mb": 100,
"download_timeout": 60,
"supported_formats": ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"],
# 消息缓冲配置
"enable_message_buffer": True,
"message_buffer_enable_group": True,
"message_buffer_enable_private": True,
"message_buffer_interval": 3.0,
"message_buffer_initial_delay": 0.5,
"message_buffer_max_components": 50,
"message_buffer_block_prefixes": ["/", "!", "", ".", "", "#", "%"],
}
if not create_default_config_dict(default_config, str(self.config_path), "功能配置文件"):
logger.critical("无法创建功能配置文件")
quit(1)
async def reload_config(self) -> bool:
"""重新加载配置文件"""
try:
if not self.config_path.exists():
logger.warning(f"功能配置文件不存在,无法重载: {self.config_path}")
return False
current_modified = self.config_path.stat().st_mtime
if self._last_modified and current_modified <= self._last_modified:
return False # 文件未修改
old_config = self.config
new_config = self.load_config()
# 检查配置是否真的发生了变化
if old_config and self._configs_equal(old_config, new_config):
return False
logger.info("功能配置已重载")
await self._notify_callbacks()
return True
except Exception as e:
logger.error(f"功能配置重载失败: {e}")
return False
def _configs_equal(self, config1: FeaturesConfig, config2: FeaturesConfig) -> bool:
"""比较两个配置是否相等"""
return (
config1.group_list_type == config2.group_list_type
and set(config1.group_list) == set(config2.group_list)
and config1.private_list_type == config2.private_list_type
and set(config1.private_list) == set(config2.private_list)
and set(config1.ban_user_id) == set(config2.ban_user_id)
and config1.ban_qq_bot == config2.ban_qq_bot
and config1.enable_poke == config2.enable_poke
and config1.ignore_non_self_poke == config2.ignore_non_self_poke
and config1.poke_debounce_seconds == config2.poke_debounce_seconds
and config1.enable_reply_at == config2.enable_reply_at
and config1.reply_at_rate == config2.reply_at_rate
and config1.enable_video_analysis == config2.enable_video_analysis
and config1.max_video_size_mb == config2.max_video_size_mb
and config1.download_timeout == config2.download_timeout
and set(config1.supported_formats) == set(config2.supported_formats)
and
# 消息缓冲配置比较
config1.enable_message_buffer == config2.enable_message_buffer
and config1.message_buffer_enable_group == config2.message_buffer_enable_group
and config1.message_buffer_enable_private == config2.message_buffer_enable_private
and config1.message_buffer_interval == config2.message_buffer_interval
and config1.message_buffer_initial_delay == config2.message_buffer_initial_delay
and config1.message_buffer_max_components == config2.message_buffer_max_components
and set(config1.message_buffer_block_prefixes) == set(config2.message_buffer_block_prefixes)
)
async def start_file_watcher(self, check_interval: float = 1.0):
"""启动文件监控,定期检查配置文件变化"""
if self._file_watcher_task and not self._file_watcher_task.done():
logger.warning("文件监控已在运行")
return
self._file_watcher_task = asyncio.create_task(self._file_watcher_loop(check_interval))
logger.info(f"功能配置文件监控已启动,检查间隔: {check_interval}")
async def stop_file_watcher(self):
"""停止文件监控"""
if self._file_watcher_task and not self._file_watcher_task.done():
self._file_watcher_task.cancel()
try:
await self._file_watcher_task
except asyncio.CancelledError:
pass
logger.info("功能配置文件监控已停止")
async def _file_watcher_loop(self, check_interval: float):
"""文件监控循环"""
while True:
try:
await asyncio.sleep(check_interval)
await self.reload_config()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"文件监控循环出错: {e}")
await asyncio.sleep(check_interval)
def get_config(self) -> FeaturesConfig:
"""获取当前功能配置"""
if self.config is None:
return self.load_config()
return self.config
def is_group_allowed(self, group_id: int) -> bool:
"""检查群聊是否被允许"""
config = self.get_config()
if config.group_list_type == "whitelist":
return group_id in config.group_list
else: # blacklist
return group_id not in config.group_list
def is_private_allowed(self, user_id: int) -> bool:
"""检查私聊是否被允许"""
config = self.get_config()
if config.private_list_type == "whitelist":
return user_id in config.private_list
else: # blacklist
return user_id not in config.private_list
def is_user_banned(self, user_id: int) -> bool:
"""检查用户是否被全局禁止"""
config = self.get_config()
return user_id in config.ban_user_id
def is_qq_bot_banned(self) -> bool:
"""检查是否禁止QQ官方机器人"""
config = self.get_config()
return config.ban_qq_bot
def is_poke_enabled(self) -> bool:
"""检查戳一戳功能是否启用"""
config = self.get_config()
return config.enable_poke
def is_non_self_poke_ignored(self) -> bool:
"""检查是否忽略非自己戳一戳"""
config = self.get_config()
return config.ignore_non_self_poke
def is_message_buffer_enabled(self) -> bool:
"""检查消息缓冲功能是否启用"""
config = self.get_config()
return config.enable_message_buffer
def is_message_buffer_group_enabled(self) -> bool:
"""检查群消息缓冲是否启用"""
config = self.get_config()
return config.message_buffer_enable_group
def is_message_buffer_private_enabled(self) -> bool:
"""检查私聊消息缓冲是否启用"""
config = self.get_config()
return config.message_buffer_enable_private
def get_message_buffer_interval(self) -> float:
"""获取消息缓冲间隔时间"""
config = self.get_config()
return config.message_buffer_interval
def get_message_buffer_initial_delay(self) -> float:
"""获取消息缓冲初始延迟"""
config = self.get_config()
return config.message_buffer_initial_delay
def get_message_buffer_max_components(self) -> int:
"""获取消息缓冲最大组件数量"""
config = self.get_config()
return config.message_buffer_max_components
def is_message_buffer_group_enabled(self) -> bool:
"""检查是否启用群聊消息缓冲"""
config = self.get_config()
return config.message_buffer_enable_group
def is_message_buffer_private_enabled(self) -> bool:
"""检查是否启用私聊消息缓冲"""
config = self.get_config()
return config.message_buffer_enable_private
def get_message_buffer_block_prefixes(self) -> list[str]:
"""获取消息缓冲屏蔽前缀列表"""
config = self.get_config()
return config.message_buffer_block_prefixes
# 全局功能管理器实例
features_manager = FeaturesManager()

View File

@@ -1,26 +0,0 @@
from maim_message import Router, RouteConfig, TargetConfig
from .config import global_config
from src.common.logger import get_logger
from .send_handler import send_handler
logger = get_logger("napcat_adapter")
route_config = RouteConfig(
route_config={
global_config.maibot_server.platform_name: TargetConfig(
url=f"ws://{global_config.maibot_server.host}:{global_config.maibot_server.port}/ws",
token=None,
)
}
)
router = Router(route_config)
async def mmc_start_com():
logger.info("正在连接MaiBot")
router.register_class_handler(send_handler.handle_message)
await router.run()
async def mmc_stop_com():
await router.stop()

View File

@@ -0,0 +1,61 @@
"""
更新Prompt类导入脚本
将旧的prompt_builder.Prompt导入更新为unified_prompt.Prompt
"""
import os
import re
from pathlib import Path
# 需要更新的文件列表
files_to_update = [
"src/person_info/relationship_fetcher.py",
"src/mood/mood_manager.py",
"src/mais4u/mais4u_chat/body_emotion_action_manager.py",
"src/chat/express/expression_learner.py",
"src/chat/planner_actions/planner.py",
"src/mais4u/mais4u_chat/s4u_prompt.py",
"src/chat/message_receive/bot.py",
"src/chat/replyer/default_generator.py",
"src/chat/express/expression_selector.py",
"src/mais4u/mai_think.py",
"src/mais4u/mais4u_chat/s4u_mood_manager.py",
"src/plugin_system/core/tool_use.py",
"src/chat/memory_system/memory_activator.py",
"src/chat/utils/smart_prompt.py"
]
def update_prompt_imports(file_path):
"""更新文件中的Prompt导入"""
if not os.path.exists(file_path):
print(f"文件不存在: {file_path}")
return False
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 替换导入语句
old_import = "from src.chat.utils.prompt_builder import Prompt, global_prompt_manager"
new_import = "from src.chat.utils.prompt import Prompt, global_prompt_manager"
if old_import in content:
new_content = content.replace(old_import, new_import)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(new_content)
print(f"已更新: {file_path}")
return True
else:
print(f"无需更新: {file_path}")
return False
def main():
"""主函数"""
updated_count = 0
for file_path in files_to_update:
if update_prompt_imports(file_path):
updated_count += 1
print(f"\n更新完成!共更新了 {updated_count} 个文件")
if __name__ == "__main__":
main()

View File

@@ -149,7 +149,7 @@ class CycleProcessor:
logger.info(f"{self.log_prefix} 开始第{self.context.cycle_counter}次思考")
if ENABLE_S4U:
await send_typing()
await send_typing(self.context.chat_stream.user_info.user_id)
loop_start_time = time.time()

View File

@@ -121,7 +121,7 @@ class CycleDetail:
self.loop_action_info = loop_info["loop_action_info"]
async def send_typing():
async def send_typing(user_id):
"""
发送打字状态指示
@@ -139,6 +139,11 @@ async def send_typing():
group_info=group_info,
)
from plugin_system.core.event_manager import event_manager
from src.plugins.built_in.napcat_adapter_plugin.event_types import NapcatEvent
# 设置正在输入状态
await event_manager.trigger_event(NapcatEvent.PERSONAL.SET_INPUT_STATUS,user_id=user_id,event_type=1)
await send_api.custom_to_stream(
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
)

View File

@@ -12,7 +12,7 @@ from .hfc_context import HfcContext
# 导入反注入系统
from src.chat.antipromptinjector import get_anti_injector
from src.chat.antipromptinjector.types import ProcessResult
from src.chat.utils.prompt_builder import Prompt
from src.chat.utils.prompt import Prompt
logger = get_logger("hfc")
anti_injector_logger = get_logger("anti_injector")

View File

@@ -13,7 +13,7 @@ from src.common.database.sqlalchemy_models import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager

View File

@@ -11,7 +11,7 @@ from src.config.config import global_config, model_config
from src.common.logger import get_logger
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_db_session
logger = get_logger("expression_selector")

View File

@@ -8,7 +8,7 @@ from datetime import datetime
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.memory_system.Hippocampus import hippocampus_manager

View File

@@ -12,7 +12,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor

View File

@@ -9,7 +9,7 @@ from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_actions,
get_actions_by_timestamp_with_chat,

View File

@@ -1,6 +1,6 @@
"""
默认回复生成器 - 集成SmartPrompt系统
使用重构后的SmartPrompt系统替换原有的复杂提示词构建逻辑
默认回复生成器 - 集成统一Prompt系统
使用重构后的统一Prompt系统替换原有的复杂提示词构建逻辑
"""
import traceback
@@ -11,7 +11,6 @@ import re
from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
from src.chat.utils.prompt_utils import PromptUtils
from src.mais4u.mai_think import mai_thinking_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
@@ -22,7 +21,7 @@ from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
@@ -37,8 +36,8 @@ from src.person_info.person_info import get_person_info_manager
from src.plugin_system.base.component_types import ActionInfo, EventType
from src.plugin_system.apis import llm_api
# 导入新的智能Prompt系统
from src.chat.utils.smart_prompt import SmartPrompt, SmartPromptParameters
# 导入新的统一Prompt系统
from src.chat.utils.prompt import Prompt, PromptParameters
logger = get_logger("replyer")
@@ -598,7 +597,8 @@ class DefaultReplyer:
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
"""解析回复目标消息 - 使用共享工具"""
return PromptUtils.parse_reply_target(target_message)
from src.chat.utils.prompt import Prompt
return Prompt.parse_reply_target(target_message)
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
"""构建关键词反应提示
@@ -873,7 +873,8 @@ class DefaultReplyer:
target_user_info = None
if sender:
target_user_info = await person_info_manager.get_person_info_by_name(sender)
from src.chat.utils.prompt import Prompt
# 并行执行六个构建任务
task_results = await asyncio.gather(
self._time_and_run_task(
@@ -886,7 +887,7 @@ class DefaultReplyer:
),
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
self._time_and_run_task(
PromptUtils.build_cross_context(chat_id, target_user_info, global_config.personality.prompt_mode),
Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info),
"cross_context",
),
)
@@ -970,8 +971,8 @@ class DefaultReplyer:
# 根据配置选择模板
current_prompt_mode = global_config.personality.prompt_mode
# 使用重构后的SmartPrompt系统
prompt_params = SmartPromptParameters(
# 使用新的统一Prompt系统 - 创建PromptParameters
prompt_parameters = PromptParameters(
chat_id=chat_id,
is_group_chat=is_group_chat,
sender=sender,
@@ -1004,12 +1005,19 @@ class DefaultReplyer:
action_descriptions=action_descriptions,
)
# 使用重构后的SmartPrompt系统
smart_prompt = SmartPrompt(
template_name=None, # 由current_prompt_mode自动选择
parameters=prompt_params,
)
prompt_text = await smart_prompt.build_prompt()
# 使用新的统一Prompt系统 - 使用正确的模板名称
template_name = None
if current_prompt_mode == "s4u":
template_name = "s4u_style_prompt"
elif current_prompt_mode == "normal":
template_name = "normal_style_prompt"
elif current_prompt_mode == "minimal":
template_name = "default_expressor_prompt"
# 获取模板内容
template_prompt = await global_prompt_manager.get_prompt_async(template_name)
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
prompt_text = await prompt.build()
return prompt_text
@@ -1110,8 +1118,8 @@ class DefaultReplyer:
template_name = "default_expressor_prompt"
# 使用重构后的SmartPrompt系统 - Expressor模式
prompt_params = SmartPromptParameters(
# 使用新的统一Prompt系统 - Expressor模式创建PromptParameters
prompt_parameters = PromptParameters(
chat_id=chat_id,
is_group_chat=is_group_chat,
sender=sender,
@@ -1131,8 +1139,10 @@ class DefaultReplyer:
relation_info_block=relation_info,
)
smart_prompt = SmartPrompt(parameters=prompt_params)
prompt_text = await smart_prompt.build_prompt()
# 使用新的统一Prompt系统 - Expressor模式
template_prompt = await global_prompt_manager.get_prompt_async("default_expressor_prompt")
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
prompt_text = await prompt.build()
return prompt_text

823
src/chat/utils/prompt.py Normal file
View File

@@ -0,0 +1,823 @@
"""
统一提示词系统 - 合并模板管理和智能构建功能
将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类
"""
import re
import asyncio
import time
import contextvars
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Literal, Tuple
from contextlib import asynccontextmanager
from rich.traceback import install
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
install(extra_lines=3)
logger = get_logger("unified_prompt")
@dataclass
class PromptParameters:
"""统一提示词参数系统"""
# 基础参数
chat_id: str = ""
is_group_chat: bool = False
sender: str = ""
target: str = ""
reply_to: str = ""
extra_info: str = ""
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
# 功能开关
enable_tool: bool = True
enable_memory: bool = True
enable_expression: bool = True
enable_relation: bool = True
enable_cross_context: bool = True
enable_knowledge: bool = True
# 性能控制
max_context_messages: int = 50
# 调试选项
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: Optional[Dict[str, Any]] = None
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: Optional[Dict[str, Any]] = None
# 已构建的内容块
expression_habits_block: str = ""
relation_info_block: str = ""
memory_block: str = ""
tool_info_block: str = ""
knowledge_prompt: str = ""
cross_context_block: str = ""
# 其他内容块
keywords_reaction_prompt: str = ""
extra_info_block: str = ""
time_block: str = ""
identity_block: str = ""
schedule_block: str = ""
moderation_prompt_block: str = ""
reply_target_block: str = ""
mood_prompt: str = ""
action_descriptions: str = ""
# 可用动作信息
available_actions: Optional[Dict[str, Any]] = None
def validate(self) -> List[str]:
"""参数验证"""
errors = []
if not self.chat_id:
errors.append("chat_id不能为空")
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
errors.append("prompt_mode必须是's4u''normal''minimal'")
if self.max_context_messages <= 0:
errors.append("max_context_messages必须大于0")
return errors
class PromptContext:
"""提示词上下文管理器"""
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock()
@property
def _current_context(self) -> Optional[str]:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value) # type: ignore
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
"""创建一个异步的临时提示模板作用域"""
if context_id is not None:
try:
await asyncio.wait_for(self._context_lock.acquire(), timeout=5.0)
try:
if context_id not in self._context_prompts:
self._context_prompts[context_id] = {}
finally:
self._context_lock.release()
except asyncio.TimeoutError:
logger.warning(f"获取上下文锁超时context_id: {context_id}")
context_id = None
previous_context = self._current_context
token = self._current_context_var.set(context_id) if context_id else None
else:
previous_context = self._current_context
token = None
try:
yield self
finally:
if context_id is not None and token is not None:
try:
self._current_context_var.reset(token)
except Exception as e:
logger.warning(f"恢复上下文时出错: {e}")
try:
self._current_context = previous_context
except Exception:
...
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板"""
async with self._context_lock:
current_context = self._current_context
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
if (
current_context
and current_context in self._context_prompts
and name in self._context_prompts[current_context]
):
return self._context_prompts[current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
if target_context := context_id or self._current_context:
if prompt.name:
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
class PromptManager:
"""统一提示词管理器"""
def __init__(self):
self._prompts = {}
self._counter = 0
self._context = PromptContext()
self._lock = asyncio.Lock()
@asynccontextmanager
async def async_message_scope(self, message_id: Optional[str] = None):
"""为消息处理创建异步临时作用域"""
async with self._context.async_scope(message_id):
yield self
async def get_prompt_async(self, name: str) -> "Prompt":
"""异步获取提示模板"""
context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None:
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
return context_prompt
async with self._lock:
if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]
def generate_name(self, template: str) -> str:
"""为未命名的prompt生成名称"""
self._counter += 1
return f"prompt_{self._counter}"
def register(self, prompt: "Prompt") -> None:
"""注册一个prompt"""
if not prompt.name:
prompt.name = self.generate_name(prompt.template)
self._prompts[prompt.name] = prompt
def add_prompt(self, name: str, fstr: str) -> "Prompt":
"""添加新提示模板"""
prompt = Prompt(fstr, name=name)
if prompt.name:
self._prompts[prompt.name] = prompt
return prompt
async def format_prompt(self, name: str, **kwargs) -> str:
"""格式化提示模板"""
prompt = await self.get_prompt_async(name)
result = prompt.format(**kwargs)
return result
# 全局单例
global_prompt_manager = PromptManager()
class Prompt:
"""
统一提示词类 - 合并模板管理和智能构建功能
真正的Prompt类支持模板管理和智能上下文构建
"""
# 临时标记,作为类常量
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
def __init__(
self,
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
should_register: bool = True
):
"""
初始化统一提示词
Args:
template: 提示词模板字符串
name: 提示词名称
parameters: 构建参数
should_register: 是否自动注册到全局管理器
"""
self.template = template
self.name = name
self.parameters = parameters or PromptParameters()
self.args = self._parse_template_args(template)
self._formatted_result = ""
# 预处理模板中的转义花括号
self._processed_template = self._process_escaped_braces(template)
# 自动注册
if should_register and not global_prompt_manager._context._current_context:
global_prompt_manager.register(self)
@staticmethod
def _process_escaped_braces(template) -> str:
"""处理模板中的转义花括号"""
if isinstance(template, list):
template = "\n".join(str(item) for item in template)
elif not isinstance(template, str):
template = str(template)
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
@staticmethod
def _restore_escaped_braces(template: str) -> str:
"""将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def _parse_template_args(self, template: str) -> List[str]:
"""解析模板参数"""
template_args = []
processed_template = self._process_escaped_braces(template)
result = re.findall(r"\{(.*?)}", processed_template)
for expr in result:
if expr and expr not in template_args:
template_args.append(expr)
return template_args
async def build(self) -> str:
"""
构建完整的提示词,包含智能上下文
Returns:
str: 构建完成的提示词文本
"""
# 参数验证
errors = self.parameters.validate()
if errors:
logger.error(f"参数验证失败: {', '.join(errors)}")
raise ValueError(f"参数验证失败: {', '.join(errors)}")
start_time = time.time()
try:
# 构建上下文数据
context_data = await self._build_context_data()
# 格式化模板
result = await self._format_with_context(context_data)
total_time = time.time() - start_time
logger.debug(f"Prompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
self._formatted_result = result
return result
except asyncio.TimeoutError as e:
logger.error(f"构建Prompt超时: {e}")
raise TimeoutError(f"构建Prompt超时: {e}")
except Exception as e:
logger.error(f"构建Prompt失败: {e}")
raise RuntimeError(f"构建Prompt失败: {e}")
async def _build_context_data(self) -> Dict[str, Any]:
"""构建智能上下文数据"""
# 并行执行所有构建任务
start_time = time.time()
timing_logs = {}
try:
# 准备构建任务
tasks = []
task_names = []
# 初始化预构建参数
pre_built_params = {}
if self.parameters.expression_habits_block:
pre_built_params["expression_habits_block"] = self.parameters.expression_habits_block
if self.parameters.relation_info_block:
pre_built_params["relation_info_block"] = self.parameters.relation_info_block
if self.parameters.memory_block:
pre_built_params["memory_block"] = self.parameters.memory_block
if self.parameters.tool_info_block:
pre_built_params["tool_info_block"] = self.parameters.tool_info_block
if self.parameters.knowledge_prompt:
pre_built_params["knowledge_prompt"] = self.parameters.knowledge_prompt
if self.parameters.cross_context_block:
pre_built_params["cross_context_block"] = self.parameters.cross_context_block
# 根据参数确定要构建的项
if self.parameters.enable_expression and not pre_built_params.get("expression_habits_block"):
tasks.append(self._build_expression_habits())
task_names.append("expression_habits")
if self.parameters.enable_memory and not pre_built_params.get("memory_block"):
tasks.append(self._build_memory_block())
task_names.append("memory_block")
if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"):
tasks.append(self._build_relation_info())
task_names.append("relation_info")
if self.parameters.enable_tool and not pre_built_params.get("tool_info_block"):
tasks.append(self._build_tool_info())
task_names.append("tool_info")
if self.parameters.enable_knowledge and not pre_built_params.get("knowledge_prompt"):
tasks.append(self._build_knowledge_info())
task_names.append("knowledge_info")
if self.parameters.enable_cross_context and not pre_built_params.get("cross_context_block"):
tasks.append(self._build_cross_context())
task_names.append("cross_context")
# 性能优化
base_timeout = 10.0
task_timeout = 2.0
timeout_seconds = min(
max(base_timeout, len(tasks) * task_timeout),
30.0,
)
max_concurrent_tasks = 5
if len(tasks) > max_concurrent_tasks:
results = []
for i in range(0, len(tasks), max_concurrent_tasks):
batch_tasks = tasks[i : i + max_concurrent_tasks]
batch_names = task_names[i : i + max_concurrent_tasks]
batch_results = await asyncio.wait_for(
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
)
results.extend(batch_results)
else:
results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
)
# 处理结果
context_data = {}
for i, result in enumerate(results):
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
if isinstance(result, Exception):
logger.error(f"构建任务{task_name}失败: {str(result)}")
elif isinstance(result, dict):
context_data.update(result)
# 添加预构建的参数
for key, value in pre_built_params.items():
if value:
context_data[key] = value
except asyncio.TimeoutError:
logger.error(f"构建超时 ({timeout_seconds}s)")
context_data = {}
for key, value in pre_built_params.items():
if value:
context_data[key] = value
# 构建聊天历史
if self.parameters.prompt_mode == "s4u":
await self._build_s4u_chat_context(context_data)
else:
await self._build_normal_chat_context(context_data)
# 补充基础信息
context_data.update({
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
"extra_info_block": self.parameters.extra_info_block,
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
"identity": self.parameters.identity_block,
"schedule_block": self.parameters.schedule_block,
"moderation_prompt": self.parameters.moderation_prompt_block,
"reply_target_block": self.parameters.reply_target_block,
"mood_state": self.parameters.mood_prompt,
"action_descriptions": self.parameters.action_descriptions,
})
total_time = time.time() - start_time
logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s")
return context_data
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
"""构建S4U模式的聊天上下文"""
if not self.parameters.message_list_before_now_long:
return
core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts(
self.parameters.message_list_before_now_long,
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
self.parameters.sender
)
context_data["core_dialogue_prompt"] = core_dialogue
context_data["background_dialogue_prompt"] = background_dialogue
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
"""构建normal模式的聊天上下文"""
if not self.parameters.chat_talking_prompt_short:
return
context_data["chat_info"] = f"""群里的聊天内容:
{self.parameters.chat_talking_prompt_short}"""
async def _build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
) -> Tuple[str, str]:
"""构建S4U风格的分离对话prompt"""
# 实现逻辑与原有SmartPromptBuilder相同
core_dialogue_list = []
bot_id = str(global_config.bot.qq_account)
for msg_dict in message_list_before_now:
try:
msg_user_id = str(msg_dict.get("user_id"))
reply_to = msg_dict.get("reply_to", "")
platform, reply_to_user_id = Prompt.parse_reply_target(reply_to)
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
core_dialogue_list.append(msg_dict)
except Exception as e:
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
# 构建背景对话 prompt
all_dialogue_prompt = ""
if message_list_before_now:
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
all_dialogue_prompt_str = build_readable_messages(
latest_25_msgs,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True,
)
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
# 构建核心对话 prompt
core_dialogue_prompt = ""
if core_dialogue_list:
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
if not has_bot_message:
core_dialogue_prompt = ""
else:
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :]
core_dialogue_prompt_str = build_readable_messages(
core_dialogue_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
core_dialogue_prompt = f"""--------------------------------
这是你和{sender}的对话,你们正在交流中:
{core_dialogue_prompt_str}
--------------------------------
"""
return core_dialogue_prompt, all_dialogue_prompt
async def _build_expression_habits(self) -> Dict[str, Any]:
"""构建表达习惯"""
# 简化的实现,完整实现需要导入相关模块
return {"expression_habits_block": ""}
async def _build_memory_block(self) -> Dict[str, Any]:
"""构建记忆块"""
# 简化的实现
return {"memory_block": ""}
async def _build_relation_info(self) -> Dict[str, Any]:
"""构建关系信息"""
try:
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
return {"relation_info_block": relation_info}
except Exception as e:
logger.error(f"构建关系信息失败: {e}")
return {"relation_info_block": ""}
async def _build_tool_info(self) -> Dict[str, Any]:
"""构建工具信息"""
# 简化的实现
return {"tool_info_block": ""}
async def _build_knowledge_info(self) -> Dict[str, Any]:
"""构建知识信息"""
# 简化的实现
return {"knowledge_prompt": ""}
async def _build_cross_context(self) -> Dict[str, Any]:
"""构建跨群上下文"""
try:
cross_context = await Prompt.build_cross_context(
self.parameters.chat_id, self.parameters.prompt_mode, self.parameters.target_user_info
)
return {"cross_context_block": cross_context}
except Exception as e:
logger.error(f"构建跨群上下文失败: {e}")
return {"cross_context_block": ""}
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
"""使用上下文数据格式化模板"""
if self.parameters.prompt_mode == "s4u":
params = self._prepare_s4u_params(context_data)
elif self.parameters.prompt_mode == "normal":
params = self._prepare_normal_params(context_data)
else:
params = self._prepare_default_params(context_data)
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备S4U模式的参数"""
return {
**context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""),
"tool_info_block": context_data.get("tool_info_block", ""),
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"sender_name": self.parameters.sender,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""),
"time_block": context_data.get("time_block", ""),
"core_dialogue_prompt": context_data.get("core_dialogue_prompt", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备Normal模式的参数"""
return {
**context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""),
"tool_info_block": context_data.get("tool_info_block", ""),
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"schedule_block": self.parameters.schedule_block or context_data.get("schedule_block", ""),
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"config_expression_style": global_config.personality.reply_style,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备默认模式的参数"""
return {
"expression_habits_block": context_data.get("expression_habits_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"chat_target": "",
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"chat_target_2": "",
"reply_target_block": context_data.get("reply_target_block", ""),
"raw_reply": self.parameters.target,
"reason": "",
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
def format(self, *args, **kwargs) -> str:
"""格式化模板,支持位置参数和关键字参数"""
try:
# 先用位置参数格式化
if args:
formatted_args = {}
for i in range(len(args)):
if i < len(self.args):
formatted_args[self.args[i]] = args[i]
processed_template = self._processed_template.format(**formatted_args)
else:
processed_template = self._processed_template
# 再用关键字参数格式化
if kwargs:
processed_template = processed_template.format(**kwargs)
# 将临时标记还原为实际的花括号
result = self._restore_escaped_braces(processed_template)
return result
except (IndexError, KeyError) as e:
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
def __str__(self) -> str:
"""返回格式化后的结果或原始模板"""
return self._formatted_result if self._formatted_result else self.template
def __repr__(self) -> str:
"""返回提示词的表示形式"""
return f"Prompt(template='{self.template}', name='{self.name}')"
# =============================================================================
# PromptUtils功能迁移 - 静态工具方法
# 这些方法原来在PromptUtils类中现在作为Prompt类的静态方法
# 解决循环导入问题
# =============================================================================
@staticmethod
def parse_reply_target(target_message: str) -> Tuple[str, str]:
"""
解析回复目标消息 - 统一实现
Args:
target_message: 目标消息,格式为 "发送者:消息内容""发送者:消息内容"
Returns:
Tuple[str, str]: (发送者名称, 消息内容)
"""
sender = ""
target = ""
# 添加None检查防止NoneType错误
if target_message is None:
return sender, target
if ":" in target_message or "" in target_message:
# 使用正则表达式匹配中文或英文冒号
parts = re.split(pattern=r"[:]", string=target_message, maxsplit=1)
if len(parts) == 2:
sender = parts[0].strip()
target = parts[1].strip()
return sender, target
@staticmethod
async def build_relation_info(chat_id: str, reply_to: str) -> str:
"""
构建关系信息 - 统一实现
Args:
chat_id: 聊天ID
reply_to: 回复目标字符串
Returns:
str: 关系信息字符串
"""
if not global_config.relationship.enable_relationship:
return ""
from src.person_info.relationship_fetcher import relationship_fetcher_manager
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
if not reply_to:
return ""
sender, text = Prompt.parse_reply_target(reply_to)
if not sender or not text:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if not person_id:
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
@staticmethod
async def build_cross_context(
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
) -> str:
"""
构建跨群聊上下文 - 统一实现
Args:
chat_id: 聊天ID
prompt_mode: 当前提示词模式
target_user_info: 目标用户信息
Returns:
str: 跨群聊上下文字符串
"""
if not global_config.cross_context.enable:
return ""
from src.plugin_system.apis import cross_context_api
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
if not other_chat_raw_ids:
return ""
chat_stream = get_chat_manager().get_stream(chat_id)
if not chat_stream:
return ""
if prompt_mode == "normal":
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
elif prompt_mode == "s4u":
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
return ""
@staticmethod
def parse_reply_target_id(reply_to: str) -> str:
"""
解析回复目标中的用户ID
Args:
reply_to: 回复目标字符串
Returns:
str: 用户ID
"""
if not reply_to:
return ""
# 复用parse_reply_target方法的逻辑
sender, _ = Prompt.parse_reply_target(reply_to)
if not sender:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if person_id:
user_id = person_info_manager.get_value_sync(person_id, "user_id")
return str(user_id) if user_id else ""
return ""
# 工厂函数
def create_prompt(
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
) -> Prompt:
"""快速创建Prompt实例的工厂函数"""
if parameters is None:
parameters = PromptParameters(**kwargs)
return Prompt(template, name, parameters)
async def create_prompt_async(
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
) -> Prompt:
"""异步创建Prompt实例"""
prompt = create_prompt(template, name, parameters, **kwargs)
if global_prompt_manager._context._current_context:
await global_prompt_manager._context.register_async(prompt)
return prompt

View File

@@ -1,299 +0,0 @@
import re
import asyncio
import contextvars
from rich.traceback import install
from contextlib import asynccontextmanager
from typing import Dict, Any, Optional, List, Union
from src.common.logger import get_logger
install(extra_lines=3)
logger = get_logger("prompt_build")
class PromptContext:
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
# 使用contextvars创建协程上下文变量
self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock() # 保留锁用于其他操作
@property
def _current_context(self) -> Optional[str]:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value) # type: ignore
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
# sourcery skip: hoist-statement-from-if, use-contextlib-suppress
"""创建一个异步的临时提示模板作用域"""
# 保存当前上下文并设置新上下文
if context_id is not None:
try:
# 添加超时保护,避免长时间等待锁
await asyncio.wait_for(self._context_lock.acquire(), timeout=5.0)
try:
if context_id not in self._context_prompts:
self._context_prompts[context_id] = {}
finally:
self._context_lock.release()
except asyncio.TimeoutError:
logger.warning(f"获取上下文锁超时context_id: {context_id}")
# 超时时直接进入,不设置上下文
context_id = None
# 保存当前协程的上下文值,不影响其他协程
previous_context = self._current_context
# 设置当前协程的新上下文
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
else:
# 如果没有提供新上下文,保持当前上下文不变
previous_context = self._current_context
token = None
try:
yield self
finally:
# 恢复之前的上下文,添加异常保护
if context_id is not None and token is not None:
try:
self._current_context_var.reset(token)
except Exception as e:
logger.warning(f"恢复上下文时出错: {e}")
# 如果reset失败尝试直接设置
try:
self._current_context = previous_context
except Exception:
...
# 静默忽略恢复失败
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板"""
async with self._context_lock:
current_context = self._current_context
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
if (
current_context
and current_context in self._context_prompts
and name in self._context_prompts[current_context]
):
return self._context_prompts[current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
if target_context := context_id or self._current_context:
if prompt.name:
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
class PromptManager:
def __init__(self):
self._prompts = {}
self._counter = 0
self._context = PromptContext()
self._lock = asyncio.Lock()
@asynccontextmanager
async def async_message_scope(self, message_id: Optional[str] = None):
"""为消息处理创建异步临时作用域,支持 message_id 为 None 的情况"""
async with self._context.async_scope(message_id):
yield self
async def get_prompt_async(self, name: str) -> "Prompt":
# 首先尝试从当前上下文获取
context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None:
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
return context_prompt
# 如果上下文中不存在,则使用全局提示模板
async with self._lock:
# logger.debug(f"从全局获取提示词: {name}")
if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]
def generate_name(self, template: str) -> str:
"""为未命名的prompt生成名称"""
self._counter += 1
return f"prompt_{self._counter}"
def register(self, prompt: "Prompt") -> None:
"""注册一个prompt"""
if not prompt.name:
prompt.name = self.generate_name(prompt.template)
self._prompts[prompt.name] = prompt
def add_prompt(self, name: str, fstr: str) -> "Prompt":
prompt = Prompt(fstr, name=name)
if prompt.name:
self._prompts[prompt.name] = prompt
return prompt
async def format_prompt(self, name: str, **kwargs) -> str:
# 获取当前提示词
prompt = await self.get_prompt_async(name)
# 获取基本格式化结果
result = prompt.format(**kwargs)
return result
# 全局单例
global_prompt_manager = PromptManager()
class Prompt(str):
template: str
name: Optional[str]
args: List[str]
_args: List[Any]
_kwargs: Dict[str, Any]
# 临时标记,作为类常量
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
@staticmethod
def _process_escaped_braces(template) -> str:
"""处理模板中的转义花括号,将 \\{\\} 替换为临时标记""" # type: ignore
# 如果传入的是列表,将其转换为字符串
if isinstance(template, list):
template = "\n".join(str(item) for item in template)
elif not isinstance(template, str):
template = str(template)
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
@staticmethod
def _restore_escaped_braces(template: str) -> str:
"""将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def __new__(
cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs
):
# 如果传入的是元组,转换为列表
if isinstance(args, tuple):
args = list(args)
should_register = kwargs.pop("_should_register", True)
# 预处理模板中的转义花括号
processed_fstr = cls._process_escaped_braces(fstr)
# 解析模板
template_args = []
result = re.findall(r"\{(.*?)}", processed_fstr)
for expr in result:
if expr and expr not in template_args:
template_args.append(expr)
# 如果提供了初始参数,立即格式化
if kwargs or args:
formatted = cls._format_template(fstr, args=args, kwargs=kwargs)
obj = super().__new__(cls, formatted)
else:
obj = super().__new__(cls, "")
obj.template = fstr
obj.name = name
obj.args = template_args
obj._args = args or []
obj._kwargs = kwargs
# 修改自动注册逻辑
if should_register and not global_prompt_manager._context._current_context:
global_prompt_manager.register(obj)
return obj
@classmethod
async def create_async(
cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs
):
"""异步创建Prompt实例"""
prompt = cls(fstr, name, args, **kwargs)
if global_prompt_manager._context._current_context:
await global_prompt_manager._context.register_async(prompt)
return prompt
@classmethod
def _format_template(
cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None
) -> str:
if kwargs is None:
kwargs = {}
# 预处理模板中的转义花括号
processed_template = cls._process_escaped_braces(template)
template_args = []
result = re.findall(r"\{(.*?)}", processed_template)
for expr in result:
if expr and expr not in template_args:
template_args.append(expr)
formatted_args = {}
formatted_kwargs = {}
# 处理位置参数
if args:
# print(len(template_args), len(args), template_args, args)
for i in range(len(args)):
if i < len(template_args):
arg = args[i]
if isinstance(arg, Prompt):
formatted_args[template_args[i]] = arg.format(**kwargs)
else:
formatted_args[template_args[i]] = arg
else:
logger.error(
f"构建提示词模板失败,解析到的参数列表{template_args},长度为{len(template_args)},输入的参数列表为{args},提示词模板为{template}"
)
raise ValueError("格式化模板失败")
# 处理关键字参数
if kwargs:
for key, value in kwargs.items():
if isinstance(value, Prompt):
remaining_kwargs = {k: v for k, v in kwargs.items() if k != key}
formatted_kwargs[key] = value.format(**remaining_kwargs)
else:
formatted_kwargs[key] = value
try:
# 先用位置参数格式化
if args:
processed_template = processed_template.format(**formatted_args)
# 再用关键字参数格式化
if kwargs:
processed_template = processed_template.format(**formatted_kwargs)
# 将临时标记还原为实际的花括号
result = cls._restore_escaped_braces(processed_template)
return result
except (IndexError, KeyError) as e:
raise ValueError(
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
) from e
def format(self, *args, **kwargs) -> "str":
"""支持位置参数和关键字参数的格式化,使用"""
ret = type(self)(
self.template,
self.name,
args=list(args) if args else self._args,
_should_register=False,
**kwargs or self._kwargs,
)
# print(f"prompt build result: {ret} name: {ret.name} ")
return str(ret)
def __str__(self) -> str:
return super().__str__() if self._kwargs or self._args else self.template
def __repr__(self) -> str:
return f"Prompt(template='{self.template}', name='{self.name}')"

View File

@@ -1,156 +0,0 @@
"""
智能提示词参数模块 - 优化参数结构
简化SmartPromptParameters减少冗余和重复
"""
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Literal
@dataclass
class SmartPromptParameters:
"""简化的智能提示词参数系统"""
# 基础参数
chat_id: str = ""
is_group_chat: bool = False
sender: str = ""
target: str = ""
reply_to: str = ""
extra_info: str = ""
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
# 功能开关
enable_tool: bool = True
enable_memory: bool = True
enable_expression: bool = True
enable_relation: bool = True
enable_cross_context: bool = True
enable_knowledge: bool = True
# 性能控制
max_context_messages: int = 50
# 调试选项
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: Optional[Dict[str, Any]] = None
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: Optional[Dict[str, Any]] = None
# 已构建的内容块
expression_habits_block: str = ""
relation_info_block: str = ""
memory_block: str = ""
tool_info_block: str = ""
knowledge_prompt: str = ""
cross_context_block: str = ""
# 其他内容块
keywords_reaction_prompt: str = ""
extra_info_block: str = ""
time_block: str = ""
identity_block: str = ""
schedule_block: str = ""
moderation_prompt_block: str = ""
reply_target_block: str = ""
mood_prompt: str = ""
action_descriptions: str = ""
# 可用动作信息
available_actions: Optional[Dict[str, Any]] = None
def validate(self) -> List[str]:
"""统一的参数验证"""
errors = []
if not self.chat_id:
errors.append("chat_id不能为空")
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
errors.append("prompt_mode必须是's4u''normal''minimal'")
if self.max_context_messages <= 0:
errors.append("max_context_messages必须大于0")
return errors
def get_needed_build_tasks(self) -> List[str]:
"""获取需要执行的任务列表"""
tasks = []
if self.enable_expression and not self.expression_habits_block:
tasks.append("expression_habits")
if self.enable_memory and not self.memory_block:
tasks.append("memory_block")
if self.enable_relation and not self.relation_info_block:
tasks.append("relation_info")
if self.enable_tool and not self.tool_info_block:
tasks.append("tool_info")
if self.enable_knowledge and not self.knowledge_prompt:
tasks.append("knowledge_info")
if self.enable_cross_context and not self.cross_context_block:
tasks.append("cross_context")
return tasks
@classmethod
def from_legacy_params(cls, **kwargs) -> "SmartPromptParameters":
"""
从旧版参数创建新参数对象
Args:
**kwargs: 旧版参数
Returns:
SmartPromptParameters: 新参数对象
"""
return cls(
# 基础参数
chat_id=kwargs.get("chat_id", ""),
is_group_chat=kwargs.get("is_group_chat", False),
sender=kwargs.get("sender", ""),
target=kwargs.get("target", ""),
reply_to=kwargs.get("reply_to", ""),
extra_info=kwargs.get("extra_info", ""),
prompt_mode=kwargs.get("current_prompt_mode", "s4u"),
# 功能开关
enable_tool=kwargs.get("enable_tool", True),
enable_memory=kwargs.get("enable_memory", True),
enable_expression=kwargs.get("enable_expression", True),
enable_relation=kwargs.get("enable_relation", True),
enable_cross_context=kwargs.get("enable_cross_context", True),
enable_knowledge=kwargs.get("enable_knowledge", True),
# 性能控制
max_context_messages=kwargs.get("max_context_messages", 50),
debug_mode=kwargs.get("debug_mode", False),
# 聊天历史和上下文
chat_target_info=kwargs.get("chat_target_info"),
message_list_before_now_long=kwargs.get("message_list_before_now_long", []),
message_list_before_short=kwargs.get("message_list_before_short", []),
chat_talking_prompt_short=kwargs.get("chat_talking_prompt_short", ""),
target_user_info=kwargs.get("target_user_info"),
# 已构建的内容块
expression_habits_block=kwargs.get("expression_habits_block", ""),
relation_info_block=kwargs.get("relation_info", ""),
memory_block=kwargs.get("memory_block", ""),
tool_info_block=kwargs.get("tool_info", ""),
knowledge_prompt=kwargs.get("knowledge_prompt", ""),
cross_context_block=kwargs.get("cross_context_block", ""),
# 其他内容块
keywords_reaction_prompt=kwargs.get("keywords_reaction_prompt", ""),
extra_info_block=kwargs.get("extra_info_block", ""),
time_block=kwargs.get("time_block", ""),
identity_block=kwargs.get("identity_block", ""),
schedule_block=kwargs.get("schedule_block", ""),
moderation_prompt_block=kwargs.get("moderation_prompt_block", ""),
reply_target_block=kwargs.get("reply_target_block", ""),
mood_prompt=kwargs.get("mood_prompt", ""),
action_descriptions=kwargs.get("action_descriptions", ""),
# 可用动作信息
available_actions=kwargs.get("available_actions", None),
)

View File

@@ -1,132 +0,0 @@
"""
共享提示词工具模块 - 消除重复代码
提供统一的工具函数供DefaultReplyer和SmartPrompt使用
"""
import re
from typing import Dict, Any, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
from src.plugin_system.apis import cross_context_api
logger = get_logger("prompt_utils")
class PromptUtils:
"""提示词工具类 - 提供共享功能,移除缓存相关功能和依赖检查"""
@staticmethod
def parse_reply_target(target_message: str) -> Tuple[str, str]:
"""
解析回复目标消息 - 统一实现
Args:
target_message: 目标消息,格式为 "发送者:消息内容""发送者:消息内容"
Returns:
Tuple[str, str]: (发送者名称, 消息内容)
"""
sender = ""
target = ""
# 添加None检查防止NoneType错误
if target_message is None:
return sender, target
if ":" in target_message or "" in target_message:
# 使用正则表达式匹配中文或英文冒号
parts = re.split(pattern=r"[:]", string=target_message, maxsplit=1)
if len(parts) == 2:
sender = parts[0].strip()
target = parts[1].strip()
return sender, target
@staticmethod
async def build_relation_info(chat_id: str, reply_to: str) -> str:
"""
构建关系信息 - 统一实现
Args:
chat_id: 聊天ID
reply_to: 回复目标字符串
Returns:
str: 关系信息字符串
"""
if not global_config.relationship.enable_relationship:
return ""
from src.person_info.relationship_fetcher import relationship_fetcher_manager
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
if not reply_to:
return ""
sender, text = PromptUtils.parse_reply_target(reply_to)
if not sender or not text:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if not person_id:
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
@staticmethod
async def build_cross_context(
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
) -> str:
"""
构建跨群聊上下文 - 统一实现完全继承DefaultReplyer功能
"""
if not global_config.cross_context.enable:
return ""
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
if not other_chat_raw_ids:
return ""
chat_stream = get_chat_manager().get_stream(chat_id)
if not chat_stream:
return ""
if current_prompt_mode == "normal":
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
elif current_prompt_mode == "s4u":
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
return ""
@staticmethod
def parse_reply_target_id(reply_to: str) -> str:
"""
解析回复目标中的用户ID
Args:
reply_to: 回复目标字符串
Returns:
str: 用户ID
"""
if not reply_to:
return ""
# 复用parse_reply_target方法的逻辑
sender, _ = PromptUtils.parse_reply_target(reply_to)
if not sender:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if person_id:
user_id = person_info_manager.get_value_sync(person_id, "user_id")
return str(user_id) if user_id else ""
return ""

View File

@@ -1,938 +0,0 @@
"""
智能Prompt系统 - 完全重构版本
基于原有DefaultReplyer的完整功能集成使用新的参数结构
解决实现质量不高、功能集成不完整和错误处理不足的问题
"""
import asyncio
import time
from datetime import datetime
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Tuple
from src.chat.utils.prompt_builder import global_prompt_manager, Prompt
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import (
build_readable_messages,
)
from src.person_info.person_info import get_person_info_manager
from src.chat.utils.prompt_utils import PromptUtils
from src.chat.utils.prompt_parameters import SmartPromptParameters
logger = get_logger("smart_prompt")
@dataclass
class ChatContext:
"""聊天上下文信息"""
chat_id: str = ""
platform: str = ""
is_group: bool = False
user_id: str = ""
user_nickname: str = ""
group_id: Optional[str] = None
timestamp: datetime = field(default_factory=datetime.now)
class SmartPromptBuilder:
"""重构的智能提示词构建器 - 统一错误处理和功能集成,移除缓存机制和依赖检查"""
def __init__(self):
# 移除缓存相关初始化
pass
async def build_context_data(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""并行构建完整的上下文数据 - 移除缓存机制和依赖检查"""
# 并行执行所有构建任务
start_time = time.time()
timing_logs = {}
try:
# 准备构建任务
tasks = []
task_names = []
# 初始化预构建参数,使用新的结构
pre_built_params = {}
if params.expression_habits_block:
pre_built_params["expression_habits_block"] = params.expression_habits_block
if params.relation_info_block:
pre_built_params["relation_info_block"] = params.relation_info_block
if params.memory_block:
pre_built_params["memory_block"] = params.memory_block
if params.tool_info_block:
pre_built_params["tool_info_block"] = params.tool_info_block
if params.knowledge_prompt:
pre_built_params["knowledge_prompt"] = params.knowledge_prompt
if params.cross_context_block:
pre_built_params["cross_context_block"] = params.cross_context_block
# 根据新的参数结构确定要构建的项
if params.enable_expression and not pre_built_params.get("expression_habits_block"):
tasks.append(self._build_expression_habits(params))
task_names.append("expression_habits")
if params.enable_memory and not pre_built_params.get("memory_block"):
tasks.append(self._build_memory_block(params))
task_names.append("memory_block")
if params.enable_relation and not pre_built_params.get("relation_info_block"):
tasks.append(self._build_relation_info(params))
task_names.append("relation_info")
# 添加mai_think上下文构建任务
if not pre_built_params.get("mai_think"):
tasks.append(self._build_mai_think_context(params))
task_names.append("mai_think_context")
if params.enable_tool and not pre_built_params.get("tool_info_block"):
tasks.append(self._build_tool_info(params))
task_names.append("tool_info")
if params.enable_knowledge and not pre_built_params.get("knowledge_prompt"):
tasks.append(self._build_knowledge_info(params))
task_names.append("knowledge_info")
if params.enable_cross_context and not pre_built_params.get("cross_context_block"):
tasks.append(self._build_cross_context(params))
task_names.append("cross_context")
# 性能优化:根据任务数量动态调整超时时间
base_timeout = 10.0 # 基础超时时间
task_timeout = 2.0 # 每个任务的超时时间
timeout_seconds = min(
max(base_timeout, len(tasks) * task_timeout), # 根据任务数量计算超时
30.0, # 最大超时时间
)
# 性能优化:限制并发任务数量,避免资源耗尽
max_concurrent_tasks = 5 # 最大并发任务数
if len(tasks) > max_concurrent_tasks:
# 分批执行任务
results = []
for i in range(0, len(tasks), max_concurrent_tasks):
batch_tasks = tasks[i : i + max_concurrent_tasks]
batch_names = task_names[i : i + max_concurrent_tasks]
batch_results = await asyncio.wait_for(
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
)
results.extend(batch_results)
else:
# 一次性执行所有任务
results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
)
# 处理结果并收集性能数据
context_data = {}
for i, result in enumerate(results):
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
if isinstance(result, Exception):
logger.error(f"构建任务{task_name}失败: {str(result)}")
elif isinstance(result, dict):
# 结果格式: {component_name: value}
context_data.update(result)
# 记录耗时过长的任务
if task_name in timing_logs and timing_logs[task_name] > 8.0:
logger.warning(f"构建任务{task_name}耗时过长: {timing_logs[task_name]:.2f}s")
# 添加预构建的参数
for key, value in pre_built_params.items():
if value:
context_data[key] = value
except asyncio.TimeoutError:
logger.error(f"构建超时 ({timeout_seconds}s)")
context_data = {}
# 添加预构建的参数,即使在超时情况下
for key, value in pre_built_params.items():
if value:
context_data[key] = value
# 构建聊天历史 - 根据模式不同
if params.prompt_mode == "s4u":
await self._build_s4u_chat_context(context_data, params)
else:
await self._build_normal_chat_context(context_data, params)
# 补充基础信息
context_data.update(
{
"keywords_reaction_prompt": params.keywords_reaction_prompt,
"extra_info_block": params.extra_info_block,
"time_block": params.time_block or f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
"identity": params.identity_block,
"schedule_block": params.schedule_block,
"moderation_prompt": params.moderation_prompt_block,
"reply_target_block": params.reply_target_block,
"mood_state": params.mood_prompt,
"action_descriptions": params.action_descriptions,
}
)
total_time = time.time() - start_time
if timing_logs:
timing_str = "; ".join([f"{name}: {time:.2f}s" for name, time in timing_logs.items()])
logger.info(f"构建任务耗时: {timing_str}")
logger.debug(f"构建完成,总耗时: {total_time:.2f}s")
return context_data
async def _build_s4u_chat_context(self, context_data: Dict[str, Any], params: SmartPromptParameters) -> None:
"""构建S4U模式的聊天上下文 - 使用新参数结构"""
if not params.message_list_before_now_long:
return
# 使用共享工具构建分离历史
core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts(
params.message_list_before_now_long,
params.target_user_info.get("user_id") if params.target_user_info else "",
params.sender,
)
context_data["core_dialogue_prompt"] = core_dialogue
context_data["background_dialogue_prompt"] = background_dialogue
async def _build_normal_chat_context(self, context_data: Dict[str, Any], params: SmartPromptParameters) -> None:
"""构建normal模式的聊天上下文 - 使用新参数结构"""
if not params.chat_talking_prompt_short:
return
context_data["chat_info"] = f"""群里的聊天内容:
{params.chat_talking_prompt_short}"""
async def _build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
) -> Tuple[str, str]:
"""构建S4U风格的分离对话prompt - 完整实现"""
core_dialogue_list = []
bot_id = str(global_config.bot.qq_account)
# 过滤消息分离bot和目标用户的对话 vs 其他用户的对话
for msg_dict in message_list_before_now:
try:
msg_user_id = str(msg_dict.get("user_id"))
reply_to = msg_dict.get("reply_to", "")
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
# bot 和目标用户的对话
core_dialogue_list.append(msg_dict)
except Exception as e:
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
# 构建背景对话 prompt
all_dialogue_prompt = ""
if message_list_before_now:
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
all_dialogue_prompt_str = build_readable_messages(
latest_25_msgs,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True,
)
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
# 构建核心对话 prompt
core_dialogue_prompt = ""
if core_dialogue_list:
# 检查最新五条消息中是否包含bot自己说的消息
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
# logger.info(f"最新五条消息:{latest_5_messages}")
# logger.info(f"最新五条消息中是否包含bot自己说的消息{has_bot_message}")
# 如果最新五条消息中不包含bot的消息则返回空字符串
if not has_bot_message:
core_dialogue_prompt = ""
else:
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
core_dialogue_prompt_str = build_readable_messages(
core_dialogue_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
core_dialogue_prompt = f"""--------------------------------
这是你和{sender}的对话,你们正在交流中:
{core_dialogue_prompt_str}
--------------------------------
"""
return core_dialogue_prompt, all_dialogue_prompt
async def _build_mai_think_context(self, params: SmartPromptParameters) -> Any:
"""构建mai_think上下文 - 完全继承DefaultReplyer功能"""
from src.mais4u.mai_think import mai_thinking_manager
# 获取mai_think实例
mai_think = mai_thinking_manager.get_mai_think(params.chat_id)
# 设置mai_think的上下文信息
mai_think.memory_block = params.memory_block or ""
mai_think.relation_info_block = params.relation_info_block or ""
mai_think.time_block = params.time_block or f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# 设置聊天目标信息
if params.is_group_chat:
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
else:
chat_target_name = "对方"
if params.chat_target_info:
chat_target_name = (
params.chat_target_info.get("person_name") or params.chat_target_info.get("user_nickname") or "对方"
)
chat_target_1 = await global_prompt_manager.format_prompt(
"chat_target_private1", sender_name=chat_target_name
)
chat_target_2 = await global_prompt_manager.format_prompt(
"chat_target_private2", sender_name=chat_target_name
)
mai_think.chat_target = chat_target_1
mai_think.chat_target_2 = chat_target_2
mai_think.chat_info = params.chat_talking_prompt_short or ""
mai_think.mood_state = params.mood_prompt or ""
mai_think.identity = params.identity_block or ""
mai_think.sender = params.sender
mai_think.target = params.target
# 返回mai_think实例以便后续使用
return mai_think
def _parse_reply_target_id(self, reply_to: str) -> str:
"""解析回复目标中的用户ID"""
if not reply_to:
return ""
# 复用_parse_reply_target方法的逻辑
sender, _ = self._parse_reply_target(reply_to)
if not sender:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if person_id:
user_id = person_info_manager.get_value_sync(person_id, "user_id")
return str(user_id) if user_id else ""
async def _build_expression_habits(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建表达习惯 - 使用共享工具类完全继承DefaultReplyer功能"""
# 检查是否允许在此聊天流中使用表达
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(params.chat_id)
if not use_expression:
return {"expression_habits_block": ""}
from src.chat.express.expression_selector import expression_selector
style_habits = []
grammar_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
try:
selected_expressions = await expression_selector.select_suitable_expressions_llm(
params.chat_id, params.chat_talking_prompt_short, max_num=8, min_num=2, target_message=params.target
)
except Exception as e:
logger.error(f"选择表达方式失败: {e}")
selected_expressions = []
if selected_expressions:
logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式")
for expr in selected_expressions:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_type = expr.get("type", "style")
if expr_type == "grammar":
grammar_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
style_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
# 不再在replyer中进行随机选择全部交给处理器处理
style_habits_str = "\n".join(style_habits)
grammar_habits_str = "\n".join(grammar_habits)
# 动态构建expression habits块
expression_habits_block = ""
expression_habits_title = ""
if style_habits_str.strip():
expression_habits_title = (
"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:"
)
expression_habits_block += f"{style_habits_str}\n"
if grammar_habits_str.strip():
expression_habits_title = (
"你可以选择下面的句法进行回复,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式使用:"
)
expression_habits_block += f"{grammar_habits_str}\n"
if style_habits_str.strip() and grammar_habits_str.strip():
expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中。"
return {"expression_habits_block": f"{expression_habits_title}\n{expression_habits_block}"}
async def _build_memory_block(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建记忆块 - 使用共享工具类完全继承DefaultReplyer功能"""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
from src.chat.memory_system.memory_activator import MemoryActivator
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
instant_memory = None
# 初始化记忆激活器
try:
memory_activator = MemoryActivator()
# 获取长期记忆
running_memories = await memory_activator.activate_memory_with_chat_history(
target_message=params.target, chat_history_prompt=params.chat_talking_prompt_short
)
except Exception as e:
logger.error(f"激活记忆失败: {e}")
running_memories = []
# 处理瞬时记忆
if global_config.memory.enable_instant_memory:
# 使用异步记忆包装器(最优化的非阻塞模式)
try:
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
# 获取异步记忆包装器
async_memory = get_async_instant_memory(params.chat_id)
# 后台存储聊天历史(完全非阻塞)
async_memory.store_memory_background(params.chat_talking_prompt_short)
# 快速检索记忆最大超时2秒
instant_memory = await async_memory.get_memory_with_fallback(params.target, max_timeout=2.0)
logger.info(f"异步瞬时记忆:{instant_memory}")
except ImportError:
# 如果异步包装器不可用,尝试使用异步记忆管理器
try:
from src.chat.memory_system.async_memory_optimizer import (
retrieve_memory_nonblocking,
store_memory_nonblocking,
)
# 异步存储聊天历史(非阻塞)
asyncio.create_task(
store_memory_nonblocking(chat_id=params.chat_id, content=params.chat_talking_prompt_short)
)
# 尝试从缓存获取瞬时记忆
instant_memory = await retrieve_memory_nonblocking(chat_id=params.chat_id, query=params.target)
# 如果没有缓存结果,快速检索一次
if instant_memory is None:
try:
# 使用VectorInstantMemoryV2实例
instant_memory_system = VectorInstantMemoryV2(chat_id=params.chat_id, retention_hours=1)
instant_memory = await asyncio.wait_for(
instant_memory_system.get_memory_for_context(params.target), timeout=1.5
)
except asyncio.TimeoutError:
logger.warning("瞬时记忆检索超时,使用空结果")
instant_memory = ""
logger.info(f"向量瞬时记忆:{instant_memory}")
except ImportError:
# 最后的fallback使用原有逻辑但加上超时控制
logger.warning("异步记忆系统不可用,使用带超时的同步方式")
# 使用VectorInstantMemoryV2实例
instant_memory_system = VectorInstantMemoryV2(chat_id=params.chat_id, retention_hours=1)
# 异步存储聊天历史
asyncio.create_task(instant_memory_system.store_message(params.chat_talking_prompt_short))
# 带超时的记忆检索
try:
instant_memory = await asyncio.wait_for(
instant_memory_system.get_memory_for_context(params.target),
timeout=1.0, # 最保守的1秒超时
)
except asyncio.TimeoutError:
logger.warning("瞬时记忆检索超时,跳过记忆获取")
instant_memory = ""
except Exception as e:
logger.error(f"瞬时记忆检索失败: {e}")
instant_memory = ""
logger.info(f"同步瞬时记忆:{instant_memory}")
except Exception as e:
logger.error(f"瞬时记忆系统异常: {e}")
instant_memory = ""
# 构建记忆字符串,即使某种记忆为空也要继续
memory_str = ""
has_any_memory = False
# 添加长期记忆
if running_memories:
if not memory_str:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memories:
memory_str += f"- {running_memory['content']}\n"
has_any_memory = True
# 添加瞬时记忆
if instant_memory:
if not memory_str:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
memory_str += f"- {instant_memory}\n"
has_any_memory = True
# 注入视频分析结果引导语
memory_str = self._inject_video_prompt_if_needed(params.target, memory_str)
# 只有当完全没有任何记忆时才返回空字符串
return {"memory_block": memory_str if has_any_memory else ""}
def _inject_video_prompt_if_needed(self, target: str, memory_str: str) -> str:
"""统一视频分析结果注入逻辑"""
if target and ("[视频内容]" in target or "好的,我将根据您提供的" in target):
video_prompt_injection = (
"\n请注意,以上内容是你刚刚观看的视频,请以第一人称分享你的观后感,而不是在分析一份报告。"
)
return memory_str + video_prompt_injection
return memory_str
async def _build_relation_info(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建关系信息 - 使用共享工具类"""
try:
relation_info = await PromptUtils.build_relation_info(params.chat_id, params.reply_to)
return {"relation_info_block": relation_info}
except Exception as e:
logger.error(f"构建关系信息失败: {e}")
return {"relation_info_block": ""}
async def _build_tool_info(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建工具信息 - 使用共享工具类完全继承DefaultReplyer功能"""
if not params.enable_tool:
return {"tool_info_block": ""}
if not params.reply_to:
return {"tool_info_block": ""}
sender, text = PromptUtils.parse_reply_target(params.reply_to)
if not text:
return {"tool_info_block": ""}
from src.plugin_system.core.tool_use import ToolExecutor
# 使用工具执行器获取信息
try:
tool_executor = ToolExecutor(chat_id=params.chat_id)
tool_results, _, _ = await tool_executor.execute_from_chat_message(
sender=sender, target_message=text, chat_history=params.chat_talking_prompt_short, return_details=False
)
if tool_results:
tool_info_str = "以下是你通过工具获取到的实时信息:\n"
for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}{result_type}: {content}\n"
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
logger.info(f"获取到 {len(tool_results)} 个工具结果")
return {"tool_info_block": tool_info_str}
else:
logger.debug("未获取到任何工具结果")
return {"tool_info_block": ""}
except Exception as e:
logger.error(f"工具信息获取失败: {e}")
return {"tool_info_block": ""}
async def _build_knowledge_info(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建知识信息 - 使用共享工具类完全继承DefaultReplyer功能"""
if not params.reply_to:
logger.debug("没有回复对象,跳过获取知识库内容")
return {"knowledge_prompt": ""}
sender, content = PromptUtils.parse_reply_target(params.reply_to)
if not content:
logger.debug("回复对象内容为空,跳过获取知识库内容")
return {"knowledge_prompt": ""}
logger.debug(
f"获取知识库内容,元消息:{params.chat_talking_prompt_short[:30]}...,消息长度: {len(params.chat_talking_prompt_short)}"
)
# 从LPMM知识库获取知识
try:
# 检查LPMM知识库是否启用
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用跳过获取知识库内容")
return {"knowledge_prompt": ""}
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
from src.plugin_system.apis import llm_api
from src.config.config import model_config
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname
prompt = await global_prompt_manager.format_prompt(
"lpmm_get_knowledge_prompt",
bot_name=bot_name,
time_now=time_now,
chat_history=params.chat_talking_prompt_short,
sender=sender,
target_message=content,
)
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
)
if tool_calls:
from src.plugin_system.core.tool_use import ToolExecutor
tool_executor = ToolExecutor(chat_id=params.chat_id)
result = await tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
if not result or not result.get("content"):
logger.debug("从LPMM知识库获取知识失败返回空知识...")
return {"knowledge_prompt": ""}
found_knowledge_from_lpmm = result.get("content", "")
logger.debug(
f"从LPMM知识库获取知识相关信息{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
)
return {
"knowledge_prompt": f"你有以下这些**知识**\n{found_knowledge_from_lpmm}\n请你**记住上面的知识**,之后可能会用到。\n"
}
else:
logger.debug("从LPMM知识库获取知识失败可能是从未导入过知识返回空知识...")
return {"knowledge_prompt": ""}
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return {"knowledge_prompt": ""}
async def _build_cross_context(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建跨群上下文 - 使用共享工具类"""
try:
cross_context = await PromptUtils.build_cross_context(
params.chat_id, params.prompt_mode, params.target_user_info
)
return {"cross_context_block": cross_context}
except Exception as e:
logger.error(f"构建跨群上下文失败: {e}")
return {"cross_context_block": ""}
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
"""解析回复目标消息 - 使用共享工具类"""
return PromptUtils.parse_reply_target(target_message)
class SmartPrompt:
"""重构的智能提示词核心类 - 移除缓存机制和依赖检查,简化架构"""
def __init__(
self,
template_name: Optional[str] = None,
parameters: Optional[SmartPromptParameters] = None,
):
self.parameters = parameters or SmartPromptParameters()
self.template_name = template_name or self._get_default_template()
self.builder = SmartPromptBuilder()
def _get_default_template(self) -> str:
"""根据模式选择默认模板"""
if self.parameters.prompt_mode == "s4u":
return "s4u_style_prompt"
elif self.parameters.prompt_mode == "normal":
return "normal_style_prompt"
else:
return "default_expressor_prompt"
async def build_prompt(self) -> str:
"""构建最终的Prompt文本 - 移除缓存机制和依赖检查"""
# 参数验证
errors = self.parameters.validate()
if errors:
logger.error(f"参数验证失败: {', '.join(errors)}")
raise ValueError(f"参数验证失败: {', '.join(errors)}")
start_time = time.time()
try:
# 构建基础上下文的完整映射
context_data = await self.builder.build_context_data(self.parameters)
# 检查关键上下文数据
if not context_data or not isinstance(context_data, dict):
logger.error("构建的上下文数据无效")
raise ValueError("构建的上下文数据无效")
# 获取模板
template = await self._get_template()
if template is None:
logger.error("无法获取模板")
raise ValueError("无法获取模板")
# 根据模式传递不同的参数
if self.parameters.prompt_mode == "s4u":
result = await self._build_s4u_prompt(template, context_data)
elif self.parameters.prompt_mode == "normal":
result = await self._build_normal_prompt(template, context_data)
else:
result = await self._build_default_prompt(template, context_data)
# 记录性能数据
total_time = time.time() - start_time
logger.debug(f"SmartPrompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
return result
except asyncio.TimeoutError as e:
logger.error(f"构建Prompt超时: {e}")
raise TimeoutError(f"构建Prompt超时: {e}")
except Exception as e:
logger.error(f"构建Prompt失败: {e}")
raise RuntimeError(f"构建Prompt失败: {e}")
async def _get_template(self) -> Optional[Prompt]:
"""获取模板"""
try:
return await global_prompt_manager.get_prompt_async(self.template_name)
except Exception as e:
logger.error(f"获取模板 {self.template_name} 失败: {e}")
raise RuntimeError(f"获取模板 {self.template_name} 失败: {e}")
async def _build_s4u_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str:
"""构建S4U模式的完整Prompt - 使用新参数结构"""
params = {
**context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""),
"tool_info_block": context_data.get("tool_info_block", ""),
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"sender_name": self.parameters.sender,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""),
"time_block": context_data.get("time_block", ""),
"core_dialogue_prompt": context_data.get("core_dialogue_prompt", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
return await global_prompt_manager.format_prompt(self.template_name, **params)
async def _build_normal_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str:
"""构建Normal模式的完整Prompt - 使用新参数结构"""
params = {
**context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""),
"tool_info_block": context_data.get("tool_info_block", ""),
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"schedule_block": self.parameters.schedule_block or context_data.get("schedule_block", ""),
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"config_expression_style": global_config.personality.reply_style,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
return await global_prompt_manager.format_prompt(self.template_name, **params)
async def _build_default_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str:
"""构建默认模式的Prompt - 使用新参数结构"""
params = {
"expression_habits_block": context_data.get("expression_habits_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"chat_target": "",
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"chat_target_2": "",
"reply_target_block": context_data.get("reply_target_block", ""),
"raw_reply": self.parameters.target,
"reason": "",
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
return await global_prompt_manager.format_prompt(self.template_name, **params)
# 工厂函数 - 简化创建 - 更新参数结构
def create_smart_prompt(
chat_id: str = "", sender_name: str = "", target_message: str = "", reply_to: str = "", **kwargs
) -> SmartPrompt:
"""快速创建智能Prompt实例的工厂函数 - 使用新参数结构"""
# 使用新的参数结构
parameters = SmartPromptParameters(
chat_id=chat_id, sender=sender_name, target=target_message, reply_to=reply_to, **kwargs
)
return SmartPrompt(parameters=parameters)
class SmartPromptHealthChecker:
"""SmartPrompt健康检查器 - 移除依赖检查"""
@staticmethod
async def check_system_health() -> Dict[str, Any]:
"""检查系统健康状态 - 移除依赖检查"""
health_status = {"status": "healthy", "components": {}, "issues": []}
try:
# 检查配置
try:
from src.config.config import global_config
health_status["components"]["config"] = "ok"
# 检查关键配置项
if not hasattr(global_config, "personality") or not hasattr(global_config.personality, "prompt_mode"):
health_status["issues"].append("缺少personality.prompt_mode配置")
health_status["status"] = "degraded"
if not hasattr(global_config, "memory") or not hasattr(global_config.memory, "enable_memory"):
health_status["issues"].append("缺少memory.enable_memory配置")
except Exception as e:
health_status["components"]["config"] = f"failed: {str(e)}"
health_status["issues"].append("配置加载失败")
health_status["status"] = "unhealthy"
# 检查Prompt模板
try:
required_templates = ["s4u_style_prompt", "normal_style_prompt", "default_expressor_prompt"]
for template_name in required_templates:
try:
await global_prompt_manager.get_prompt_async(template_name)
health_status["components"][f"template_{template_name}"] = "ok"
except Exception as e:
health_status["components"][f"template_{template_name}"] = f"failed: {str(e)}"
health_status["issues"].append(f"模板{template_name}加载失败")
health_status["status"] = "degraded"
except Exception as e:
health_status["components"]["prompt_templates"] = f"failed: {str(e)}"
health_status["issues"].append("Prompt模板检查失败")
health_status["status"] = "unhealthy"
return health_status
except Exception as e:
return {"status": "unhealthy", "components": {}, "issues": [f"健康检查异常: {str(e)}"]}
@staticmethod
async def run_performance_test() -> Dict[str, Any]:
"""运行性能测试"""
test_results = {"status": "completed", "tests": {}, "summary": {}}
try:
# 创建测试参数
test_params = SmartPromptParameters(
chat_id="test_chat",
sender="test_user",
target="test_message",
reply_to="test_user:test_message",
prompt_mode="s4u",
)
# 测试不同模式下的构建性能
modes = ["s4u", "normal", "minimal"]
for mode in modes:
test_params.prompt_mode = mode
smart_prompt = SmartPrompt(parameters=test_params)
# 运行多次测试取平均值
times = []
for _ in range(3):
start_time = time.time()
try:
await smart_prompt.build_prompt()
end_time = time.time()
times.append(end_time - start_time)
except Exception as e:
times.append(float("inf"))
logger.error(f"性能测试失败 (模式: {mode}): {e}")
# 计算统计信息
valid_times = [t for t in times if t != float("inf")]
if valid_times:
avg_time = sum(valid_times) / len(valid_times)
min_time = min(valid_times)
max_time = max(valid_times)
test_results["tests"][mode] = {
"avg_time": avg_time,
"min_time": min_time,
"max_time": max_time,
"success_rate": len(valid_times) / len(times),
}
else:
test_results["tests"][mode] = {
"avg_time": float("inf"),
"min_time": float("inf"),
"max_time": float("inf"),
"success_rate": 0,
}
# 计算总体统计
all_avg_times = [
test["avg_time"] for test in test_results["tests"].values() if test["avg_time"] != float("inf")
]
if all_avg_times:
test_results["summary"] = {
"overall_avg_time": sum(all_avg_times) / len(all_avg_times),
"fastest_mode": min(test_results["tests"].items(), key=lambda x: x[1]["avg_time"])[0],
"slowest_mode": max(test_results["tests"].items(), key=lambda x: x[1]["avg_time"])[0],
}
return test_results
except Exception as e:
return {"status": "failed", "tests": {}, "summary": {}, "error": str(e)}

View File

@@ -1,6 +1,6 @@
from src.chat.message_receive.chat_stream import get_chat_manager
import time
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.chat.message_receive.message import MessageRecvS4U

View File

@@ -7,7 +7,7 @@ from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api

View File

@@ -7,7 +7,7 @@ from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api
from src.mais4u.constant_s4u import ENABLE_S4U

View File

@@ -1,6 +1,6 @@
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
import time
from src.chat.utils.utils import get_recent_group_speaker

View File

@@ -6,7 +6,7 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.llm_models.utils_model import LLMRequest
from src.manager.async_task_manager import AsyncTask, async_task_manager

View File

@@ -9,7 +9,7 @@ from json_repair import repair_json
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager

View File

@@ -23,17 +23,20 @@ class BaseEventHandler(ABC):
"""是否拦截消息,默认为否"""
init_subscribe: List[Union[EventType, str]] = [EventType.UNKNOWN]
"""初始化时订阅的事件名称"""
plugin_name = None
def __init__(self):
self.log_prefix = "[EventHandler]"
"""对应插件名"""
self.plugin_config: Optional[Dict] = None
"""插件配置字典"""
self.subscribed_events = []
"""订阅的事件列表"""
if EventType.UNKNOWN in self.init_subscribe:
raise NotImplementedError("事件处理器必须指定 event_type")
from src.plugin_system.core.component_registry import component_registry
self.plugin_config = component_registry.get_plugin_config(self.plugin_name)
@abstractmethod
async def execute(self, kwargs: dict | None) -> Tuple[bool, bool, Optional[str]]:
"""执行事件处理的抽象方法,子类必须实现
@@ -89,15 +92,7 @@ class BaseEventHandler(ABC):
weight=cls.weight,
intercept_message=cls.intercept_message,
)
def set_plugin_config(self, plugin_config: Dict) -> None:
"""设置插件配置
Args:
plugin_config (dict): 插件配置字典
"""
self.plugin_config = plugin_config
def set_plugin_name(self, plugin_name: str) -> None:
"""设置插件名称

View File

@@ -248,6 +248,7 @@ class ComponentRegistry:
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
return False
handler_class.plugin_name = handler_info.plugin_name
self._event_handler_registry[handler_name] = handler_class
if not handler_info.enabled:

View File

@@ -145,11 +145,12 @@ class EventManager:
logger.info(f"事件 {event_name} 已禁用")
return True
def register_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
def register_event_handler(self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None) -> bool:
"""注册事件处理器
Args:
handler_class (Type[BaseEventHandler]): 事件处理器类
plugin_config (Optional[dict]): 插件配置字典默认为None
Returns:
bool: 注册成功返回True已存在返回False
@@ -163,7 +164,12 @@ class EventManager:
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
return False
self._event_handlers[handler_name] = handler_class()
# 创建事件处理器实例,传递插件配置
handler_instance = handler_class()
if plugin_config is not None and hasattr(handler_instance, 'set_plugin_config'):
handler_instance.set_plugin_config(plugin_config)
self._event_handlers[handler_name] = handler_instance
# 处理init_subscribe缓存失败的订阅
if self._event_handlers[handler_name].init_subscribe:

View File

@@ -6,7 +6,7 @@ from src.plugin_system.core.global_announcement_manager import global_announceme
from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
import inspect
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger

View File

@@ -39,8 +39,9 @@ class EmojiAction(BaseAction):
llm_judge_prompt = """
判定是否需要使用表情动作的条件:
1. 用户明确要求使用表情包
2. 这是一个适合表达强烈情绪的场合
3. 不要发送太多表情包,如果你已经发送过多个表情包则回答""
2. 这是一个适合表达情绪的场合
3. 发表情包能使当前对话更有趣
4. 不要发送太多表情包,如果你已经发送过多个表情包则回答""
请回答""""
"""

View File

@@ -1746,3 +1746,32 @@ class SetGroupSignHandler(BaseEventHandler):
else:
logger.error("事件 napcat_set_group_sign 请求失败!")
return HandlerResult(False, False, {"status": "error"})
# ===PERSONAL===
class SetInputStatusHandler(BaseEventHandler):
handler_name: str = "napcat_set_input_status_handler"
handler_description: str = "设置输入状态"
weight: int = 100
intercept_message: bool = False
init_subscribe = [NapcatEvent.PERSONAL.SET_INPUT_STATUS]
async def execute(self, params: dict):
raw = params.get("raw", {})
user_id = params.get("user_id", "")
event_type = params.get("event_type", 0)
if params.get("raw", ""):
user_id = raw.get("user_id", "")
event_type = raw.get("event_type", 0)
if not user_id or event_type is None:
logger.error("事件 napcat_set_input_status 缺少必要参数: user_id 或 event_type")
return HandlerResult(False, False, {"status": "error"})
payload = {"user_id": str(user_id), "event_type": int(event_type)}
response = await send_handler.send_message_to_napcat(action="set_input_status", params=payload)
if response.get("status", "") == "ok":
return HandlerResult(True, True, response)
else:
logger.error("事件 napcat_set_input_status 请求失败!")
return HandlerResult(False, False, {"status": "error"})

View File

@@ -1816,3 +1816,27 @@ class NapcatEvent:
"""
class FILE(Enum): ...
class PERSONAL(Enum):
SET_INPUT_STATUS = "napcat_set_input_status"
"""
设置输入状态
Args:
user_id (Optional[str|int]): 用户id(必需)
event_type (Optional[int]): 输入状态id(必需)
raw (Optional[dict]): 原始请求体
Returns:
dict: {
"status": "ok",
"retcode": 0,
"data": {
"result": 0,
"errMsg": "string"
},
"message": "string",
"wording": "string",
"echo": "string"
}
"""

View File

@@ -8,6 +8,7 @@ from typing import List
from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType, ConfigField
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.apis import config_api
from src.common.logger import get_logger
@@ -17,8 +18,6 @@ from .src.recv_handler.meta_event_handler import meta_event_handler
from .src.recv_handler.notice_handler import notice_handler
from .src.recv_handler.message_sending import message_send_instance
from .src.send_handler import send_handler
from .src.config import global_config
from .src.config.features_config import features_manager
from .src.config.migrate_features import auto_migrate_features
from .src.mmc_com_layer import mmc_start_com, router, mmc_stop_com
from .src.response_pool import put_response, check_timeout_response
@@ -134,13 +133,14 @@ async def message_process():
logger.debug(f"清理消息队列时出错: {e}")
async def napcat_server():
async def napcat_server(plugin_config: dict):
"""启动 Napcat WebSocket 连接(支持正向和反向连接)"""
mode = global_config.napcat_server.mode
# 使用插件系统配置API获取配置
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
logger.info(f"正在启动 adapter连接模式: {mode}")
try:
await websocket_manager.start_connection(message_recv)
await websocket_manager.start_connection(message_recv, plugin_config)
except Exception as e:
logger.error(f"启动 WebSocket 连接失败: {e}")
raise
@@ -157,11 +157,7 @@ async def graceful_shutdown():
except Exception as e:
logger.warning(f"停止消息重组器清理任务时出错: {e}")
# 停止功能管理器文件监控
try:
await features_manager.stop_file_watcher()
except Exception as e:
logger.warning(f"停止功能管理器文件监控时出错: {e}")
# 停止功能管理器文件监控(已迁移到插件系统配置,无需操作)
# 关闭消息处理器(包括消息缓冲器)
try:
@@ -233,16 +229,28 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
logger.info("启动消息重组器...")
await reassembler.start_cleanup_task()
# 初始化功能管理器
logger.info("正在初始化功能管理器...")
features_manager.load_config()
await features_manager.start_file_watcher(check_interval=2.0)
logger.info("功能管理器初始化完成")
# 功能管理器已迁移到插件系统配置
logger.info("功能配置已迁移到插件系统")
logger.info("开始启动Napcat Adapter")
message_send_instance.maibot_router = router
# 设置插件配置
message_send_instance.set_plugin_config(self.plugin_config)
# 设置chunker的插件配置
chunker.set_plugin_config(self.plugin_config)
# 设置response_pool的插件配置
from .src.response_pool import set_plugin_config as set_response_pool_config
set_response_pool_config(self.plugin_config)
# 设置send_handler的插件配置
send_handler.set_plugin_config(self.plugin_config)
# 设置message_handler的插件配置
message_handler.set_plugin_config(self.plugin_config)
# 设置notice_handler的插件配置
notice_handler.set_plugin_config(self.plugin_config)
# 设置meta_event_handler的插件配置
meta_event_handler.set_plugin_config(self.plugin_config)
# 创建单独的异步任务,防止阻塞主线程
asyncio.create_task(napcat_server())
asyncio.create_task(mmc_start_com())
asyncio.create_task(napcat_server(self.plugin_config))
asyncio.create_task(mmc_start_com(self.plugin_config))
asyncio.create_task(message_process())
asyncio.create_task(check_timeout_response())
@@ -277,10 +285,52 @@ class NapcatAdapterPlugin(BasePlugin):
"plugin": {
"name": ConfigField(type=str, default="napcat_adapter_plugin", description="插件名称"),
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
"config_version": ConfigField(type=str, default="1.2.0", description="配置文件版本"),
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
},
"inner": {
"version": ConfigField(type=str, default="0.2.1", description="配置版本号,请勿修改"),
},
"nickname": {
"nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"),
},
"napcat_server": {
"mode": ConfigField(type=str, default="reverse", description="连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]),
"host": ConfigField(type=str, default="localhost", description="主机地址"),
"port": ConfigField(type=int, default=8095, description="端口号"),
"url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL如 ws://localhost:8080/ws (仅在forward模式下使用)"),
"access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"),
"heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"),
},
"maibot_server": {
"host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址即HOST字段"),
"port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口即PORT字段"),
"platform_name": ConfigField(type=str, default="napcat", description="平台名称,用于消息路由"),
},
"voice": {
"use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音请确保你配置了tts并有对应的adapter"),
},
"slicing": {
"max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小单位为字节默认64KB"),
"delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"),
},
"debug": {
"level": ConfigField(type=str, default="INFO", description="日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
}
}
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息",
"inner": "内部配置信息(请勿修改)",
"nickname": "昵称配置(目前未使用)",
"napcat_server": "Napcat连接的ws服务设置",
"maibot_server": "连接麦麦的ws服务设置",
"voice": "发送语音设置",
"slicing": "WebSocket消息切片设置",
"debug": "调试设置"
}
def register_events(self):
# 注册事件
for e in event_types.NapcatEvent.ON_RECEIVED:

View File

@@ -0,0 +1,2 @@
# 配置已迁移到插件系统,此文件不再需要
# 所有配置访问应通过插件系统的 config_api 进行

View File

@@ -7,7 +7,7 @@ from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from .config.features_config import features_manager
from src.plugin_system.apis import config_api
from .recv_handler import RealMessageType
@@ -43,6 +43,11 @@ class SimpleMessageBuffer:
self.lock = asyncio.Lock()
self.merge_callback = merge_callback
self._shutdown = False
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
def get_session_id(self, event_data: Dict[str, Any]) -> str:
"""根据事件数据生成会话ID"""
@@ -97,8 +102,7 @@ class SimpleMessageBuffer:
return True
# 检查屏蔽前缀
config = features_manager.get_config()
block_prefixes = tuple(config.message_buffer_block_prefixes)
block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", []))
text = text.strip()
if text.startswith(block_prefixes):
@@ -124,15 +128,15 @@ class SimpleMessageBuffer:
if self._shutdown:
return False
config = features_manager.get_config()
if not config.enable_message_buffer:
# 检查是否启用消息缓冲
if not config_api.get_plugin_config(self.plugin_config, "features.enable_message_buffer", False):
return False
# 检查是否启用对应类型的缓冲
message_type = event_data.get("message_type", "")
if message_type == "group" and not config.message_buffer_enable_group:
if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False):
return False
elif message_type == "private" and not config.message_buffer_enable_private:
elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False):
return False
# 提取文本
@@ -154,7 +158,7 @@ class SimpleMessageBuffer:
session = self.buffer_pool[session_id]
# 检查是否超过最大组件数量
if len(session.messages) >= config.message_buffer_max_components:
if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5):
logger.info(f"会话 {session_id} 消息数量达到上限,强制合并")
asyncio.create_task(self._force_merge_session(session_id))
self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event)
@@ -187,8 +191,8 @@ class SimpleMessageBuffer:
async def _wait_and_start_merge(self, session_id: str):
"""等待初始延迟后开始合并定时器"""
config = features_manager.get_config()
await asyncio.sleep(config.message_buffer_initial_delay)
initial_delay = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_initial_delay", 0.5)
await asyncio.sleep(initial_delay)
async with self.lock:
session = self.buffer_pool.get(session_id)
@@ -206,8 +210,8 @@ class SimpleMessageBuffer:
async def _wait_and_merge(self, session_id: str):
"""等待合并间隔后执行合并"""
config = features_manager.get_config()
await asyncio.sleep(config.message_buffer_interval)
interval = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_interval", 2.0)
await asyncio.sleep(interval)
await self._merge_session(session_id)
async def _force_merge_session(self, session_id: str):

View File

@@ -9,7 +9,7 @@ import uuid
import asyncio
import time
from typing import List, Dict, Any, Optional, Union
from .config import global_config
from src.plugin_system.apis import config_api
from src.common.logger import get_logger
@@ -20,7 +20,15 @@ class MessageChunker:
"""消息切片器,用于处理大消息的分片发送"""
def __init__(self):
self.max_chunk_size = global_config.slicing.max_frame_size * 1024
self.max_chunk_size = 64 * 1024 # 默认值,将在设置配置时更新
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
if plugin_config:
max_frame_size = config_api.get_plugin_config(plugin_config, "slicing.max_frame_size", 64)
self.max_chunk_size = max_frame_size * 1024
def should_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool:
"""判断消息是否需要切片"""

View File

@@ -0,0 +1,44 @@
from maim_message import Router, RouteConfig, TargetConfig
from src.common.logger import get_logger
from .send_handler import send_handler
from src.plugin_system.apis import config_api
logger = get_logger("napcat_adapter")
router = None
def create_router(plugin_config: dict):
"""创建路由器实例"""
global router
platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "napcat")
host = config_api.get_plugin_config(plugin_config, "maibot_server.host", "localhost")
port = config_api.get_plugin_config(plugin_config, "maibot_server.port", 8000)
route_config = RouteConfig(
route_config={
platform_name: TargetConfig(
url=f"ws://{host}:{port}/ws",
token=None,
)
}
)
router = Router(route_config)
return router
async def mmc_start_com(plugin_config: dict = None):
"""启动MaiBot连接"""
logger.info("正在连接MaiBot")
if plugin_config:
create_router(plugin_config)
if router:
router.register_class_handler(send_handler.handle_message)
await router.run()
async def mmc_stop_com():
"""停止MaiBot连接"""
if router:
await router.stop()

View File

@@ -5,8 +5,7 @@ from ...CONSTS import PLUGIN_NAME
logger = get_logger("napcat_adapter")
from ..config import global_config
from ..config.features_config import features_manager
from src.plugin_system.apis import config_api
from ..message_buffer import SimpleMessageBuffer
from ..utils import (
get_group_info,
@@ -48,9 +47,17 @@ class MessageHandler:
def __init__(self):
self.server_connection: Server.ServerConnection = None
self.bot_id_list: Dict[int, bool] = {}
self.plugin_config = None
# 初始化简化消息缓冲器,传入回调函数
self.message_buffer = SimpleMessageBuffer(merge_callback=self._send_buffered_message)
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
# 将配置传递给消息缓冲器
if self.message_buffer:
self.message_buffer.set_plugin_config(plugin_config)
async def shutdown(self):
"""关闭消息处理器,清理资源"""
if self.message_buffer:
@@ -90,21 +97,21 @@ class MessageHandler:
# 使用新的权限管理器检查权限
if group_id:
if not features_manager.is_group_allowed(group_id):
if not config_api.get_plugin_config(self.plugin_config, f"features.group_allowed.{group_id}", True):
logger.warning("群聊不在聊天权限范围内,消息被丢弃")
return False
else:
if not features_manager.is_private_allowed(user_id):
if not config_api.get_plugin_config(self.plugin_config, f"features.private_allowed.{user_id}", True):
logger.warning("私聊不在聊天权限范围内,消息被丢弃")
return False
# 检查全局禁止名单
if not ignore_global_list and features_manager.is_user_banned(user_id):
if not ignore_global_list and config_api.get_plugin_config(self.plugin_config, f"features.user_banned.{user_id}", False):
logger.warning("用户在全局黑名单中,消息被丢弃")
return False
# 检查QQ官方机器人
if features_manager.is_qq_bot_banned() and group_id and not ignore_bot:
if config_api.get_plugin_config(self.plugin_config, "features.qq_bot_banned", False) and group_id and not ignore_bot:
logger.debug("开始判断是否为机器人")
member_info = await get_member_info(self.get_server_connection(), group_id, user_id)
if member_info:
@@ -149,7 +156,7 @@ class MessageHandler:
# 发送者用户信息
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
user_id=sender_info.get("user_id"),
user_nickname=sender_info.get("nickname"),
user_cardname=sender_info.get("card"),
@@ -175,7 +182,7 @@ class MessageHandler:
nickname = fetched_member_info.get("nickname") if fetched_member_info else None
# 发送者用户信息
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
user_id=sender_info.get("user_id"),
user_nickname=nickname,
user_cardname=None,
@@ -192,7 +199,7 @@ class MessageHandler:
group_name = fetched_group_info.get("group_name")
group_info: GroupInfo = GroupInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
group_id=raw_message.get("group_id"),
group_name=group_name,
)
@@ -210,7 +217,7 @@ class MessageHandler:
# 发送者用户信息
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
user_id=sender_info.get("user_id"),
user_nickname=sender_info.get("nickname"),
user_cardname=sender_info.get("card"),
@@ -223,7 +230,7 @@ class MessageHandler:
group_name = fetched_group_info.get("group_name")
group_info: GroupInfo = GroupInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
group_id=raw_message.get("group_id"),
group_name=group_name,
)
@@ -233,12 +240,12 @@ class MessageHandler:
return None
additional_config: dict = {}
if global_config.voice.use_tts:
if config_api.get_plugin_config(self.plugin_config, "voice.use_tts"):
additional_config["allow_tts"] = True
# 消息信息
message_info: BaseMessageInfo = BaseMessageInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
message_id=message_id,
time=message_time,
user_info=user_info,
@@ -260,14 +267,14 @@ class MessageHandler:
return None
# 检查是否需要使用消息缓冲
if features_manager.is_message_buffer_enabled():
if config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enabled", False):
# 检查消息类型是否启用缓冲
message_type = raw_message.get("message_type")
should_use_buffer = False
if message_type == "group" and features_manager.is_message_buffer_group_enabled():
if message_type == "group" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_group_enabled", False):
should_use_buffer = True
elif message_type == "private" and features_manager.is_message_buffer_private_enabled():
elif message_type == "private" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_private_enabled", False):
should_use_buffer = True
if should_use_buffer:

View File

@@ -2,7 +2,7 @@ import asyncio
from src.common.logger import get_logger
from ..message_chunker import chunker
from ..config import global_config
from src.plugin_system.apis import config_api
logger = get_logger("napcat_adapter")
from maim_message import MessageBase, Router
@@ -14,10 +14,15 @@ class MessageSending:
"""
maibot_router: Router = None
plugin_config = None
def __init__(self):
pass
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
async def message_send(self, message_base: MessageBase) -> bool:
"""
发送消息Ada -> MMC 方向需要实现切片
@@ -52,9 +57,10 @@ class MessageSending:
return False
# 使用配置中的延迟时间
if i < len(chunks) - 1:
delay_seconds = global_config.slicing.delay_ms / 1000.0
logger.debug(f"切片发送延迟: {global_config.slicing.delay_ms}毫秒")
if i < len(chunks) - 1 and self.plugin_config:
delay_ms = config_api.get_plugin_config(self.plugin_config, "slicing.delay_ms", 10)
delay_seconds = delay_ms / 1000.0
logger.debug(f"切片发送延迟: {delay_ms}毫秒")
await asyncio.sleep(delay_seconds)
logger.debug("所有切片发送完成")

View File

@@ -1,7 +1,7 @@
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from ..config import global_config
from src.plugin_system.apis import config_api
import time
import asyncio
@@ -14,8 +14,15 @@ class MetaEventHandler:
"""
def __init__(self):
self.interval = global_config.napcat_server.heartbeat_interval
self.interval = 5.0 # 默认值稍后通过set_plugin_config设置
self._interval_checking = False
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
# 更新interval值
self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
async def handle_meta_event(self, message: dict) -> None:
event_type = message.get("meta_event_type")

View File

@@ -8,8 +8,7 @@ from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from ..config import global_config
from ..config.features_config import features_manager
from src.plugin_system.apis import config_api
from ..database import BanUser, db_manager, is_identical
from . import NoticeType, ACCEPT_FORMAT
from .message_sending import message_send_instance
@@ -38,6 +37,11 @@ class NoticeHandler:
def __init__(self):
self.server_connection: Server.ServerConnection | None = None
self.last_poke_time: float = 0.0 # 记录最后一次针对机器人的戳一戳时间
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
@@ -112,7 +116,7 @@ class NoticeHandler:
sub_type = raw_message.get("sub_type")
match sub_type:
case NoticeType.Notify.poke:
if features_manager.is_poke_enabled() and await message_handler.check_allow_to_chat(
if config_api.get_plugin_config(self.plugin_config, "features.poke_enabled", True) and await message_handler.check_allow_to_chat(
user_id, group_id, False, False
):
logger.info("处理戳一戳消息")
@@ -159,13 +163,13 @@ class NoticeHandler:
else:
logger.warning("无法获取notice消息所在群的名称")
group_info = GroupInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
group_id=group_id,
group_name=group_name,
)
message_info: BaseMessageInfo = BaseMessageInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
message_id="notice",
time=message_time,
user_info=user_info,
@@ -206,7 +210,7 @@ class NoticeHandler:
# 防抖检查:如果是针对机器人的戳一戳,检查防抖时间
if self_id == target_id:
current_time = time.time()
debounce_seconds = features_manager.get_config().poke_debounce_seconds
debounce_seconds = config_api.get_plugin_config(self.plugin_config, "features.poke_debounce_seconds", 2.0)
if self.last_poke_time > 0:
time_diff = current_time - self.last_poke_time
@@ -243,7 +247,7 @@ class NoticeHandler:
else:
# 如果配置为忽略不是针对自己的戳一戳则直接返回None
if features_manager.is_non_self_poke_ignored():
if config_api.get_plugin_config(self.plugin_config, "features.non_self_poke_ignored", False):
logger.info("忽略不是针对自己的戳一戳消息")
return None, None
@@ -268,7 +272,7 @@ class NoticeHandler:
logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本")
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id,
user_nickname=user_name,
user_cardname=user_cardname,
@@ -299,7 +303,7 @@ class NoticeHandler:
operator_nickname = "QQ用户"
operator_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=operator_id,
user_nickname=operator_nickname,
user_cardname=operator_cardname,
@@ -328,7 +332,7 @@ class NoticeHandler:
user_nickname = fetched_member_info.get("nickname")
user_cardname = fetched_member_info.get("card")
banned_user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
@@ -367,7 +371,7 @@ class NoticeHandler:
operator_nickname = "QQ用户"
operator_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=operator_id,
user_nickname=operator_nickname,
user_cardname=operator_cardname,
@@ -393,7 +397,7 @@ class NoticeHandler:
else:
logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效")
lifted_user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
@@ -436,13 +440,13 @@ class NoticeHandler:
else:
logger.warning("无法获取notice消息所在群的名称")
group_info = GroupInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
group_id=group_id,
group_name=group_name,
)
message_info: BaseMessageInfo = BaseMessageInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
message_id="notice",
time=time.time(),
user_info=None, # 自然解除禁言没有操作者
@@ -493,7 +497,7 @@ class NoticeHandler:
user_cardname = fetched_member_info.get("card")
lifted_user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,

View File

@@ -1,13 +1,20 @@
import asyncio
import time
from typing import Dict
from .config import global_config
from src.common.logger import get_logger
from src.plugin_system.apis import config_api
logger = get_logger("napcat_adapter")
response_dict: Dict = {}
response_time_dict: Dict = {}
plugin_config = None
def set_plugin_config(config: dict):
"""设置插件配置"""
global plugin_config
plugin_config = config
async def get_response(request_id: str, timeout: int = 10) -> dict:
@@ -38,11 +45,17 @@ async def check_timeout_response() -> None:
while True:
cleaned_message_count: int = 0
now_time = time.time()
# 获取心跳间隔配置
heartbeat_interval = 30 # 默认值
if plugin_config:
heartbeat_interval = config_api.get_plugin_config(plugin_config, "napcat_server.heartbeat_interval", 30)
for echo_id, response_time in list(response_time_dict.items()):
if now_time - response_time > global_config.napcat_server.heartbeat_interval:
if now_time - response_time > heartbeat_interval:
cleaned_message_count += 1
response_dict.pop(echo_id)
response_time_dict.pop(echo_id)
logger.warning(f"响应消息 {echo_id} 超时,已删除")
logger.info(f"已删除 {cleaned_message_count} 条超时响应消息")
await asyncio.sleep(global_config.napcat_server.heartbeat_interval)
await asyncio.sleep(heartbeat_interval)

View File

@@ -12,9 +12,9 @@ from maim_message import (
MessageBase,
)
from typing import Dict, Any, Tuple, Optional
from src.plugin_system.apis import config_api
from . import CommandType
from .config import global_config
from .response_pool import get_response
from src.common.logger import get_logger
@@ -22,12 +22,16 @@ logger = get_logger("napcat_adapter")
from .utils import get_image_format, convert_image_to_gif
from .recv_handler.message_sending import message_send_instance
from .websocket_manager import websocket_manager
from .config.features_config import features_manager
class SendHandler:
def __init__(self):
self.server_connection: Optional[Server.ServerConnection] = None
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
@@ -287,11 +291,8 @@ class SendHandler:
"""处理回复消息"""
reply_seg = {"type": "reply", "data": {"id": id}}
# 获取功能配置
ft_config = features_manager.get_config()
# 检查是否启用引用艾特功能
if not ft_config.enable_reply_at:
if not config_api.get_plugin_config(self.plugin_config, "features.enable_reply_at", False):
return reply_seg
try:
@@ -310,7 +311,7 @@ class SendHandler:
return reply_seg
# 根据概率决定是否艾特用户
if random.random() < ft_config.reply_at_rate:
if random.random() < config_api.get_plugin_config(self.plugin_config, "features.reply_at_rate", 0.5):
at_seg = {"type": "at", "data": {"qq": str(replied_user_id)}}
# 在艾特后面添加一个空格
text_seg = {"type": "text", "data": {"text": " "}}
@@ -354,7 +355,11 @@ class SendHandler:
def handle_voice_message(self, encoded_voice: str) -> dict:
"""处理语音消息"""
if not global_config.voice.use_tts:
use_tts = False
if self.plugin_config:
use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False)
if not use_tts:
logger.warning("未启用语音消息处理")
return {}
if not encoded_voice:

View File

@@ -2,9 +2,9 @@ import asyncio
import websockets as Server
from typing import Optional, Callable, Any
from src.common.logger import get_logger
from src.plugin_system.apis import config_api
logger = get_logger("napcat_adapter")
from .config import global_config
class WebSocketManager:
@@ -16,10 +16,12 @@ class WebSocketManager:
self.is_running = False
self.reconnect_interval = 5 # 重连间隔(秒)
self.max_reconnect_attempts = 10 # 最大重连次数
self.plugin_config = None
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict) -> None:
"""根据配置启动 WebSocket 连接"""
mode = global_config.napcat_server.mode
self.plugin_config = plugin_config
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
if mode == "reverse":
await self._start_reverse_connection(message_handler)
@@ -30,8 +32,8 @@ class WebSocketManager:
async def _start_reverse_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
"""启动反向连接(作为服务器)"""
host = global_config.napcat_server.host
port = global_config.napcat_server.port
host = config_api.get_plugin_config(self.plugin_config, "napcat_server.host")
port = config_api.get_plugin_config(self.plugin_config, "napcat_server.port")
logger.info(f"正在启动反向连接模式,监听地址: ws://{host}:{port}")
@@ -68,9 +70,10 @@ class WebSocketManager:
connect_kwargs = {"max_size": 2**26}
# 如果配置了访问令牌,添加到请求头
if global_config.napcat_server.access_token:
access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token")
if access_token:
connect_kwargs["additional_headers"] = {
"Authorization": f"Bearer {global_config.napcat_server.access_token}"
"Authorization": f"Bearer {access_token}"
}
logger.info("已添加访问令牌到连接请求头")
@@ -112,15 +115,14 @@ class WebSocketManager:
def _get_forward_url(self) -> str:
"""获取正向连接的 URL"""
config = global_config.napcat_server
# 如果配置了完整的 URL直接使用
if config.url:
return config.url
url = config_api.get_plugin_config(self.plugin_config, "napcat_server.url")
if url:
return url
# 否则根据 host 和 port 构建 URL
host = config.host
port = config.port
host = config_api.get_plugin_config(self.plugin_config, "napcat_server.host")
port = config_api.get_plugin_config(self.plugin_config, "napcat_server.port")
return f"ws://{host}:{port}"
async def stop_connection(self) -> None: