This commit is contained in:
minecraft1024a
2025-09-06 10:43:10 +08:00
61 changed files with 1250 additions and 2209 deletions

View File

@@ -149,7 +149,7 @@ class CycleProcessor:
logger.info(f"{self.log_prefix} 开始第{self.context.cycle_counter}次思考")
if ENABLE_S4U:
await send_typing()
await send_typing(self.context.chat_stream.user_info.user_id)
loop_start_time = time.time()

View File

@@ -121,7 +121,7 @@ class CycleDetail:
self.loop_action_info = loop_info["loop_action_info"]
async def send_typing():
async def send_typing(user_id):
"""
发送打字状态指示
@@ -139,6 +139,11 @@ async def send_typing():
group_info=group_info,
)
from plugin_system.core.event_manager import event_manager
from src.plugins.built_in.napcat_adapter_plugin.event_types import NapcatEvent
# 设置正在输入状态
await event_manager.trigger_event(NapcatEvent.PERSONAL.SET_INPUT_STATUS,user_id=user_id,event_type=1)
await send_api.custom_to_stream(
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
)

View File

@@ -12,7 +12,7 @@ from .hfc_context import HfcContext
# 导入反注入系统
from src.chat.antipromptinjector import get_anti_injector
from src.chat.antipromptinjector.types import ProcessResult
from src.chat.utils.prompt_builder import Prompt
from src.chat.utils.prompt import Prompt
logger = get_logger("hfc")
anti_injector_logger = get_logger("anti_injector")

View File

@@ -13,7 +13,7 @@ from src.common.database.sqlalchemy_models import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager

View File

@@ -11,7 +11,7 @@ from src.config.config import global_config, model_config
from src.common.logger import get_logger
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_db_session
logger = get_logger("expression_selector")

View File

@@ -8,7 +8,7 @@ from datetime import datetime
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.memory_system.Hippocampus import hippocampus_manager

View File

@@ -12,7 +12,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor

View File

@@ -9,7 +9,7 @@ from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_actions,
get_actions_by_timestamp_with_chat,

View File

@@ -1,6 +1,6 @@
"""
默认回复生成器 - 集成SmartPrompt系统
使用重构后的SmartPrompt系统替换原有的复杂提示词构建逻辑
默认回复生成器 - 集成统一Prompt系统
使用重构后的统一Prompt系统替换原有的复杂提示词构建逻辑
"""
import traceback
@@ -11,7 +11,6 @@ import re
from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
from src.chat.utils.prompt_utils import PromptUtils
from src.mais4u.mai_think import mai_thinking_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
@@ -22,7 +21,7 @@ from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
@@ -37,8 +36,8 @@ from src.person_info.person_info import get_person_info_manager
from src.plugin_system.base.component_types import ActionInfo, EventType
from src.plugin_system.apis import llm_api
# 导入新的智能Prompt系统
from src.chat.utils.smart_prompt import SmartPrompt, SmartPromptParameters
# 导入新的统一Prompt系统
from src.chat.utils.prompt import Prompt, PromptParameters
logger = get_logger("replyer")
@@ -598,7 +597,8 @@ class DefaultReplyer:
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
"""解析回复目标消息 - 使用共享工具"""
return PromptUtils.parse_reply_target(target_message)
from src.chat.utils.prompt import Prompt
return Prompt.parse_reply_target(target_message)
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
"""构建关键词反应提示
@@ -873,7 +873,8 @@ class DefaultReplyer:
target_user_info = None
if sender:
target_user_info = await person_info_manager.get_person_info_by_name(sender)
from src.chat.utils.prompt import Prompt
# 并行执行六个构建任务
task_results = await asyncio.gather(
self._time_and_run_task(
@@ -886,7 +887,7 @@ class DefaultReplyer:
),
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
self._time_and_run_task(
PromptUtils.build_cross_context(chat_id, target_user_info, global_config.personality.prompt_mode),
Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info),
"cross_context",
),
)
@@ -970,8 +971,8 @@ class DefaultReplyer:
# 根据配置选择模板
current_prompt_mode = global_config.personality.prompt_mode
# 使用重构后的SmartPrompt系统
prompt_params = SmartPromptParameters(
# 使用新的统一Prompt系统 - 创建PromptParameters
prompt_parameters = PromptParameters(
chat_id=chat_id,
is_group_chat=is_group_chat,
sender=sender,
@@ -1004,12 +1005,19 @@ class DefaultReplyer:
action_descriptions=action_descriptions,
)
# 使用重构后的SmartPrompt系统
smart_prompt = SmartPrompt(
template_name=None, # 由current_prompt_mode自动选择
parameters=prompt_params,
)
prompt_text = await smart_prompt.build_prompt()
# 使用新的统一Prompt系统 - 使用正确的模板名称
template_name = None
if current_prompt_mode == "s4u":
template_name = "s4u_style_prompt"
elif current_prompt_mode == "normal":
template_name = "normal_style_prompt"
elif current_prompt_mode == "minimal":
template_name = "default_expressor_prompt"
# 获取模板内容
template_prompt = await global_prompt_manager.get_prompt_async(template_name)
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
prompt_text = await prompt.build()
return prompt_text
@@ -1110,8 +1118,8 @@ class DefaultReplyer:
template_name = "default_expressor_prompt"
# 使用重构后的SmartPrompt系统 - Expressor模式
prompt_params = SmartPromptParameters(
# 使用新的统一Prompt系统 - Expressor模式创建PromptParameters
prompt_parameters = PromptParameters(
chat_id=chat_id,
is_group_chat=is_group_chat,
sender=sender,
@@ -1131,8 +1139,10 @@ class DefaultReplyer:
relation_info_block=relation_info,
)
smart_prompt = SmartPrompt(parameters=prompt_params)
prompt_text = await smart_prompt.build_prompt()
# 使用新的统一Prompt系统 - Expressor模式
template_prompt = await global_prompt_manager.get_prompt_async("default_expressor_prompt")
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
prompt_text = await prompt.build()
return prompt_text

823
src/chat/utils/prompt.py Normal file
View File

@@ -0,0 +1,823 @@
"""
统一提示词系统 - 合并模板管理和智能构建功能
将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类
"""
import re
import asyncio
import time
import contextvars
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Literal, Tuple
from contextlib import asynccontextmanager
from rich.traceback import install
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
install(extra_lines=3)
logger = get_logger("unified_prompt")
@dataclass
class PromptParameters:
"""统一提示词参数系统"""
# 基础参数
chat_id: str = ""
is_group_chat: bool = False
sender: str = ""
target: str = ""
reply_to: str = ""
extra_info: str = ""
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
# 功能开关
enable_tool: bool = True
enable_memory: bool = True
enable_expression: bool = True
enable_relation: bool = True
enable_cross_context: bool = True
enable_knowledge: bool = True
# 性能控制
max_context_messages: int = 50
# 调试选项
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: Optional[Dict[str, Any]] = None
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: Optional[Dict[str, Any]] = None
# 已构建的内容块
expression_habits_block: str = ""
relation_info_block: str = ""
memory_block: str = ""
tool_info_block: str = ""
knowledge_prompt: str = ""
cross_context_block: str = ""
# 其他内容块
keywords_reaction_prompt: str = ""
extra_info_block: str = ""
time_block: str = ""
identity_block: str = ""
schedule_block: str = ""
moderation_prompt_block: str = ""
reply_target_block: str = ""
mood_prompt: str = ""
action_descriptions: str = ""
# 可用动作信息
available_actions: Optional[Dict[str, Any]] = None
def validate(self) -> List[str]:
"""参数验证"""
errors = []
if not self.chat_id:
errors.append("chat_id不能为空")
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
errors.append("prompt_mode必须是's4u''normal''minimal'")
if self.max_context_messages <= 0:
errors.append("max_context_messages必须大于0")
return errors
class PromptContext:
"""提示词上下文管理器"""
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock()
@property
def _current_context(self) -> Optional[str]:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value) # type: ignore
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
"""创建一个异步的临时提示模板作用域"""
if context_id is not None:
try:
await asyncio.wait_for(self._context_lock.acquire(), timeout=5.0)
try:
if context_id not in self._context_prompts:
self._context_prompts[context_id] = {}
finally:
self._context_lock.release()
except asyncio.TimeoutError:
logger.warning(f"获取上下文锁超时context_id: {context_id}")
context_id = None
previous_context = self._current_context
token = self._current_context_var.set(context_id) if context_id else None
else:
previous_context = self._current_context
token = None
try:
yield self
finally:
if context_id is not None and token is not None:
try:
self._current_context_var.reset(token)
except Exception as e:
logger.warning(f"恢复上下文时出错: {e}")
try:
self._current_context = previous_context
except Exception:
...
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板"""
async with self._context_lock:
current_context = self._current_context
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
if (
current_context
and current_context in self._context_prompts
and name in self._context_prompts[current_context]
):
return self._context_prompts[current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
if target_context := context_id or self._current_context:
if prompt.name:
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
class PromptManager:
"""统一提示词管理器"""
def __init__(self):
self._prompts = {}
self._counter = 0
self._context = PromptContext()
self._lock = asyncio.Lock()
@asynccontextmanager
async def async_message_scope(self, message_id: Optional[str] = None):
"""为消息处理创建异步临时作用域"""
async with self._context.async_scope(message_id):
yield self
async def get_prompt_async(self, name: str) -> "Prompt":
"""异步获取提示模板"""
context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None:
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
return context_prompt
async with self._lock:
if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]
def generate_name(self, template: str) -> str:
"""为未命名的prompt生成名称"""
self._counter += 1
return f"prompt_{self._counter}"
def register(self, prompt: "Prompt") -> None:
"""注册一个prompt"""
if not prompt.name:
prompt.name = self.generate_name(prompt.template)
self._prompts[prompt.name] = prompt
def add_prompt(self, name: str, fstr: str) -> "Prompt":
"""添加新提示模板"""
prompt = Prompt(fstr, name=name)
if prompt.name:
self._prompts[prompt.name] = prompt
return prompt
async def format_prompt(self, name: str, **kwargs) -> str:
"""格式化提示模板"""
prompt = await self.get_prompt_async(name)
result = prompt.format(**kwargs)
return result
# 全局单例
global_prompt_manager = PromptManager()
class Prompt:
"""
统一提示词类 - 合并模板管理和智能构建功能
真正的Prompt类支持模板管理和智能上下文构建
"""
# 临时标记,作为类常量
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
def __init__(
self,
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
should_register: bool = True
):
"""
初始化统一提示词
Args:
template: 提示词模板字符串
name: 提示词名称
parameters: 构建参数
should_register: 是否自动注册到全局管理器
"""
self.template = template
self.name = name
self.parameters = parameters or PromptParameters()
self.args = self._parse_template_args(template)
self._formatted_result = ""
# 预处理模板中的转义花括号
self._processed_template = self._process_escaped_braces(template)
# 自动注册
if should_register and not global_prompt_manager._context._current_context:
global_prompt_manager.register(self)
@staticmethod
def _process_escaped_braces(template) -> str:
"""处理模板中的转义花括号"""
if isinstance(template, list):
template = "\n".join(str(item) for item in template)
elif not isinstance(template, str):
template = str(template)
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
@staticmethod
def _restore_escaped_braces(template: str) -> str:
"""将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def _parse_template_args(self, template: str) -> List[str]:
"""解析模板参数"""
template_args = []
processed_template = self._process_escaped_braces(template)
result = re.findall(r"\{(.*?)}", processed_template)
for expr in result:
if expr and expr not in template_args:
template_args.append(expr)
return template_args
async def build(self) -> str:
"""
构建完整的提示词,包含智能上下文
Returns:
str: 构建完成的提示词文本
"""
# 参数验证
errors = self.parameters.validate()
if errors:
logger.error(f"参数验证失败: {', '.join(errors)}")
raise ValueError(f"参数验证失败: {', '.join(errors)}")
start_time = time.time()
try:
# 构建上下文数据
context_data = await self._build_context_data()
# 格式化模板
result = await self._format_with_context(context_data)
total_time = time.time() - start_time
logger.debug(f"Prompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
self._formatted_result = result
return result
except asyncio.TimeoutError as e:
logger.error(f"构建Prompt超时: {e}")
raise TimeoutError(f"构建Prompt超时: {e}")
except Exception as e:
logger.error(f"构建Prompt失败: {e}")
raise RuntimeError(f"构建Prompt失败: {e}")
async def _build_context_data(self) -> Dict[str, Any]:
"""构建智能上下文数据"""
# 并行执行所有构建任务
start_time = time.time()
timing_logs = {}
try:
# 准备构建任务
tasks = []
task_names = []
# 初始化预构建参数
pre_built_params = {}
if self.parameters.expression_habits_block:
pre_built_params["expression_habits_block"] = self.parameters.expression_habits_block
if self.parameters.relation_info_block:
pre_built_params["relation_info_block"] = self.parameters.relation_info_block
if self.parameters.memory_block:
pre_built_params["memory_block"] = self.parameters.memory_block
if self.parameters.tool_info_block:
pre_built_params["tool_info_block"] = self.parameters.tool_info_block
if self.parameters.knowledge_prompt:
pre_built_params["knowledge_prompt"] = self.parameters.knowledge_prompt
if self.parameters.cross_context_block:
pre_built_params["cross_context_block"] = self.parameters.cross_context_block
# 根据参数确定要构建的项
if self.parameters.enable_expression and not pre_built_params.get("expression_habits_block"):
tasks.append(self._build_expression_habits())
task_names.append("expression_habits")
if self.parameters.enable_memory and not pre_built_params.get("memory_block"):
tasks.append(self._build_memory_block())
task_names.append("memory_block")
if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"):
tasks.append(self._build_relation_info())
task_names.append("relation_info")
if self.parameters.enable_tool and not pre_built_params.get("tool_info_block"):
tasks.append(self._build_tool_info())
task_names.append("tool_info")
if self.parameters.enable_knowledge and not pre_built_params.get("knowledge_prompt"):
tasks.append(self._build_knowledge_info())
task_names.append("knowledge_info")
if self.parameters.enable_cross_context and not pre_built_params.get("cross_context_block"):
tasks.append(self._build_cross_context())
task_names.append("cross_context")
# 性能优化
base_timeout = 10.0
task_timeout = 2.0
timeout_seconds = min(
max(base_timeout, len(tasks) * task_timeout),
30.0,
)
max_concurrent_tasks = 5
if len(tasks) > max_concurrent_tasks:
results = []
for i in range(0, len(tasks), max_concurrent_tasks):
batch_tasks = tasks[i : i + max_concurrent_tasks]
batch_names = task_names[i : i + max_concurrent_tasks]
batch_results = await asyncio.wait_for(
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
)
results.extend(batch_results)
else:
results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
)
# 处理结果
context_data = {}
for i, result in enumerate(results):
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
if isinstance(result, Exception):
logger.error(f"构建任务{task_name}失败: {str(result)}")
elif isinstance(result, dict):
context_data.update(result)
# 添加预构建的参数
for key, value in pre_built_params.items():
if value:
context_data[key] = value
except asyncio.TimeoutError:
logger.error(f"构建超时 ({timeout_seconds}s)")
context_data = {}
for key, value in pre_built_params.items():
if value:
context_data[key] = value
# 构建聊天历史
if self.parameters.prompt_mode == "s4u":
await self._build_s4u_chat_context(context_data)
else:
await self._build_normal_chat_context(context_data)
# 补充基础信息
context_data.update({
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
"extra_info_block": self.parameters.extra_info_block,
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
"identity": self.parameters.identity_block,
"schedule_block": self.parameters.schedule_block,
"moderation_prompt": self.parameters.moderation_prompt_block,
"reply_target_block": self.parameters.reply_target_block,
"mood_state": self.parameters.mood_prompt,
"action_descriptions": self.parameters.action_descriptions,
})
total_time = time.time() - start_time
logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s")
return context_data
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
"""构建S4U模式的聊天上下文"""
if not self.parameters.message_list_before_now_long:
return
core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts(
self.parameters.message_list_before_now_long,
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
self.parameters.sender
)
context_data["core_dialogue_prompt"] = core_dialogue
context_data["background_dialogue_prompt"] = background_dialogue
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
"""构建normal模式的聊天上下文"""
if not self.parameters.chat_talking_prompt_short:
return
context_data["chat_info"] = f"""群里的聊天内容:
{self.parameters.chat_talking_prompt_short}"""
async def _build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
) -> Tuple[str, str]:
"""构建S4U风格的分离对话prompt"""
# 实现逻辑与原有SmartPromptBuilder相同
core_dialogue_list = []
bot_id = str(global_config.bot.qq_account)
for msg_dict in message_list_before_now:
try:
msg_user_id = str(msg_dict.get("user_id"))
reply_to = msg_dict.get("reply_to", "")
platform, reply_to_user_id = Prompt.parse_reply_target(reply_to)
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
core_dialogue_list.append(msg_dict)
except Exception as e:
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
# 构建背景对话 prompt
all_dialogue_prompt = ""
if message_list_before_now:
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
all_dialogue_prompt_str = build_readable_messages(
latest_25_msgs,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True,
)
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
# 构建核心对话 prompt
core_dialogue_prompt = ""
if core_dialogue_list:
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
if not has_bot_message:
core_dialogue_prompt = ""
else:
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :]
core_dialogue_prompt_str = build_readable_messages(
core_dialogue_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
core_dialogue_prompt = f"""--------------------------------
这是你和{sender}的对话,你们正在交流中:
{core_dialogue_prompt_str}
--------------------------------
"""
return core_dialogue_prompt, all_dialogue_prompt
async def _build_expression_habits(self) -> Dict[str, Any]:
"""构建表达习惯"""
# 简化的实现,完整实现需要导入相关模块
return {"expression_habits_block": ""}
async def _build_memory_block(self) -> Dict[str, Any]:
"""构建记忆块"""
# 简化的实现
return {"memory_block": ""}
async def _build_relation_info(self) -> Dict[str, Any]:
"""构建关系信息"""
try:
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
return {"relation_info_block": relation_info}
except Exception as e:
logger.error(f"构建关系信息失败: {e}")
return {"relation_info_block": ""}
async def _build_tool_info(self) -> Dict[str, Any]:
"""构建工具信息"""
# 简化的实现
return {"tool_info_block": ""}
async def _build_knowledge_info(self) -> Dict[str, Any]:
"""构建知识信息"""
# 简化的实现
return {"knowledge_prompt": ""}
async def _build_cross_context(self) -> Dict[str, Any]:
"""构建跨群上下文"""
try:
cross_context = await Prompt.build_cross_context(
self.parameters.chat_id, self.parameters.prompt_mode, self.parameters.target_user_info
)
return {"cross_context_block": cross_context}
except Exception as e:
logger.error(f"构建跨群上下文失败: {e}")
return {"cross_context_block": ""}
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
"""使用上下文数据格式化模板"""
if self.parameters.prompt_mode == "s4u":
params = self._prepare_s4u_params(context_data)
elif self.parameters.prompt_mode == "normal":
params = self._prepare_normal_params(context_data)
else:
params = self._prepare_default_params(context_data)
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备S4U模式的参数"""
return {
**context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""),
"tool_info_block": context_data.get("tool_info_block", ""),
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"sender_name": self.parameters.sender,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""),
"time_block": context_data.get("time_block", ""),
"core_dialogue_prompt": context_data.get("core_dialogue_prompt", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备Normal模式的参数"""
return {
**context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""),
"tool_info_block": context_data.get("tool_info_block", ""),
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"schedule_block": self.parameters.schedule_block or context_data.get("schedule_block", ""),
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"config_expression_style": global_config.personality.reply_style,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备默认模式的参数"""
return {
"expression_habits_block": context_data.get("expression_habits_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"chat_target": "",
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"chat_target_2": "",
"reply_target_block": context_data.get("reply_target_block", ""),
"raw_reply": self.parameters.target,
"reason": "",
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
def format(self, *args, **kwargs) -> str:
"""格式化模板,支持位置参数和关键字参数"""
try:
# 先用位置参数格式化
if args:
formatted_args = {}
for i in range(len(args)):
if i < len(self.args):
formatted_args[self.args[i]] = args[i]
processed_template = self._processed_template.format(**formatted_args)
else:
processed_template = self._processed_template
# 再用关键字参数格式化
if kwargs:
processed_template = processed_template.format(**kwargs)
# 将临时标记还原为实际的花括号
result = self._restore_escaped_braces(processed_template)
return result
except (IndexError, KeyError) as e:
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
def __str__(self) -> str:
"""返回格式化后的结果或原始模板"""
return self._formatted_result if self._formatted_result else self.template
def __repr__(self) -> str:
"""返回提示词的表示形式"""
return f"Prompt(template='{self.template}', name='{self.name}')"
# =============================================================================
# PromptUtils功能迁移 - 静态工具方法
# 这些方法原来在PromptUtils类中现在作为Prompt类的静态方法
# 解决循环导入问题
# =============================================================================
@staticmethod
def parse_reply_target(target_message: str) -> Tuple[str, str]:
"""
解析回复目标消息 - 统一实现
Args:
target_message: 目标消息,格式为 "发送者:消息内容""发送者:消息内容"
Returns:
Tuple[str, str]: (发送者名称, 消息内容)
"""
sender = ""
target = ""
# 添加None检查防止NoneType错误
if target_message is None:
return sender, target
if ":" in target_message or "" in target_message:
# 使用正则表达式匹配中文或英文冒号
parts = re.split(pattern=r"[:]", string=target_message, maxsplit=1)
if len(parts) == 2:
sender = parts[0].strip()
target = parts[1].strip()
return sender, target
@staticmethod
async def build_relation_info(chat_id: str, reply_to: str) -> str:
"""
构建关系信息 - 统一实现
Args:
chat_id: 聊天ID
reply_to: 回复目标字符串
Returns:
str: 关系信息字符串
"""
if not global_config.relationship.enable_relationship:
return ""
from src.person_info.relationship_fetcher import relationship_fetcher_manager
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
if not reply_to:
return ""
sender, text = Prompt.parse_reply_target(reply_to)
if not sender or not text:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if not person_id:
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
@staticmethod
async def build_cross_context(
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
) -> str:
"""
构建跨群聊上下文 - 统一实现
Args:
chat_id: 聊天ID
prompt_mode: 当前提示词模式
target_user_info: 目标用户信息
Returns:
str: 跨群聊上下文字符串
"""
if not global_config.cross_context.enable:
return ""
from src.plugin_system.apis import cross_context_api
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
if not other_chat_raw_ids:
return ""
chat_stream = get_chat_manager().get_stream(chat_id)
if not chat_stream:
return ""
if prompt_mode == "normal":
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
elif prompt_mode == "s4u":
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
return ""
@staticmethod
def parse_reply_target_id(reply_to: str) -> str:
"""
解析回复目标中的用户ID
Args:
reply_to: 回复目标字符串
Returns:
str: 用户ID
"""
if not reply_to:
return ""
# 复用parse_reply_target方法的逻辑
sender, _ = Prompt.parse_reply_target(reply_to)
if not sender:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if person_id:
user_id = person_info_manager.get_value_sync(person_id, "user_id")
return str(user_id) if user_id else ""
return ""
# 工厂函数
def create_prompt(
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
) -> Prompt:
"""快速创建Prompt实例的工厂函数"""
if parameters is None:
parameters = PromptParameters(**kwargs)
return Prompt(template, name, parameters)
async def create_prompt_async(
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
) -> Prompt:
"""异步创建Prompt实例"""
prompt = create_prompt(template, name, parameters, **kwargs)
if global_prompt_manager._context._current_context:
await global_prompt_manager._context.register_async(prompt)
return prompt

View File

@@ -1,299 +0,0 @@
import re
import asyncio
import contextvars
from rich.traceback import install
from contextlib import asynccontextmanager
from typing import Dict, Any, Optional, List, Union
from src.common.logger import get_logger
install(extra_lines=3)
logger = get_logger("prompt_build")
class PromptContext:
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
# 使用contextvars创建协程上下文变量
self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock() # 保留锁用于其他操作
@property
def _current_context(self) -> Optional[str]:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value) # type: ignore
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
# sourcery skip: hoist-statement-from-if, use-contextlib-suppress
"""创建一个异步的临时提示模板作用域"""
# 保存当前上下文并设置新上下文
if context_id is not None:
try:
# 添加超时保护,避免长时间等待锁
await asyncio.wait_for(self._context_lock.acquire(), timeout=5.0)
try:
if context_id not in self._context_prompts:
self._context_prompts[context_id] = {}
finally:
self._context_lock.release()
except asyncio.TimeoutError:
logger.warning(f"获取上下文锁超时context_id: {context_id}")
# 超时时直接进入,不设置上下文
context_id = None
# 保存当前协程的上下文值,不影响其他协程
previous_context = self._current_context
# 设置当前协程的新上下文
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
else:
# 如果没有提供新上下文,保持当前上下文不变
previous_context = self._current_context
token = None
try:
yield self
finally:
# 恢复之前的上下文,添加异常保护
if context_id is not None and token is not None:
try:
self._current_context_var.reset(token)
except Exception as e:
logger.warning(f"恢复上下文时出错: {e}")
# 如果reset失败尝试直接设置
try:
self._current_context = previous_context
except Exception:
...
# 静默忽略恢复失败
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板"""
async with self._context_lock:
current_context = self._current_context
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
if (
current_context
and current_context in self._context_prompts
and name in self._context_prompts[current_context]
):
return self._context_prompts[current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
if target_context := context_id or self._current_context:
if prompt.name:
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
class PromptManager:
def __init__(self):
self._prompts = {}
self._counter = 0
self._context = PromptContext()
self._lock = asyncio.Lock()
@asynccontextmanager
async def async_message_scope(self, message_id: Optional[str] = None):
"""为消息处理创建异步临时作用域,支持 message_id 为 None 的情况"""
async with self._context.async_scope(message_id):
yield self
async def get_prompt_async(self, name: str) -> "Prompt":
# 首先尝试从当前上下文获取
context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None:
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
return context_prompt
# 如果上下文中不存在,则使用全局提示模板
async with self._lock:
# logger.debug(f"从全局获取提示词: {name}")
if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]
def generate_name(self, template: str) -> str:
"""为未命名的prompt生成名称"""
self._counter += 1
return f"prompt_{self._counter}"
def register(self, prompt: "Prompt") -> None:
"""注册一个prompt"""
if not prompt.name:
prompt.name = self.generate_name(prompt.template)
self._prompts[prompt.name] = prompt
def add_prompt(self, name: str, fstr: str) -> "Prompt":
prompt = Prompt(fstr, name=name)
if prompt.name:
self._prompts[prompt.name] = prompt
return prompt
async def format_prompt(self, name: str, **kwargs) -> str:
# 获取当前提示词
prompt = await self.get_prompt_async(name)
# 获取基本格式化结果
result = prompt.format(**kwargs)
return result
# 全局单例
global_prompt_manager = PromptManager()
class Prompt(str):
template: str
name: Optional[str]
args: List[str]
_args: List[Any]
_kwargs: Dict[str, Any]
# 临时标记,作为类常量
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
@staticmethod
def _process_escaped_braces(template) -> str:
"""处理模板中的转义花括号,将 \\{\\} 替换为临时标记""" # type: ignore
# 如果传入的是列表,将其转换为字符串
if isinstance(template, list):
template = "\n".join(str(item) for item in template)
elif not isinstance(template, str):
template = str(template)
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
@staticmethod
def _restore_escaped_braces(template: str) -> str:
"""将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def __new__(
cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs
):
# 如果传入的是元组,转换为列表
if isinstance(args, tuple):
args = list(args)
should_register = kwargs.pop("_should_register", True)
# 预处理模板中的转义花括号
processed_fstr = cls._process_escaped_braces(fstr)
# 解析模板
template_args = []
result = re.findall(r"\{(.*?)}", processed_fstr)
for expr in result:
if expr and expr not in template_args:
template_args.append(expr)
# 如果提供了初始参数,立即格式化
if kwargs or args:
formatted = cls._format_template(fstr, args=args, kwargs=kwargs)
obj = super().__new__(cls, formatted)
else:
obj = super().__new__(cls, "")
obj.template = fstr
obj.name = name
obj.args = template_args
obj._args = args or []
obj._kwargs = kwargs
# 修改自动注册逻辑
if should_register and not global_prompt_manager._context._current_context:
global_prompt_manager.register(obj)
return obj
@classmethod
async def create_async(
cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs
):
"""异步创建Prompt实例"""
prompt = cls(fstr, name, args, **kwargs)
if global_prompt_manager._context._current_context:
await global_prompt_manager._context.register_async(prompt)
return prompt
@classmethod
def _format_template(
cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None
) -> str:
if kwargs is None:
kwargs = {}
# 预处理模板中的转义花括号
processed_template = cls._process_escaped_braces(template)
template_args = []
result = re.findall(r"\{(.*?)}", processed_template)
for expr in result:
if expr and expr not in template_args:
template_args.append(expr)
formatted_args = {}
formatted_kwargs = {}
# 处理位置参数
if args:
# print(len(template_args), len(args), template_args, args)
for i in range(len(args)):
if i < len(template_args):
arg = args[i]
if isinstance(arg, Prompt):
formatted_args[template_args[i]] = arg.format(**kwargs)
else:
formatted_args[template_args[i]] = arg
else:
logger.error(
f"构建提示词模板失败,解析到的参数列表{template_args},长度为{len(template_args)},输入的参数列表为{args},提示词模板为{template}"
)
raise ValueError("格式化模板失败")
# 处理关键字参数
if kwargs:
for key, value in kwargs.items():
if isinstance(value, Prompt):
remaining_kwargs = {k: v for k, v in kwargs.items() if k != key}
formatted_kwargs[key] = value.format(**remaining_kwargs)
else:
formatted_kwargs[key] = value
try:
# 先用位置参数格式化
if args:
processed_template = processed_template.format(**formatted_args)
# 再用关键字参数格式化
if kwargs:
processed_template = processed_template.format(**formatted_kwargs)
# 将临时标记还原为实际的花括号
result = cls._restore_escaped_braces(processed_template)
return result
except (IndexError, KeyError) as e:
raise ValueError(
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
) from e
def format(self, *args, **kwargs) -> "str":
"""支持位置参数和关键字参数的格式化,使用"""
ret = type(self)(
self.template,
self.name,
args=list(args) if args else self._args,
_should_register=False,
**kwargs or self._kwargs,
)
# print(f"prompt build result: {ret} name: {ret.name} ")
return str(ret)
def __str__(self) -> str:
return super().__str__() if self._kwargs or self._args else self.template
def __repr__(self) -> str:
return f"Prompt(template='{self.template}', name='{self.name}')"

View File

@@ -1,156 +0,0 @@
"""
智能提示词参数模块 - 优化参数结构
简化SmartPromptParameters减少冗余和重复
"""
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Literal
@dataclass
class SmartPromptParameters:
"""简化的智能提示词参数系统"""
# 基础参数
chat_id: str = ""
is_group_chat: bool = False
sender: str = ""
target: str = ""
reply_to: str = ""
extra_info: str = ""
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
# 功能开关
enable_tool: bool = True
enable_memory: bool = True
enable_expression: bool = True
enable_relation: bool = True
enable_cross_context: bool = True
enable_knowledge: bool = True
# 性能控制
max_context_messages: int = 50
# 调试选项
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: Optional[Dict[str, Any]] = None
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: Optional[Dict[str, Any]] = None
# 已构建的内容块
expression_habits_block: str = ""
relation_info_block: str = ""
memory_block: str = ""
tool_info_block: str = ""
knowledge_prompt: str = ""
cross_context_block: str = ""
# 其他内容块
keywords_reaction_prompt: str = ""
extra_info_block: str = ""
time_block: str = ""
identity_block: str = ""
schedule_block: str = ""
moderation_prompt_block: str = ""
reply_target_block: str = ""
mood_prompt: str = ""
action_descriptions: str = ""
# 可用动作信息
available_actions: Optional[Dict[str, Any]] = None
def validate(self) -> List[str]:
"""统一的参数验证"""
errors = []
if not self.chat_id:
errors.append("chat_id不能为空")
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
errors.append("prompt_mode必须是's4u''normal''minimal'")
if self.max_context_messages <= 0:
errors.append("max_context_messages必须大于0")
return errors
def get_needed_build_tasks(self) -> List[str]:
"""获取需要执行的任务列表"""
tasks = []
if self.enable_expression and not self.expression_habits_block:
tasks.append("expression_habits")
if self.enable_memory and not self.memory_block:
tasks.append("memory_block")
if self.enable_relation and not self.relation_info_block:
tasks.append("relation_info")
if self.enable_tool and not self.tool_info_block:
tasks.append("tool_info")
if self.enable_knowledge and not self.knowledge_prompt:
tasks.append("knowledge_info")
if self.enable_cross_context and not self.cross_context_block:
tasks.append("cross_context")
return tasks
@classmethod
def from_legacy_params(cls, **kwargs) -> "SmartPromptParameters":
"""
从旧版参数创建新参数对象
Args:
**kwargs: 旧版参数
Returns:
SmartPromptParameters: 新参数对象
"""
return cls(
# 基础参数
chat_id=kwargs.get("chat_id", ""),
is_group_chat=kwargs.get("is_group_chat", False),
sender=kwargs.get("sender", ""),
target=kwargs.get("target", ""),
reply_to=kwargs.get("reply_to", ""),
extra_info=kwargs.get("extra_info", ""),
prompt_mode=kwargs.get("current_prompt_mode", "s4u"),
# 功能开关
enable_tool=kwargs.get("enable_tool", True),
enable_memory=kwargs.get("enable_memory", True),
enable_expression=kwargs.get("enable_expression", True),
enable_relation=kwargs.get("enable_relation", True),
enable_cross_context=kwargs.get("enable_cross_context", True),
enable_knowledge=kwargs.get("enable_knowledge", True),
# 性能控制
max_context_messages=kwargs.get("max_context_messages", 50),
debug_mode=kwargs.get("debug_mode", False),
# 聊天历史和上下文
chat_target_info=kwargs.get("chat_target_info"),
message_list_before_now_long=kwargs.get("message_list_before_now_long", []),
message_list_before_short=kwargs.get("message_list_before_short", []),
chat_talking_prompt_short=kwargs.get("chat_talking_prompt_short", ""),
target_user_info=kwargs.get("target_user_info"),
# 已构建的内容块
expression_habits_block=kwargs.get("expression_habits_block", ""),
relation_info_block=kwargs.get("relation_info", ""),
memory_block=kwargs.get("memory_block", ""),
tool_info_block=kwargs.get("tool_info", ""),
knowledge_prompt=kwargs.get("knowledge_prompt", ""),
cross_context_block=kwargs.get("cross_context_block", ""),
# 其他内容块
keywords_reaction_prompt=kwargs.get("keywords_reaction_prompt", ""),
extra_info_block=kwargs.get("extra_info_block", ""),
time_block=kwargs.get("time_block", ""),
identity_block=kwargs.get("identity_block", ""),
schedule_block=kwargs.get("schedule_block", ""),
moderation_prompt_block=kwargs.get("moderation_prompt_block", ""),
reply_target_block=kwargs.get("reply_target_block", ""),
mood_prompt=kwargs.get("mood_prompt", ""),
action_descriptions=kwargs.get("action_descriptions", ""),
# 可用动作信息
available_actions=kwargs.get("available_actions", None),
)

View File

@@ -1,132 +0,0 @@
"""
共享提示词工具模块 - 消除重复代码
提供统一的工具函数供DefaultReplyer和SmartPrompt使用
"""
import re
from typing import Dict, Any, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
from src.plugin_system.apis import cross_context_api
logger = get_logger("prompt_utils")
class PromptUtils:
"""提示词工具类 - 提供共享功能,移除缓存相关功能和依赖检查"""
@staticmethod
def parse_reply_target(target_message: str) -> Tuple[str, str]:
"""
解析回复目标消息 - 统一实现
Args:
target_message: 目标消息,格式为 "发送者:消息内容""发送者:消息内容"
Returns:
Tuple[str, str]: (发送者名称, 消息内容)
"""
sender = ""
target = ""
# 添加None检查防止NoneType错误
if target_message is None:
return sender, target
if ":" in target_message or "" in target_message:
# 使用正则表达式匹配中文或英文冒号
parts = re.split(pattern=r"[:]", string=target_message, maxsplit=1)
if len(parts) == 2:
sender = parts[0].strip()
target = parts[1].strip()
return sender, target
@staticmethod
async def build_relation_info(chat_id: str, reply_to: str) -> str:
"""
构建关系信息 - 统一实现
Args:
chat_id: 聊天ID
reply_to: 回复目标字符串
Returns:
str: 关系信息字符串
"""
if not global_config.relationship.enable_relationship:
return ""
from src.person_info.relationship_fetcher import relationship_fetcher_manager
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
if not reply_to:
return ""
sender, text = PromptUtils.parse_reply_target(reply_to)
if not sender or not text:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if not person_id:
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
@staticmethod
async def build_cross_context(
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
) -> str:
"""
构建跨群聊上下文 - 统一实现完全继承DefaultReplyer功能
"""
if not global_config.cross_context.enable:
return ""
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
if not other_chat_raw_ids:
return ""
chat_stream = get_chat_manager().get_stream(chat_id)
if not chat_stream:
return ""
if current_prompt_mode == "normal":
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
elif current_prompt_mode == "s4u":
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
return ""
@staticmethod
def parse_reply_target_id(reply_to: str) -> str:
"""
解析回复目标中的用户ID
Args:
reply_to: 回复目标字符串
Returns:
str: 用户ID
"""
if not reply_to:
return ""
# 复用parse_reply_target方法的逻辑
sender, _ = PromptUtils.parse_reply_target(reply_to)
if not sender:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if person_id:
user_id = person_info_manager.get_value_sync(person_id, "user_id")
return str(user_id) if user_id else ""
return ""

View File

@@ -1,938 +0,0 @@
"""
智能Prompt系统 - 完全重构版本
基于原有DefaultReplyer的完整功能集成使用新的参数结构
解决实现质量不高、功能集成不完整和错误处理不足的问题
"""
import asyncio
import time
from datetime import datetime
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Tuple
from src.chat.utils.prompt_builder import global_prompt_manager, Prompt
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import (
build_readable_messages,
)
from src.person_info.person_info import get_person_info_manager
from src.chat.utils.prompt_utils import PromptUtils
from src.chat.utils.prompt_parameters import SmartPromptParameters
logger = get_logger("smart_prompt")
@dataclass
class ChatContext:
"""聊天上下文信息"""
chat_id: str = ""
platform: str = ""
is_group: bool = False
user_id: str = ""
user_nickname: str = ""
group_id: Optional[str] = None
timestamp: datetime = field(default_factory=datetime.now)
class SmartPromptBuilder:
"""重构的智能提示词构建器 - 统一错误处理和功能集成,移除缓存机制和依赖检查"""
def __init__(self):
# 移除缓存相关初始化
pass
async def build_context_data(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""并行构建完整的上下文数据 - 移除缓存机制和依赖检查"""
# 并行执行所有构建任务
start_time = time.time()
timing_logs = {}
try:
# 准备构建任务
tasks = []
task_names = []
# 初始化预构建参数,使用新的结构
pre_built_params = {}
if params.expression_habits_block:
pre_built_params["expression_habits_block"] = params.expression_habits_block
if params.relation_info_block:
pre_built_params["relation_info_block"] = params.relation_info_block
if params.memory_block:
pre_built_params["memory_block"] = params.memory_block
if params.tool_info_block:
pre_built_params["tool_info_block"] = params.tool_info_block
if params.knowledge_prompt:
pre_built_params["knowledge_prompt"] = params.knowledge_prompt
if params.cross_context_block:
pre_built_params["cross_context_block"] = params.cross_context_block
# 根据新的参数结构确定要构建的项
if params.enable_expression and not pre_built_params.get("expression_habits_block"):
tasks.append(self._build_expression_habits(params))
task_names.append("expression_habits")
if params.enable_memory and not pre_built_params.get("memory_block"):
tasks.append(self._build_memory_block(params))
task_names.append("memory_block")
if params.enable_relation and not pre_built_params.get("relation_info_block"):
tasks.append(self._build_relation_info(params))
task_names.append("relation_info")
# 添加mai_think上下文构建任务
if not pre_built_params.get("mai_think"):
tasks.append(self._build_mai_think_context(params))
task_names.append("mai_think_context")
if params.enable_tool and not pre_built_params.get("tool_info_block"):
tasks.append(self._build_tool_info(params))
task_names.append("tool_info")
if params.enable_knowledge and not pre_built_params.get("knowledge_prompt"):
tasks.append(self._build_knowledge_info(params))
task_names.append("knowledge_info")
if params.enable_cross_context and not pre_built_params.get("cross_context_block"):
tasks.append(self._build_cross_context(params))
task_names.append("cross_context")
# 性能优化:根据任务数量动态调整超时时间
base_timeout = 10.0 # 基础超时时间
task_timeout = 2.0 # 每个任务的超时时间
timeout_seconds = min(
max(base_timeout, len(tasks) * task_timeout), # 根据任务数量计算超时
30.0, # 最大超时时间
)
# 性能优化:限制并发任务数量,避免资源耗尽
max_concurrent_tasks = 5 # 最大并发任务数
if len(tasks) > max_concurrent_tasks:
# 分批执行任务
results = []
for i in range(0, len(tasks), max_concurrent_tasks):
batch_tasks = tasks[i : i + max_concurrent_tasks]
batch_names = task_names[i : i + max_concurrent_tasks]
batch_results = await asyncio.wait_for(
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
)
results.extend(batch_results)
else:
# 一次性执行所有任务
results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
)
# 处理结果并收集性能数据
context_data = {}
for i, result in enumerate(results):
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
if isinstance(result, Exception):
logger.error(f"构建任务{task_name}失败: {str(result)}")
elif isinstance(result, dict):
# 结果格式: {component_name: value}
context_data.update(result)
# 记录耗时过长的任务
if task_name in timing_logs and timing_logs[task_name] > 8.0:
logger.warning(f"构建任务{task_name}耗时过长: {timing_logs[task_name]:.2f}s")
# 添加预构建的参数
for key, value in pre_built_params.items():
if value:
context_data[key] = value
except asyncio.TimeoutError:
logger.error(f"构建超时 ({timeout_seconds}s)")
context_data = {}
# 添加预构建的参数,即使在超时情况下
for key, value in pre_built_params.items():
if value:
context_data[key] = value
# 构建聊天历史 - 根据模式不同
if params.prompt_mode == "s4u":
await self._build_s4u_chat_context(context_data, params)
else:
await self._build_normal_chat_context(context_data, params)
# 补充基础信息
context_data.update(
{
"keywords_reaction_prompt": params.keywords_reaction_prompt,
"extra_info_block": params.extra_info_block,
"time_block": params.time_block or f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
"identity": params.identity_block,
"schedule_block": params.schedule_block,
"moderation_prompt": params.moderation_prompt_block,
"reply_target_block": params.reply_target_block,
"mood_state": params.mood_prompt,
"action_descriptions": params.action_descriptions,
}
)
total_time = time.time() - start_time
if timing_logs:
timing_str = "; ".join([f"{name}: {time:.2f}s" for name, time in timing_logs.items()])
logger.info(f"构建任务耗时: {timing_str}")
logger.debug(f"构建完成,总耗时: {total_time:.2f}s")
return context_data
async def _build_s4u_chat_context(self, context_data: Dict[str, Any], params: SmartPromptParameters) -> None:
"""构建S4U模式的聊天上下文 - 使用新参数结构"""
if not params.message_list_before_now_long:
return
# 使用共享工具构建分离历史
core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts(
params.message_list_before_now_long,
params.target_user_info.get("user_id") if params.target_user_info else "",
params.sender,
)
context_data["core_dialogue_prompt"] = core_dialogue
context_data["background_dialogue_prompt"] = background_dialogue
async def _build_normal_chat_context(self, context_data: Dict[str, Any], params: SmartPromptParameters) -> None:
"""构建normal模式的聊天上下文 - 使用新参数结构"""
if not params.chat_talking_prompt_short:
return
context_data["chat_info"] = f"""群里的聊天内容:
{params.chat_talking_prompt_short}"""
async def _build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
) -> Tuple[str, str]:
"""构建S4U风格的分离对话prompt - 完整实现"""
core_dialogue_list = []
bot_id = str(global_config.bot.qq_account)
# 过滤消息分离bot和目标用户的对话 vs 其他用户的对话
for msg_dict in message_list_before_now:
try:
msg_user_id = str(msg_dict.get("user_id"))
reply_to = msg_dict.get("reply_to", "")
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
# bot 和目标用户的对话
core_dialogue_list.append(msg_dict)
except Exception as e:
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
# 构建背景对话 prompt
all_dialogue_prompt = ""
if message_list_before_now:
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
all_dialogue_prompt_str = build_readable_messages(
latest_25_msgs,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True,
)
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
# 构建核心对话 prompt
core_dialogue_prompt = ""
if core_dialogue_list:
# 检查最新五条消息中是否包含bot自己说的消息
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
# logger.info(f"最新五条消息:{latest_5_messages}")
# logger.info(f"最新五条消息中是否包含bot自己说的消息{has_bot_message}")
# 如果最新五条消息中不包含bot的消息则返回空字符串
if not has_bot_message:
core_dialogue_prompt = ""
else:
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
core_dialogue_prompt_str = build_readable_messages(
core_dialogue_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
core_dialogue_prompt = f"""--------------------------------
这是你和{sender}的对话,你们正在交流中:
{core_dialogue_prompt_str}
--------------------------------
"""
return core_dialogue_prompt, all_dialogue_prompt
async def _build_mai_think_context(self, params: SmartPromptParameters) -> Any:
"""构建mai_think上下文 - 完全继承DefaultReplyer功能"""
from src.mais4u.mai_think import mai_thinking_manager
# 获取mai_think实例
mai_think = mai_thinking_manager.get_mai_think(params.chat_id)
# 设置mai_think的上下文信息
mai_think.memory_block = params.memory_block or ""
mai_think.relation_info_block = params.relation_info_block or ""
mai_think.time_block = params.time_block or f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# 设置聊天目标信息
if params.is_group_chat:
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
else:
chat_target_name = "对方"
if params.chat_target_info:
chat_target_name = (
params.chat_target_info.get("person_name") or params.chat_target_info.get("user_nickname") or "对方"
)
chat_target_1 = await global_prompt_manager.format_prompt(
"chat_target_private1", sender_name=chat_target_name
)
chat_target_2 = await global_prompt_manager.format_prompt(
"chat_target_private2", sender_name=chat_target_name
)
mai_think.chat_target = chat_target_1
mai_think.chat_target_2 = chat_target_2
mai_think.chat_info = params.chat_talking_prompt_short or ""
mai_think.mood_state = params.mood_prompt or ""
mai_think.identity = params.identity_block or ""
mai_think.sender = params.sender
mai_think.target = params.target
# 返回mai_think实例以便后续使用
return mai_think
def _parse_reply_target_id(self, reply_to: str) -> str:
"""解析回复目标中的用户ID"""
if not reply_to:
return ""
# 复用_parse_reply_target方法的逻辑
sender, _ = self._parse_reply_target(reply_to)
if not sender:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if person_id:
user_id = person_info_manager.get_value_sync(person_id, "user_id")
return str(user_id) if user_id else ""
async def _build_expression_habits(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建表达习惯 - 使用共享工具类完全继承DefaultReplyer功能"""
# 检查是否允许在此聊天流中使用表达
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(params.chat_id)
if not use_expression:
return {"expression_habits_block": ""}
from src.chat.express.expression_selector import expression_selector
style_habits = []
grammar_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
try:
selected_expressions = await expression_selector.select_suitable_expressions_llm(
params.chat_id, params.chat_talking_prompt_short, max_num=8, min_num=2, target_message=params.target
)
except Exception as e:
logger.error(f"选择表达方式失败: {e}")
selected_expressions = []
if selected_expressions:
logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式")
for expr in selected_expressions:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_type = expr.get("type", "style")
if expr_type == "grammar":
grammar_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
style_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
# 不再在replyer中进行随机选择全部交给处理器处理
style_habits_str = "\n".join(style_habits)
grammar_habits_str = "\n".join(grammar_habits)
# 动态构建expression habits块
expression_habits_block = ""
expression_habits_title = ""
if style_habits_str.strip():
expression_habits_title = (
"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:"
)
expression_habits_block += f"{style_habits_str}\n"
if grammar_habits_str.strip():
expression_habits_title = (
"你可以选择下面的句法进行回复,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式使用:"
)
expression_habits_block += f"{grammar_habits_str}\n"
if style_habits_str.strip() and grammar_habits_str.strip():
expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中。"
return {"expression_habits_block": f"{expression_habits_title}\n{expression_habits_block}"}
async def _build_memory_block(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建记忆块 - 使用共享工具类完全继承DefaultReplyer功能"""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
from src.chat.memory_system.memory_activator import MemoryActivator
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
instant_memory = None
# 初始化记忆激活器
try:
memory_activator = MemoryActivator()
# 获取长期记忆
running_memories = await memory_activator.activate_memory_with_chat_history(
target_message=params.target, chat_history_prompt=params.chat_talking_prompt_short
)
except Exception as e:
logger.error(f"激活记忆失败: {e}")
running_memories = []
# 处理瞬时记忆
if global_config.memory.enable_instant_memory:
# 使用异步记忆包装器(最优化的非阻塞模式)
try:
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
# 获取异步记忆包装器
async_memory = get_async_instant_memory(params.chat_id)
# 后台存储聊天历史(完全非阻塞)
async_memory.store_memory_background(params.chat_talking_prompt_short)
# 快速检索记忆最大超时2秒
instant_memory = await async_memory.get_memory_with_fallback(params.target, max_timeout=2.0)
logger.info(f"异步瞬时记忆:{instant_memory}")
except ImportError:
# 如果异步包装器不可用,尝试使用异步记忆管理器
try:
from src.chat.memory_system.async_memory_optimizer import (
retrieve_memory_nonblocking,
store_memory_nonblocking,
)
# 异步存储聊天历史(非阻塞)
asyncio.create_task(
store_memory_nonblocking(chat_id=params.chat_id, content=params.chat_talking_prompt_short)
)
# 尝试从缓存获取瞬时记忆
instant_memory = await retrieve_memory_nonblocking(chat_id=params.chat_id, query=params.target)
# 如果没有缓存结果,快速检索一次
if instant_memory is None:
try:
# 使用VectorInstantMemoryV2实例
instant_memory_system = VectorInstantMemoryV2(chat_id=params.chat_id, retention_hours=1)
instant_memory = await asyncio.wait_for(
instant_memory_system.get_memory_for_context(params.target), timeout=1.5
)
except asyncio.TimeoutError:
logger.warning("瞬时记忆检索超时,使用空结果")
instant_memory = ""
logger.info(f"向量瞬时记忆:{instant_memory}")
except ImportError:
# 最后的fallback使用原有逻辑但加上超时控制
logger.warning("异步记忆系统不可用,使用带超时的同步方式")
# 使用VectorInstantMemoryV2实例
instant_memory_system = VectorInstantMemoryV2(chat_id=params.chat_id, retention_hours=1)
# 异步存储聊天历史
asyncio.create_task(instant_memory_system.store_message(params.chat_talking_prompt_short))
# 带超时的记忆检索
try:
instant_memory = await asyncio.wait_for(
instant_memory_system.get_memory_for_context(params.target),
timeout=1.0, # 最保守的1秒超时
)
except asyncio.TimeoutError:
logger.warning("瞬时记忆检索超时,跳过记忆获取")
instant_memory = ""
except Exception as e:
logger.error(f"瞬时记忆检索失败: {e}")
instant_memory = ""
logger.info(f"同步瞬时记忆:{instant_memory}")
except Exception as e:
logger.error(f"瞬时记忆系统异常: {e}")
instant_memory = ""
# 构建记忆字符串,即使某种记忆为空也要继续
memory_str = ""
has_any_memory = False
# 添加长期记忆
if running_memories:
if not memory_str:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memories:
memory_str += f"- {running_memory['content']}\n"
has_any_memory = True
# 添加瞬时记忆
if instant_memory:
if not memory_str:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
memory_str += f"- {instant_memory}\n"
has_any_memory = True
# 注入视频分析结果引导语
memory_str = self._inject_video_prompt_if_needed(params.target, memory_str)
# 只有当完全没有任何记忆时才返回空字符串
return {"memory_block": memory_str if has_any_memory else ""}
def _inject_video_prompt_if_needed(self, target: str, memory_str: str) -> str:
"""统一视频分析结果注入逻辑"""
if target and ("[视频内容]" in target or "好的,我将根据您提供的" in target):
video_prompt_injection = (
"\n请注意,以上内容是你刚刚观看的视频,请以第一人称分享你的观后感,而不是在分析一份报告。"
)
return memory_str + video_prompt_injection
return memory_str
async def _build_relation_info(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建关系信息 - 使用共享工具类"""
try:
relation_info = await PromptUtils.build_relation_info(params.chat_id, params.reply_to)
return {"relation_info_block": relation_info}
except Exception as e:
logger.error(f"构建关系信息失败: {e}")
return {"relation_info_block": ""}
async def _build_tool_info(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建工具信息 - 使用共享工具类完全继承DefaultReplyer功能"""
if not params.enable_tool:
return {"tool_info_block": ""}
if not params.reply_to:
return {"tool_info_block": ""}
sender, text = PromptUtils.parse_reply_target(params.reply_to)
if not text:
return {"tool_info_block": ""}
from src.plugin_system.core.tool_use import ToolExecutor
# 使用工具执行器获取信息
try:
tool_executor = ToolExecutor(chat_id=params.chat_id)
tool_results, _, _ = await tool_executor.execute_from_chat_message(
sender=sender, target_message=text, chat_history=params.chat_talking_prompt_short, return_details=False
)
if tool_results:
tool_info_str = "以下是你通过工具获取到的实时信息:\n"
for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}{result_type}: {content}\n"
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
logger.info(f"获取到 {len(tool_results)} 个工具结果")
return {"tool_info_block": tool_info_str}
else:
logger.debug("未获取到任何工具结果")
return {"tool_info_block": ""}
except Exception as e:
logger.error(f"工具信息获取失败: {e}")
return {"tool_info_block": ""}
async def _build_knowledge_info(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建知识信息 - 使用共享工具类完全继承DefaultReplyer功能"""
if not params.reply_to:
logger.debug("没有回复对象,跳过获取知识库内容")
return {"knowledge_prompt": ""}
sender, content = PromptUtils.parse_reply_target(params.reply_to)
if not content:
logger.debug("回复对象内容为空,跳过获取知识库内容")
return {"knowledge_prompt": ""}
logger.debug(
f"获取知识库内容,元消息:{params.chat_talking_prompt_short[:30]}...,消息长度: {len(params.chat_talking_prompt_short)}"
)
# 从LPMM知识库获取知识
try:
# 检查LPMM知识库是否启用
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用跳过获取知识库内容")
return {"knowledge_prompt": ""}
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
from src.plugin_system.apis import llm_api
from src.config.config import model_config
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname
prompt = await global_prompt_manager.format_prompt(
"lpmm_get_knowledge_prompt",
bot_name=bot_name,
time_now=time_now,
chat_history=params.chat_talking_prompt_short,
sender=sender,
target_message=content,
)
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
)
if tool_calls:
from src.plugin_system.core.tool_use import ToolExecutor
tool_executor = ToolExecutor(chat_id=params.chat_id)
result = await tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
if not result or not result.get("content"):
logger.debug("从LPMM知识库获取知识失败返回空知识...")
return {"knowledge_prompt": ""}
found_knowledge_from_lpmm = result.get("content", "")
logger.debug(
f"从LPMM知识库获取知识相关信息{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
)
return {
"knowledge_prompt": f"你有以下这些**知识**\n{found_knowledge_from_lpmm}\n请你**记住上面的知识**,之后可能会用到。\n"
}
else:
logger.debug("从LPMM知识库获取知识失败可能是从未导入过知识返回空知识...")
return {"knowledge_prompt": ""}
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return {"knowledge_prompt": ""}
async def _build_cross_context(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建跨群上下文 - 使用共享工具类"""
try:
cross_context = await PromptUtils.build_cross_context(
params.chat_id, params.prompt_mode, params.target_user_info
)
return {"cross_context_block": cross_context}
except Exception as e:
logger.error(f"构建跨群上下文失败: {e}")
return {"cross_context_block": ""}
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
"""解析回复目标消息 - 使用共享工具类"""
return PromptUtils.parse_reply_target(target_message)
class SmartPrompt:
"""重构的智能提示词核心类 - 移除缓存机制和依赖检查,简化架构"""
def __init__(
self,
template_name: Optional[str] = None,
parameters: Optional[SmartPromptParameters] = None,
):
self.parameters = parameters or SmartPromptParameters()
self.template_name = template_name or self._get_default_template()
self.builder = SmartPromptBuilder()
def _get_default_template(self) -> str:
"""根据模式选择默认模板"""
if self.parameters.prompt_mode == "s4u":
return "s4u_style_prompt"
elif self.parameters.prompt_mode == "normal":
return "normal_style_prompt"
else:
return "default_expressor_prompt"
async def build_prompt(self) -> str:
"""构建最终的Prompt文本 - 移除缓存机制和依赖检查"""
# 参数验证
errors = self.parameters.validate()
if errors:
logger.error(f"参数验证失败: {', '.join(errors)}")
raise ValueError(f"参数验证失败: {', '.join(errors)}")
start_time = time.time()
try:
# 构建基础上下文的完整映射
context_data = await self.builder.build_context_data(self.parameters)
# 检查关键上下文数据
if not context_data or not isinstance(context_data, dict):
logger.error("构建的上下文数据无效")
raise ValueError("构建的上下文数据无效")
# 获取模板
template = await self._get_template()
if template is None:
logger.error("无法获取模板")
raise ValueError("无法获取模板")
# 根据模式传递不同的参数
if self.parameters.prompt_mode == "s4u":
result = await self._build_s4u_prompt(template, context_data)
elif self.parameters.prompt_mode == "normal":
result = await self._build_normal_prompt(template, context_data)
else:
result = await self._build_default_prompt(template, context_data)
# 记录性能数据
total_time = time.time() - start_time
logger.debug(f"SmartPrompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
return result
except asyncio.TimeoutError as e:
logger.error(f"构建Prompt超时: {e}")
raise TimeoutError(f"构建Prompt超时: {e}")
except Exception as e:
logger.error(f"构建Prompt失败: {e}")
raise RuntimeError(f"构建Prompt失败: {e}")
async def _get_template(self) -> Optional[Prompt]:
"""获取模板"""
try:
return await global_prompt_manager.get_prompt_async(self.template_name)
except Exception as e:
logger.error(f"获取模板 {self.template_name} 失败: {e}")
raise RuntimeError(f"获取模板 {self.template_name} 失败: {e}")
async def _build_s4u_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str:
"""构建S4U模式的完整Prompt - 使用新参数结构"""
params = {
**context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""),
"tool_info_block": context_data.get("tool_info_block", ""),
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"sender_name": self.parameters.sender,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""),
"time_block": context_data.get("time_block", ""),
"core_dialogue_prompt": context_data.get("core_dialogue_prompt", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
return await global_prompt_manager.format_prompt(self.template_name, **params)
async def _build_normal_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str:
"""构建Normal模式的完整Prompt - 使用新参数结构"""
params = {
**context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""),
"tool_info_block": context_data.get("tool_info_block", ""),
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"schedule_block": self.parameters.schedule_block or context_data.get("schedule_block", ""),
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"config_expression_style": global_config.personality.reply_style,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
return await global_prompt_manager.format_prompt(self.template_name, **params)
async def _build_default_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str:
"""构建默认模式的Prompt - 使用新参数结构"""
params = {
"expression_habits_block": context_data.get("expression_habits_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"chat_target": "",
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"chat_target_2": "",
"reply_target_block": context_data.get("reply_target_block", ""),
"raw_reply": self.parameters.target,
"reason": "",
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
return await global_prompt_manager.format_prompt(self.template_name, **params)
# 工厂函数 - 简化创建 - 更新参数结构
def create_smart_prompt(
chat_id: str = "", sender_name: str = "", target_message: str = "", reply_to: str = "", **kwargs
) -> SmartPrompt:
"""快速创建智能Prompt实例的工厂函数 - 使用新参数结构"""
# 使用新的参数结构
parameters = SmartPromptParameters(
chat_id=chat_id, sender=sender_name, target=target_message, reply_to=reply_to, **kwargs
)
return SmartPrompt(parameters=parameters)
class SmartPromptHealthChecker:
"""SmartPrompt健康检查器 - 移除依赖检查"""
@staticmethod
async def check_system_health() -> Dict[str, Any]:
"""检查系统健康状态 - 移除依赖检查"""
health_status = {"status": "healthy", "components": {}, "issues": []}
try:
# 检查配置
try:
from src.config.config import global_config
health_status["components"]["config"] = "ok"
# 检查关键配置项
if not hasattr(global_config, "personality") or not hasattr(global_config.personality, "prompt_mode"):
health_status["issues"].append("缺少personality.prompt_mode配置")
health_status["status"] = "degraded"
if not hasattr(global_config, "memory") or not hasattr(global_config.memory, "enable_memory"):
health_status["issues"].append("缺少memory.enable_memory配置")
except Exception as e:
health_status["components"]["config"] = f"failed: {str(e)}"
health_status["issues"].append("配置加载失败")
health_status["status"] = "unhealthy"
# 检查Prompt模板
try:
required_templates = ["s4u_style_prompt", "normal_style_prompt", "default_expressor_prompt"]
for template_name in required_templates:
try:
await global_prompt_manager.get_prompt_async(template_name)
health_status["components"][f"template_{template_name}"] = "ok"
except Exception as e:
health_status["components"][f"template_{template_name}"] = f"failed: {str(e)}"
health_status["issues"].append(f"模板{template_name}加载失败")
health_status["status"] = "degraded"
except Exception as e:
health_status["components"]["prompt_templates"] = f"failed: {str(e)}"
health_status["issues"].append("Prompt模板检查失败")
health_status["status"] = "unhealthy"
return health_status
except Exception as e:
return {"status": "unhealthy", "components": {}, "issues": [f"健康检查异常: {str(e)}"]}
@staticmethod
async def run_performance_test() -> Dict[str, Any]:
"""运行性能测试"""
test_results = {"status": "completed", "tests": {}, "summary": {}}
try:
# 创建测试参数
test_params = SmartPromptParameters(
chat_id="test_chat",
sender="test_user",
target="test_message",
reply_to="test_user:test_message",
prompt_mode="s4u",
)
# 测试不同模式下的构建性能
modes = ["s4u", "normal", "minimal"]
for mode in modes:
test_params.prompt_mode = mode
smart_prompt = SmartPrompt(parameters=test_params)
# 运行多次测试取平均值
times = []
for _ in range(3):
start_time = time.time()
try:
await smart_prompt.build_prompt()
end_time = time.time()
times.append(end_time - start_time)
except Exception as e:
times.append(float("inf"))
logger.error(f"性能测试失败 (模式: {mode}): {e}")
# 计算统计信息
valid_times = [t for t in times if t != float("inf")]
if valid_times:
avg_time = sum(valid_times) / len(valid_times)
min_time = min(valid_times)
max_time = max(valid_times)
test_results["tests"][mode] = {
"avg_time": avg_time,
"min_time": min_time,
"max_time": max_time,
"success_rate": len(valid_times) / len(times),
}
else:
test_results["tests"][mode] = {
"avg_time": float("inf"),
"min_time": float("inf"),
"max_time": float("inf"),
"success_rate": 0,
}
# 计算总体统计
all_avg_times = [
test["avg_time"] for test in test_results["tests"].values() if test["avg_time"] != float("inf")
]
if all_avg_times:
test_results["summary"] = {
"overall_avg_time": sum(all_avg_times) / len(all_avg_times),
"fastest_mode": min(test_results["tests"].items(), key=lambda x: x[1]["avg_time"])[0],
"slowest_mode": max(test_results["tests"].items(), key=lambda x: x[1]["avg_time"])[0],
}
return test_results
except Exception as e:
return {"status": "failed", "tests": {}, "summary": {}, "error": str(e)}

View File

@@ -1,6 +1,6 @@
from src.chat.message_receive.chat_stream import get_chat_manager
import time
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.chat.message_receive.message import MessageRecvS4U

View File

@@ -7,7 +7,7 @@ from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api

View File

@@ -7,7 +7,7 @@ from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api
from src.mais4u.constant_s4u import ENABLE_S4U

View File

@@ -1,6 +1,6 @@
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
import time
from src.chat.utils.utils import get_recent_group_speaker

View File

@@ -6,7 +6,7 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.llm_models.utils_model import LLMRequest
from src.manager.async_task_manager import AsyncTask, async_task_manager

View File

@@ -9,7 +9,7 @@ from json_repair import repair_json
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager

View File

@@ -23,17 +23,20 @@ class BaseEventHandler(ABC):
"""是否拦截消息,默认为否"""
init_subscribe: List[Union[EventType, str]] = [EventType.UNKNOWN]
"""初始化时订阅的事件名称"""
plugin_name = None
def __init__(self):
self.log_prefix = "[EventHandler]"
"""对应插件名"""
self.plugin_config: Optional[Dict] = None
"""插件配置字典"""
self.subscribed_events = []
"""订阅的事件列表"""
if EventType.UNKNOWN in self.init_subscribe:
raise NotImplementedError("事件处理器必须指定 event_type")
from src.plugin_system.core.component_registry import component_registry
self.plugin_config = component_registry.get_plugin_config(self.plugin_name)
@abstractmethod
async def execute(self, kwargs: dict | None) -> Tuple[bool, bool, Optional[str]]:
"""执行事件处理的抽象方法,子类必须实现
@@ -89,15 +92,7 @@ class BaseEventHandler(ABC):
weight=cls.weight,
intercept_message=cls.intercept_message,
)
def set_plugin_config(self, plugin_config: Dict) -> None:
"""设置插件配置
Args:
plugin_config (dict): 插件配置字典
"""
self.plugin_config = plugin_config
def set_plugin_name(self, plugin_name: str) -> None:
"""设置插件名称

View File

@@ -248,6 +248,7 @@ class ComponentRegistry:
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
return False
handler_class.plugin_name = handler_info.plugin_name
self._event_handler_registry[handler_name] = handler_class
if not handler_info.enabled:

View File

@@ -145,11 +145,12 @@ class EventManager:
logger.info(f"事件 {event_name} 已禁用")
return True
def register_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
def register_event_handler(self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None) -> bool:
"""注册事件处理器
Args:
handler_class (Type[BaseEventHandler]): 事件处理器类
plugin_config (Optional[dict]): 插件配置字典默认为None
Returns:
bool: 注册成功返回True已存在返回False
@@ -163,7 +164,12 @@ class EventManager:
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
return False
self._event_handlers[handler_name] = handler_class()
# 创建事件处理器实例,传递插件配置
handler_instance = handler_class()
if plugin_config is not None and hasattr(handler_instance, 'set_plugin_config'):
handler_instance.set_plugin_config(plugin_config)
self._event_handlers[handler_name] = handler_instance
# 处理init_subscribe缓存失败的订阅
if self._event_handlers[handler_name].init_subscribe:

View File

@@ -6,7 +6,7 @@ from src.plugin_system.core.global_announcement_manager import global_announceme
from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
import inspect
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger

View File

@@ -39,8 +39,9 @@ class EmojiAction(BaseAction):
llm_judge_prompt = """
判定是否需要使用表情动作的条件:
1. 用户明确要求使用表情包
2. 这是一个适合表达强烈情绪的场合
3. 不要发送太多表情包,如果你已经发送过多个表情包则回答""
2. 这是一个适合表达情绪的场合
3. 发表情包能使当前对话更有趣
4. 不要发送太多表情包,如果你已经发送过多个表情包则回答""
请回答""""
"""

View File

@@ -0,0 +1,279 @@
log/
logs/
out/
.env
.env.*
.cursor
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
uv.lock
llm_statistics.txt
mongodb
napcat
run_dev.bat
elua.confirmed
# C extensions
*.so
/results
config_backup/
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# PyPI configuration file
.pypirc
# jieba
jieba.cache
# .vscode
!.vscode/settings.json
# direnv
/.direnv
# JetBrains
.idea
*.iml
*.ipr
# PyEnv
# If using PyEnv and configured to use a specific Python version locally
# a .local-version file will be created in the root of the project to specify the version.
.python-version
OtherRes.txt
/eula.confirmed
/privacy.confirmed
logs
.ruff_cache
.vscode
/config/*
config/old/bot_config_20250405_212257.toml
temp/
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk
config.toml
feature.toml
config.toml.back
test
data/NapcatAdapter.db
data/NapcatAdapter.db-shm
data/NapcatAdapter.db-wal

View File

@@ -0,0 +1 @@
PLUGIN_NAME = "napcat_adapter"

View File

@@ -0,0 +1,42 @@
{
"manifest_version": 1,
"name": "napcat_plugin",
"version": "1.0.0",
"description": "基于OneBot 11协议的NapCat QQ协议插件提供完整的QQ机器人API接口使用现有adapter连接",
"author": {
"name": "Windpicker_owo",
"url": "https://github.com/Windpicker-owo"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.10.0",
"max_version": "0.10.0"
},
"homepage_url": "https://github.com/Windpicker-owo/InternetSearchPlugin",
"repository_url": "https://github.com/Windpicker-owo/InternetSearchPlugin",
"keywords": ["qq", "bot", "napcat", "onebot", "api", "websocket"],
"categories": ["protocol"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"components": [
{
"type": "tool",
"name": "napcat_tool",
"description": "NapCat QQ协议综合工具提供消息发送、群管理、好友管理、文件操作等完整功能"
}
],
"features": [
"消息发送与接收",
"群管理功能",
"好友管理功能",
"文件上传下载",
"AI语音功能",
"群签到与戳一戳",
"现有adapter连接"
]
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,357 @@
import asyncio
import json
import inspect
import websockets as Server
from . import event_types, CONSTS, event_handlers
from typing import List
from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType, ConfigField
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.apis import config_api
from src.common.logger import get_logger
from .src.message_chunker import chunker, reassembler
from .src.recv_handler.message_handler import message_handler
from .src.recv_handler.meta_event_handler import meta_event_handler
from .src.recv_handler.notice_handler import notice_handler
from .src.recv_handler.message_sending import message_send_instance
from .src.send_handler import send_handler
from .src.config.migrate_features import auto_migrate_features
from .src.mmc_com_layer import mmc_start_com, router, mmc_stop_com
from .src.response_pool import put_response, check_timeout_response
from .src.websocket_manager import websocket_manager
logger = get_logger("napcat_adapter")
message_queue = asyncio.Queue()
def get_classes_in_module(module):
classes = []
for name, member in inspect.getmembers(module):
if inspect.isclass(member):
classes.append(member)
return classes
async def message_recv(server_connection: Server.ServerConnection):
await message_handler.set_server_connection(server_connection)
asyncio.create_task(notice_handler.set_server_connection(server_connection))
await send_handler.set_server_connection(server_connection)
async for raw_message in server_connection:
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
decoded_raw_message: dict = json.loads(raw_message)
try:
# 首先尝试解析原始消息
decoded_raw_message: dict = json.loads(raw_message)
# 检查是否是切片消息 (来自 MMC)
if chunker.is_chunk_message(decoded_raw_message):
logger.debug("接收到切片消息,尝试重组")
# 尝试重组消息
reassembled_message = await reassembler.add_chunk(decoded_raw_message)
if reassembled_message:
# 重组完成,处理完整消息
logger.debug("消息重组完成,处理完整消息")
decoded_raw_message = reassembled_message
else:
# 切片尚未完整,继续等待更多切片
logger.debug("等待更多切片...")
continue
# 处理完整消息(可能是重组后的,也可能是原本就完整的)
post_type = decoded_raw_message.get("post_type")
if post_type in ["meta_event", "message", "notice"]:
await message_queue.put(decoded_raw_message)
elif post_type is None:
await put_response(decoded_raw_message)
except json.JSONDecodeError as e:
logger.error(f"消息解析失败: {e}")
logger.debug(f"原始消息: {raw_message[:500]}...")
except Exception as e:
logger.error(f"处理消息时出错: {e}")
logger.debug(f"原始消息: {raw_message[:500]}...")
async def message_process():
"""消息处理主循环"""
logger.info("消息处理器已启动")
try:
while True:
try:
# 使用超时等待,以便能够响应取消请求
message = await asyncio.wait_for(message_queue.get(), timeout=1.0)
post_type = message.get("post_type")
if post_type == "message":
await message_handler.handle_raw_message(message)
elif post_type == "meta_event":
await meta_event_handler.handle_meta_event(message)
elif post_type == "notice":
await notice_handler.handle_notice(message)
else:
logger.warning(f"未知的post_type: {post_type}")
message_queue.task_done()
await asyncio.sleep(0.05)
except asyncio.TimeoutError:
# 超时是正常的,继续循环
continue
except asyncio.CancelledError:
logger.info("消息处理器收到取消信号")
break
except Exception as e:
logger.error(f"处理消息时出错: {e}")
# 即使出错也标记任务完成,避免队列阻塞
try:
message_queue.task_done()
except ValueError:
pass
await asyncio.sleep(0.1)
except asyncio.CancelledError:
logger.info("消息处理器已停止")
raise
except Exception as e:
logger.error(f"消息处理器异常: {e}")
raise
finally:
logger.info("消息处理器正在清理...")
# 清空剩余的队列项目
try:
while not message_queue.empty():
try:
message_queue.get_nowait()
message_queue.task_done()
except asyncio.QueueEmpty:
break
except Exception as e:
logger.debug(f"清理消息队列时出错: {e}")
async def napcat_server(plugin_config: dict):
"""启动 Napcat WebSocket 连接(支持正向和反向连接)"""
# 使用插件系统配置API获取配置
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
logger.info(f"正在启动 adapter连接模式: {mode}")
try:
await websocket_manager.start_connection(message_recv, plugin_config)
except Exception as e:
logger.error(f"启动 WebSocket 连接失败: {e}")
raise
async def graceful_shutdown():
"""优雅关闭所有组件"""
try:
logger.info("正在关闭adapter...")
# 停止消息重组器的清理任务
try:
await reassembler.stop_cleanup_task()
except Exception as e:
logger.warning(f"停止消息重组器清理任务时出错: {e}")
# 停止功能管理器文件监控(已迁移到插件系统配置,无需操作)
# 关闭消息处理器(包括消息缓冲器)
try:
await message_handler.shutdown()
except Exception as e:
logger.warning(f"关闭消息处理器时出错: {e}")
# 关闭 WebSocket 连接
try:
await websocket_manager.stop_connection()
except Exception as e:
logger.warning(f"关闭WebSocket连接时出错: {e}")
# 关闭 MaiBot 连接
try:
await mmc_stop_com()
except Exception as e:
logger.warning(f"关闭MaiBot连接时出错: {e}")
# 取消所有剩余任务
current_task = asyncio.current_task()
tasks = [t for t in asyncio.all_tasks() if t is not current_task and not t.done()]
if tasks:
logger.info(f"正在取消 {len(tasks)} 个剩余任务...")
for task in tasks:
task.cancel()
# 等待任务取消完成,忽略 CancelledError
try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logger.warning("部分任务取消超时")
except Exception as e:
logger.debug(f"任务取消过程中的异常(可忽略): {e}")
logger.info("Adapter已成功关闭")
except Exception as e:
logger.error(f"Adapter关闭中出现错误: {e}")
finally:
# 确保消息队列被清空
try:
while not message_queue.empty():
try:
message_queue.get_nowait()
message_queue.task_done()
except asyncio.QueueEmpty:
break
except Exception:
pass
class LauchNapcatAdapterHandler(BaseEventHandler):
"""自动启动Adapter"""
handler_name: str = "launch_napcat_adapter_handler"
handler_description: str = "自动启动napcat adapter"
weight: int = 100
intercept_message: bool = False
init_subscribe = [EventType.ON_START]
async def execute(self, kwargs):
# 执行功能配置迁移(如果需要)
logger.info("检查功能配置迁移...")
auto_migrate_features()
# 启动消息重组器的清理任务
logger.info("启动消息重组器...")
await reassembler.start_cleanup_task()
# 功能管理器已迁移到插件系统配置
logger.info("功能配置已迁移到插件系统")
logger.info("开始启动Napcat Adapter")
message_send_instance.maibot_router = router
# 设置插件配置
message_send_instance.set_plugin_config(self.plugin_config)
# 设置chunker的插件配置
chunker.set_plugin_config(self.plugin_config)
# 设置response_pool的插件配置
from .src.response_pool import set_plugin_config as set_response_pool_config
set_response_pool_config(self.plugin_config)
# 设置send_handler的插件配置
send_handler.set_plugin_config(self.plugin_config)
# 设置message_handler的插件配置
message_handler.set_plugin_config(self.plugin_config)
# 设置notice_handler的插件配置
notice_handler.set_plugin_config(self.plugin_config)
# 设置meta_event_handler的插件配置
meta_event_handler.set_plugin_config(self.plugin_config)
# 创建单独的异步任务,防止阻塞主线程
asyncio.create_task(napcat_server(self.plugin_config))
asyncio.create_task(mmc_start_com(self.plugin_config))
asyncio.create_task(message_process())
asyncio.create_task(check_timeout_response())
class StopNapcatAdapterHandler(BaseEventHandler):
"""关闭Adapter"""
handler_name: str = "stop_napcat_adapter_handler"
handler_description: str = "关闭napcat adapter"
weight: int = 100
intercept_message: bool = False
init_subscribe = [EventType.ON_STOP]
async def execute(self, kwargs):
await graceful_shutdown()
return
@register_plugin
class NapcatAdapterPlugin(BasePlugin):
plugin_name = CONSTS.PLUGIN_NAME
enable_plugin: bool = True
dependencies: List[str] = [] # 插件依赖列表
python_dependencies: List[str] = [] # Python包依赖列表
config_file_name: str = "config.toml" # 配置文件名
# 配置节描述
config_section_descriptions = {"plugin": "插件基本信息"}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="napcat_adapter_plugin", description="插件名称"),
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
"config_version": ConfigField(type=str, default="1.2.0", description="配置文件版本"),
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
},
"inner": {
"version": ConfigField(type=str, default="0.2.1", description="配置版本号,请勿修改"),
},
"nickname": {
"nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"),
},
"napcat_server": {
"mode": ConfigField(type=str, default="reverse", description="连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]),
"host": ConfigField(type=str, default="localhost", description="主机地址"),
"port": ConfigField(type=int, default=8095, description="端口号"),
"url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL如 ws://localhost:8080/ws (仅在forward模式下使用)"),
"access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"),
"heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"),
},
"maibot_server": {
"host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址即HOST字段"),
"port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口即PORT字段"),
"platform_name": ConfigField(type=str, default="napcat", description="平台名称,用于消息路由"),
},
"voice": {
"use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音请确保你配置了tts并有对应的adapter"),
},
"slicing": {
"max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小单位为字节默认64KB"),
"delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"),
},
"debug": {
"level": ConfigField(type=str, default="INFO", description="日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
}
}
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息",
"inner": "内部配置信息(请勿修改)",
"nickname": "昵称配置(目前未使用)",
"napcat_server": "Napcat连接的ws服务设置",
"maibot_server": "连接麦麦的ws服务设置",
"voice": "发送语音设置",
"slicing": "WebSocket消息切片设置",
"debug": "调试设置"
}
def register_events(self):
# 注册事件
for e in event_types.NapcatEvent.ON_RECEIVED:
event_manager.register_event(e, allowed_triggers=[self.plugin_name])
for e in event_types.NapcatEvent.ACCOUNT:
event_manager.register_event(e, allowed_subscribers=[f"{e.value}_handler"])
for e in event_types.NapcatEvent.GROUP:
event_manager.register_event(e, allowed_subscribers=[f"{e.value}_handler"])
for e in event_types.NapcatEvent.MESSAGE:
event_manager.register_event(e, allowed_subscribers=[f"{e.value}_handler"])
def get_plugin_components(self):
self.register_events()
components = []
components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler))
components.append((StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler))
for handler in get_classes_in_module(event_handlers):
if issubclass(handler, BaseEventHandler):
components.append((handler.get_handler_info(), handler))
return components

View File

@@ -0,0 +1,47 @@
[project]
name = "MaiBotNapcatAdapter"
version = "0.4.8"
description = "A MaiBot adapter for Napcat"
dependencies = [
"ruff>=0.12.9",
]
[tool.ruff]
include = ["*.py"]
# 行长度设置
line-length = 120
[tool.ruff.lint]
fixable = ["ALL"]
unfixable = []
# 启用的规则
select = [
"E", # pycodestyle 错误
"F", # pyflakes
"B", # flake8-bugbear
]
ignore = ["E711","E501"]
[tool.ruff.format]
docstring-code-format = true
indent-style = "space"
# 使用双引号表示字符串
quote-style = "double"
# 尊重魔法尾随逗号
# 例如:
# items = [
# "apple",
# "banana",
# "cherry",
# ]
skip-magic-trailing-comma = false
# 自动检测合适的换行符
line-ending = "auto"

View File

@@ -0,0 +1,30 @@
from enum import Enum
import tomlkit
import os
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
class CommandType(Enum):
"""命令类型"""
GROUP_BAN = "set_group_ban" # 禁言用户
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
GROUP_KICK = "set_group_kick" # 踢出群聊
SEND_POKE = "send_poke" # 戳一戳
DELETE_MSG = "delete_msg" # 撤回消息
AI_VOICE_SEND = "send_group_ai_record" # 发送群AI语音
SET_EMOJI_LIKE = "set_emoji_like" # 设置表情回应
SEND_AT_MESSAGE = "send_at_message" # 艾特用户并发送消息
SEND_LIKE = "send_like" # 点赞
def __str__(self) -> str:
return self.value
pyproject_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "pyproject.toml")
toml_data = tomlkit.parse(open(pyproject_path, "r", encoding="utf-8").read())
project_data = toml_data.get("project", {})
version = project_data.get("version", "unknown")
logger.info(f"版本\n\nMaiBot-Napcat-Adapter 版本: {version}\n喜欢的话点个star喵~\n")

View File

@@ -0,0 +1,2 @@
# 配置已迁移到插件系统,此文件不再需要
# 所有配置访问应通过插件系统的 config_api 进行

View File

@@ -0,0 +1,136 @@
from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union
T = TypeVar("T", bound="ConfigBase")
TOML_DICT_TYPE = {
int,
float,
str,
bool,
list,
dict,
}
@dataclass
class ConfigBase:
"""配置类的基类"""
@classmethod
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
"""从字典加载配置字段"""
if not isinstance(data, dict):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: Dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
field_type = f.type
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
try:
init_args[field_name] = cls._convert_field(value, field_type)
except TypeError as e:
raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e
except Exception as e:
raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e
return cls(**init_args)
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
"""
转换字段值为指定类型
1. 对于嵌套的 dataclass递归调用相应的 from_dict 方法
2. 对于泛型集合类型list, set, tuple递归转换每个元素
3. 对于基础类型int, str, float, bool直接转换
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
"""
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
return field_type.from_dict(value)
field_origin_type = get_origin(field_type)
field_args_type = get_args(field_type)
# 处理泛型集合类型list, set, tuple
if field_origin_type in {list, set, tuple}:
# 检查提供的value是否为list
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list:
return [cls._convert_field(item, field_args_type[0]) for item in value]
if field_origin_type is set:
return {cls._convert_field(item, field_args_type[0]) for item in value}
if field_origin_type is tuple:
# 检查提供的value长度是否与类型参数一致
if len(value) != len(field_args_type):
raise TypeError(
f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type))
if field_origin_type is dict:
# 检查提供的value是否为dict
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
# 检查字典的键值类型
if len(field_args_type) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_args_type
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理Optional类型
if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union
if value is None:
return None
# 如果有数据,检查实际类型
if type(value) not in field_args_type:
raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}")
return cls._convert_field(value, field_args_type[0])
# 处理int, str, float, bool等基础类型
if field_origin_type is None:
if isinstance(value, field_type):
return field_type(value)
else:
raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}")
# 处理Literal类型
if field_origin_type is Literal:
# 获取Literal的允许值
allowed_values = get_args(field_type)
if value in allowed_values:
return value
else:
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
# 处理其他类型
if field_type is Any:
return value
# 其他类型直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e
def __str__(self):
"""返回配置类的字符串表示"""
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"

View File

@@ -0,0 +1,145 @@
"""
配置文件工具模块
提供统一的配置文件生成和管理功能
"""
import os
import shutil
from pathlib import Path
from datetime import datetime
from typing import Optional
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
def ensure_config_directories():
"""确保配置目录存在"""
os.makedirs("config", exist_ok=True)
os.makedirs("config/old", exist_ok=True)
def create_config_from_template(
config_path: str, template_path: str, config_name: str = "配置文件", should_exit: bool = True
) -> bool:
"""
从模板创建配置文件的统一函数
Args:
config_path: 配置文件路径
template_path: 模板文件路径
config_name: 配置文件名称(用于日志显示)
should_exit: 创建后是否退出程序
Returns:
bool: 是否成功创建配置文件
"""
try:
# 确保配置目录存在
ensure_config_directories()
config_path_obj = Path(config_path)
template_path_obj = Path(template_path)
# 检查配置文件是否存在
if config_path_obj.exists():
return False # 配置文件已存在,无需创建
logger.info(f"{config_name}不存在,从模板创建新配置")
# 检查模板文件是否存在
if not template_path_obj.exists():
logger.error(f"模板文件不存在: {template_path}")
if should_exit:
logger.critical("无法创建配置文件,程序退出")
quit(1)
return False
# 确保配置文件目录存在
config_path_obj.parent.mkdir(parents=True, exist_ok=True)
# 复制模板文件到配置目录
shutil.copy2(template_path_obj, config_path_obj)
logger.info(f"已创建新{config_name}: {config_path}")
if should_exit:
logger.info("程序将退出,请检查配置文件后重启")
quit(0)
return True
except Exception as e:
logger.error(f"创建{config_name}失败: {e}")
if should_exit:
logger.critical("无法创建配置文件,程序退出")
quit(1)
return False
def create_default_config_dict(default_values: dict, config_path: str, config_name: str = "配置文件") -> bool:
"""
创建默认配置文件(使用字典数据)
Args:
default_values: 默认配置值字典
config_path: 配置文件路径
config_name: 配置文件名称(用于日志显示)
Returns:
bool: 是否成功创建配置文件
"""
try:
import tomlkit
config_path_obj = Path(config_path)
# 确保配置文件目录存在
config_path_obj.parent.mkdir(parents=True, exist_ok=True)
# 写入默认配置
with open(config_path_obj, "w", encoding="utf-8") as f:
tomlkit.dump(default_values, f)
logger.info(f"已创建默认{config_name}: {config_path}")
return True
except Exception as e:
logger.error(f"创建默认{config_name}失败: {e}")
return False
def backup_config_file(config_path: str, backup_dir: str = "config/old") -> Optional[str]:
"""
备份配置文件
Args:
config_path: 要备份的配置文件路径
backup_dir: 备份目录
Returns:
Optional[str]: 备份文件路径失败时返回None
"""
try:
config_path_obj = Path(config_path)
if not config_path_obj.exists():
return None
# 确保备份目录存在
backup_dir_obj = Path(backup_dir)
backup_dir_obj.mkdir(parents=True, exist_ok=True)
# 创建备份文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_filename = f"{config_path_obj.stem}.toml.bak.{timestamp}"
backup_path = backup_dir_obj / backup_filename
# 备份文件
shutil.copy2(config_path_obj, backup_path)
logger.info(f"已备份配置文件到: {backup_path}")
return str(backup_path)
except Exception as e:
logger.error(f"备份配置文件失败: {e}")
return None

View File

@@ -0,0 +1,215 @@
"""
功能配置迁移脚本
用于将旧的配置文件中的聊天、权限、视频处理等设置迁移到新的独立功能配置文件
"""
import os
import shutil
from pathlib import Path
import tomlkit
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
def migrate_features_from_config(
old_config_path: str = "plugins/napcat_adapter_plugin/config/config.toml",
new_features_path: str = "plugins/napcat_adapter_plugin/config/features.toml",
template_path: str = "plugins/napcat_adapter_plugin/template/features_template.toml",
):
"""
从旧配置文件迁移功能设置到新的功能配置文件
Args:
old_config_path: 旧配置文件路径
new_features_path: 新功能配置文件路径
template_path: 功能配置模板路径
"""
try:
# 检查旧配置文件是否存在
if not os.path.exists(old_config_path):
logger.warning(f"旧配置文件不存在: {old_config_path}")
return False
# 读取旧配置文件
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
# 检查是否有chat配置段和video配置段
chat_config = old_config.get("chat", {})
video_config = old_config.get("video", {})
# 检查是否有权限相关配置
permission_keys = [
"group_list_type",
"group_list",
"private_list_type",
"private_list",
"ban_user_id",
"ban_qq_bot",
"enable_poke",
"ignore_non_self_poke",
"poke_debounce_seconds",
]
video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"]
has_permission_config = any(key in chat_config for key in permission_keys)
has_video_config = any(key in video_config for key in video_keys)
if not has_permission_config and not has_video_config:
logger.info("旧配置文件中没有找到功能相关配置,无需迁移")
return False
# 确保新功能配置目录存在
new_features_dir = Path(new_features_path).parent
new_features_dir.mkdir(parents=True, exist_ok=True)
# 如果新功能配置文件已存在,先备份
if os.path.exists(new_features_path):
backup_path = f"{new_features_path}.backup"
shutil.copy2(new_features_path, backup_path)
logger.info(f"已备份现有功能配置文件到: {backup_path}")
# 创建新的功能配置
new_features_config = {
"group_list_type": chat_config.get("group_list_type", "whitelist"),
"group_list": chat_config.get("group_list", []),
"private_list_type": chat_config.get("private_list_type", "whitelist"),
"private_list": chat_config.get("private_list", []),
"ban_user_id": chat_config.get("ban_user_id", []),
"ban_qq_bot": chat_config.get("ban_qq_bot", False),
"enable_poke": chat_config.get("enable_poke", True),
"ignore_non_self_poke": chat_config.get("ignore_non_self_poke", False),
"poke_debounce_seconds": chat_config.get("poke_debounce_seconds", 3),
"enable_video_analysis": video_config.get("enable_video_analysis", True),
"max_video_size_mb": video_config.get("max_video_size_mb", 100),
"download_timeout": video_config.get("download_timeout", 60),
"supported_formats": video_config.get(
"supported_formats", ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"]
),
}
# 写入新的功能配置文件
with open(new_features_path, "w", encoding="utf-8") as f:
tomlkit.dump(new_features_config, f)
logger.info(f"功能配置已成功迁移到: {new_features_path}")
# 显示迁移的配置内容
logger.info("迁移的配置内容:")
for key, value in new_features_config.items():
logger.info(f" {key}: {value}")
return True
except Exception as e:
logger.error(f"功能配置迁移失败: {e}")
return False
def remove_features_from_old_config(config_path: str = "plugins/napcat_adapter_plugin/config/config.toml"):
"""
从旧配置文件中移除功能相关配置,并将旧配置移动到 config/old/ 目录
Args:
config_path: 配置文件路径
"""
try:
if not os.path.exists(config_path):
logger.warning(f"配置文件不存在: {config_path}")
return False
# 确保 config/old 目录存在
old_config_dir = "plugins/napcat_adapter_plugin/config/old"
os.makedirs(old_config_dir, exist_ok=True)
# 备份原配置文件到 config/old 目录
old_config_path = os.path.join(old_config_dir, "config_with_features.toml")
shutil.copy2(config_path, old_config_path)
logger.info(f"已备份包含功能配置的原文件到: {old_config_path}")
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f)
# 移除chat段中的功能相关配置
removed_keys = []
if "chat" in config:
chat_config = config["chat"]
permission_keys = [
"group_list_type",
"group_list",
"private_list_type",
"private_list",
"ban_user_id",
"ban_qq_bot",
"enable_poke",
"ignore_non_self_poke",
"poke_debounce_seconds",
]
for key in permission_keys:
if key in chat_config:
del chat_config[key]
removed_keys.append(key)
if removed_keys:
logger.info(f"已从chat配置段中移除功能相关配置: {removed_keys}")
# 移除video段中的配置
if "video" in config:
video_config = config["video"]
video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"]
video_removed_keys = []
for key in video_keys:
if key in video_config:
del video_config[key]
video_removed_keys.append(key)
if video_removed_keys:
logger.info(f"已从video配置段中移除配置: {video_removed_keys}")
removed_keys.extend(video_removed_keys)
# 如果video段为空则删除整个段
if not video_config:
del config["video"]
logger.info("已删除空的video配置段")
if removed_keys:
logger.info(f"总共移除的配置项: {removed_keys}")
# 写回配置文件
with open(config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(config))
logger.info(f"已更新配置文件: {config_path}")
return True
except Exception as e:
logger.error(f"移除功能配置失败: {e}")
return False
def auto_migrate_features():
"""
自动执行功能配置迁移
"""
logger.info("开始自动功能配置迁移...")
# 执行迁移
if migrate_features_from_config():
logger.info("功能配置迁移成功")
# 询问是否要从旧配置文件中移除功能配置
logger.info("功能配置已迁移到独立文件,建议从主配置文件中移除相关配置")
# 在实际使用中,这里可以添加用户确认逻辑
# 为了自动化,这里直接执行移除
remove_features_from_old_config()
else:
logger.info("功能配置迁移跳过或失败")
if __name__ == "__main__":
auto_migrate_features()

View File

@@ -0,0 +1,74 @@
from dataclasses import dataclass, field
from typing import Literal
from .config_base import ConfigBase
"""
须知:
1. 本文件中记录了所有的配置项
2. 所有新增的class都需要继承自ConfigBase
3. 所有新增的class都应在config.py中的Config类中添加字段
4. 对于新增的字段若为可选项则应在其后添加field()并设置default_factory或default
"""
ADAPTER_PLATFORM = "qq"
@dataclass
class NicknameConfig(ConfigBase):
nickname: str
"""机器人昵称"""
@dataclass
class NapcatServerConfig(ConfigBase):
mode: Literal["reverse", "forward"] = "reverse"
"""连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)"""
host: str = "localhost"
"""主机地址"""
port: int = 8095
"""端口号"""
url: str = ""
"""正向连接时的完整WebSocket URL如 ws://localhost:8080/ws"""
access_token: str = ""
"""WebSocket 连接的访问令牌,用于身份验证"""
heartbeat_interval: int = 30
"""心跳间隔时间,单位为秒"""
@dataclass
class MaiBotServerConfig(ConfigBase):
platform_name: str = field(default=ADAPTER_PLATFORM, init=False)
"""平台名称“qq”"""
host: str = "localhost"
"""MaiMCore的主机地址"""
port: int = 8000
"""MaiMCore的端口号"""
@dataclass
class VoiceConfig(ConfigBase):
use_tts: bool = False
"""是否启用TTS功能"""
@dataclass
class SlicingConfig(ConfigBase):
max_frame_size: int = 64
"""WebSocket帧的最大大小单位为字节默认64KB"""
delay_ms: int = 10
"""切片发送间隔时间,单位为毫秒"""
@dataclass
class DebugConfig(ConfigBase):
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
"""日志级别默认为INFO"""

View File

@@ -0,0 +1,162 @@
import os
from typing import Optional, List
from dataclasses import dataclass
from sqlmodel import Field, Session, SQLModel, create_engine, select
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
"""
表记录的方式:
| group_id | user_id | lift_time |
|----------|---------|-----------|
其中使用 user_id == 0 表示群全体禁言
"""
@dataclass
class BanUser:
"""
程序处理使用的实例
"""
user_id: int
group_id: int
lift_time: Optional[int] = Field(default=-1)
class DB_BanUser(SQLModel, table=True):
"""
表示数据库中的用户禁言记录。
使用双重主键
"""
user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID
group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID
lift_time: Optional[int] # 禁言解除的时间(时间戳)
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
"""
检查两个 BanUser 对象是否相同。
"""
return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id
class DatabaseManager:
"""
数据库管理类,负责与数据库交互。
"""
def __init__(self):
os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在
DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db")
self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL
self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎
self._ensure_database() # 确保数据库和表已创建
def _ensure_database(self) -> None:
"""
确保数据库和表已创建。
"""
logger.info("确保数据库文件和表已创建...")
SQLModel.metadata.create_all(self.engine)
logger.info("数据库和表已创建或已存在")
def update_ban_record(self, ban_list: List[BanUser]) -> None:
# sourcery skip: class-extract-method
"""
更新禁言列表到数据库。
支持在不存在时创建新记录,对于多余的项目自动删除。
"""
with Session(self.engine) as session:
all_records = session.exec(select(DB_BanUser)).all()
for ban_user in ban_list:
statement = select(DB_BanUser).where(
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
)
if existing_record := session.exec(statement).first():
if existing_record.lift_time == ban_user.lift_time:
logger.debug(f"禁言记录未变更: {existing_record}")
continue
# 更新现有记录的 lift_time
existing_record.lift_time = ban_user.lift_time
session.add(existing_record)
logger.debug(f"更新禁言记录: {existing_record}")
else:
# 创建新记录
db_record = DB_BanUser(
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
)
session.add(db_record)
logger.debug(f"创建新禁言记录: {ban_user}")
# 删除不在 ban_list 中的记录
for db_record in all_records:
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
if not any(is_identical(record, ban_user) for ban_user in ban_list):
statement = select(DB_BanUser).where(
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
logger.debug(f"删除禁言记录: {ban_record}")
else:
logger.info(f"未找到禁言记录: {ban_record}")
logger.info("禁言记录已更新")
def get_ban_records(self) -> List[BanUser]:
"""
读取所有禁言记录。
"""
with Session(self.engine) as session:
statement = select(DB_BanUser)
records = session.exec(statement).all()
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
def create_ban_record(self, ban_record: BanUser) -> None:
"""
为特定群组中的用户创建禁言记录。
一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。
其同时还是简化版的更新方式。
"""
with Session(self.engine) as session:
# 检查记录是否已存在
statement = select(DB_BanUser).where(
DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id
)
existing_record = session.exec(statement).first()
if existing_record:
# 如果记录已存在,更新 lift_time
existing_record.lift_time = ban_record.lift_time
session.add(existing_record)
logger.debug(f"更新禁言记录: {ban_record}")
else:
# 如果记录不存在,创建新记录
db_record = DB_BanUser(
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
)
session.add(db_record)
logger.debug(f"创建新禁言记录: {ban_record}")
def delete_ban_record(self, ban_record: BanUser):
"""
删除特定用户在特定群组中的禁言记录。
一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。
"""
user_id = ban_record.user_id
group_id = ban_record.group_id
with Session(self.engine) as session:
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
logger.debug(f"删除禁言记录: {ban_record}")
else:
logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}")
db_manager = DatabaseManager()

View File

@@ -0,0 +1,314 @@
import asyncio
import time
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from src.plugin_system.apis import config_api
from .recv_handler import RealMessageType
@dataclass
class TextMessage:
"""文本消息"""
text: str
timestamp: float = field(default_factory=time.time)
@dataclass
class BufferedSession:
"""缓冲会话数据"""
session_id: str
messages: List[TextMessage] = field(default_factory=list)
timer_task: Optional[asyncio.Task] = None
delay_task: Optional[asyncio.Task] = None
original_event: Any = None
created_at: float = field(default_factory=time.time)
class SimpleMessageBuffer:
def __init__(self, merge_callback=None):
"""
初始化消息缓冲器
Args:
merge_callback: 消息合并后的回调函数,接收(session_id, merged_text, original_event)参数
"""
self.buffer_pool: Dict[str, BufferedSession] = {}
self.lock = asyncio.Lock()
self.merge_callback = merge_callback
self._shutdown = False
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
def get_session_id(self, event_data: Dict[str, Any]) -> str:
"""根据事件数据生成会话ID"""
message_type = event_data.get("message_type", "unknown")
user_id = event_data.get("user_id", "unknown")
if message_type == "private":
return f"private_{user_id}"
elif message_type == "group":
group_id = event_data.get("group_id", "unknown")
return f"group_{group_id}_{user_id}"
else:
return f"{message_type}_{user_id}"
def extract_text_from_message(self, message: List[Dict[str, Any]]) -> Optional[str]:
"""从OneBot消息中提取纯文本如果包含非文本内容则返回None"""
text_parts = []
has_non_text = False
logger.debug(f"正在提取消息文本,消息段数量: {len(message)}")
for msg_seg in message:
msg_type = msg_seg.get("type", "")
logger.debug(f"处理消息段类型: {msg_type}")
if msg_type == RealMessageType.text:
text = msg_seg.get("data", {}).get("text", "").strip()
if text:
text_parts.append(text)
logger.debug(f"提取到文本: {text[:50]}...")
else:
# 发现非文本消息段,标记为包含非文本内容
has_non_text = True
logger.debug(f"发现非文本消息段: {msg_type},跳过缓冲")
# 如果包含非文本内容,则不进行缓冲
if has_non_text:
logger.debug("消息包含非文本内容,不进行缓冲")
return None
if text_parts:
combined_text = " ".join(text_parts).strip()
logger.debug(f"成功提取纯文本: {combined_text[:50]}...")
return combined_text
logger.debug("没有找到有效的文本内容")
return None
def should_skip_message(self, text: str) -> bool:
"""判断消息是否应该跳过缓冲"""
if not text or not text.strip():
return True
# 检查屏蔽前缀
block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", []))
text = text.strip()
if text.startswith(block_prefixes):
logger.debug(f"消息以屏蔽前缀开头,跳过缓冲: {text[:20]}...")
return True
return False
async def add_text_message(
self, event_data: Dict[str, Any], message: List[Dict[str, Any]], original_event: Any = None
) -> bool:
"""
添加文本消息到缓冲区
Args:
event_data: 事件数据
message: OneBot消息数组
original_event: 原始事件对象
Returns:
是否成功添加到缓冲区
"""
if self._shutdown:
return False
# 检查是否启用消息缓冲
if not config_api.get_plugin_config(self.plugin_config, "features.enable_message_buffer", False):
return False
# 检查是否启用对应类型的缓冲
message_type = event_data.get("message_type", "")
if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False):
return False
elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False):
return False
# 提取文本
text = self.extract_text_from_message(message)
if not text:
return False
# 检查是否应该跳过
if self.should_skip_message(text):
return False
session_id = self.get_session_id(event_data)
async with self.lock:
# 获取或创建会话
if session_id not in self.buffer_pool:
self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event)
session = self.buffer_pool[session_id]
# 检查是否超过最大组件数量
if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5):
logger.info(f"会话 {session_id} 消息数量达到上限,强制合并")
asyncio.create_task(self._force_merge_session(session_id))
self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event)
session = self.buffer_pool[session_id]
# 添加文本消息
session.messages.append(TextMessage(text=text))
session.original_event = original_event # 更新事件
# 取消之前的定时器
await self._cancel_session_timers(session)
# 设置新的延迟任务
session.delay_task = asyncio.create_task(self._wait_and_start_merge(session_id))
logger.debug(f"文本消息已添加到缓冲器 {session_id}: {text[:50]}...")
return True
async def _cancel_session_timers(self, session: BufferedSession):
"""取消会话的所有定时器"""
for task_name in ["timer_task", "delay_task"]:
task = getattr(session, task_name)
if task and not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
setattr(session, task_name, None)
async def _wait_and_start_merge(self, session_id: str):
"""等待初始延迟后开始合并定时器"""
initial_delay = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_initial_delay", 0.5)
await asyncio.sleep(initial_delay)
async with self.lock:
session = self.buffer_pool.get(session_id)
if session and session.messages:
# 取消旧的定时器
if session.timer_task and not session.timer_task.done():
session.timer_task.cancel()
try:
await session.timer_task
except asyncio.CancelledError:
pass
# 设置合并定时器
session.timer_task = asyncio.create_task(self._wait_and_merge(session_id))
async def _wait_and_merge(self, session_id: str):
"""等待合并间隔后执行合并"""
interval = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_interval", 2.0)
await asyncio.sleep(interval)
await self._merge_session(session_id)
async def _force_merge_session(self, session_id: str):
"""强制合并会话(不等待定时器)"""
await self._merge_session(session_id, force=True)
async def _merge_session(self, session_id: str, force: bool = False):
"""合并会话中的消息"""
async with self.lock:
session = self.buffer_pool.get(session_id)
if not session or not session.messages:
self.buffer_pool.pop(session_id, None)
return
try:
# 合并文本消息
text_parts = []
for msg in session.messages:
if msg.text.strip():
text_parts.append(msg.text.strip())
if not text_parts:
self.buffer_pool.pop(session_id, None)
return
merged_text = "".join(text_parts) # 使用中文逗号连接
message_count = len(session.messages)
logger.info(f"合并会话 {session_id}{message_count} 条文本消息: {merged_text[:100]}...")
# 调用回调函数
if self.merge_callback:
try:
if asyncio.iscoroutinefunction(self.merge_callback):
await self.merge_callback(session_id, merged_text, session.original_event)
else:
self.merge_callback(session_id, merged_text, session.original_event)
except Exception as e:
logger.error(f"消息合并回调执行失败: {e}")
except Exception as e:
logger.error(f"合并会话 {session_id} 时出错: {e}")
finally:
# 清理会话
await self._cancel_session_timers(session)
self.buffer_pool.pop(session_id, None)
async def flush_session(self, session_id: str):
"""强制刷新指定会话的缓冲区"""
await self._force_merge_session(session_id)
async def flush_all(self):
"""强制刷新所有会话的缓冲区"""
session_ids = list(self.buffer_pool.keys())
for session_id in session_ids:
await self._force_merge_session(session_id)
async def get_buffer_stats(self) -> Dict[str, Any]:
"""获取缓冲区统计信息"""
async with self.lock:
stats = {"total_sessions": len(self.buffer_pool), "sessions": {}}
for session_id, session in self.buffer_pool.items():
stats["sessions"][session_id] = {
"message_count": len(session.messages),
"created_at": session.created_at,
"age": time.time() - session.created_at,
}
return stats
async def clear_expired_sessions(self, max_age: float = 300.0):
"""清理过期的会话"""
current_time = time.time()
expired_sessions = []
async with self.lock:
for session_id, session in self.buffer_pool.items():
if current_time - session.created_at > max_age:
expired_sessions.append(session_id)
for session_id in expired_sessions:
logger.info(f"清理过期会话: {session_id}")
await self._force_merge_session(session_id)
async def shutdown(self):
"""关闭消息缓冲器"""
self._shutdown = True
logger.info("正在关闭简化消息缓冲器...")
# 刷新所有缓冲区
await self.flush_all()
# 确保所有任务都被取消
async with self.lock:
for session in list(self.buffer_pool.values()):
await self._cancel_session_timers(session)
self.buffer_pool.clear()
logger.info("简化消息缓冲器已关闭")

View File

@@ -0,0 +1,280 @@
"""
消息切片处理模块
用于在 Ada 发送给 MMC 时进行消息切片,利用 WebSocket 协议的自动重组特性
仅在 Ada -> MMC 方向进行切片其他方向MMC -> AdaAda <-> Napcat不切片
"""
import json
import uuid
import asyncio
import time
from typing import List, Dict, Any, Optional, Union
from src.plugin_system.apis import config_api
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
class MessageChunker:
"""消息切片器,用于处理大消息的分片发送"""
def __init__(self):
self.max_chunk_size = 64 * 1024 # 默认值,将在设置配置时更新
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
if plugin_config:
max_frame_size = config_api.get_plugin_config(plugin_config, "slicing.max_frame_size", 64)
self.max_chunk_size = max_frame_size * 1024
def should_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool:
"""判断消息是否需要切片"""
try:
if isinstance(message, dict):
message_str = json.dumps(message, ensure_ascii=False)
else:
message_str = message
return len(message_str.encode("utf-8")) > self.max_chunk_size
except Exception as e:
logger.error(f"检查消息大小时出错: {e}")
return False
def chunk_message(
self, message: Union[str, Dict[str, Any]], chunk_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
将消息切片
Args:
message: 要切片的消息(字符串或字典)
chunk_id: 切片组ID如果不提供则自动生成
Returns:
切片后的消息字典列表
"""
try:
# 统一转换为字符串
if isinstance(message, dict):
message_str = json.dumps(message, ensure_ascii=False)
else:
message_str = message
if not self.should_chunk_message(message_str):
# 不需要切片的情况,如果输入是字典则返回字典,如果是字符串则包装成非切片标记的字典
if isinstance(message, dict):
return [message]
else:
return [{"_original_message": message_str}]
if chunk_id is None:
chunk_id = str(uuid.uuid4())
message_bytes = message_str.encode("utf-8")
total_size = len(message_bytes)
# 计算需要多少个切片
num_chunks = (total_size + self.max_chunk_size - 1) // self.max_chunk_size
chunks = []
for i in range(num_chunks):
start_pos = i * self.max_chunk_size
end_pos = min(start_pos + self.max_chunk_size, total_size)
chunk_data = message_bytes[start_pos:end_pos]
# 构建切片消息
chunk_message = {
"__mmc_chunk_info__": {
"chunk_id": chunk_id,
"chunk_index": i,
"total_chunks": num_chunks,
"chunk_size": len(chunk_data),
"total_size": total_size,
"timestamp": time.time(),
},
"__mmc_chunk_data__": chunk_data.decode("utf-8", errors="ignore"),
"__mmc_is_chunked__": True,
}
chunks.append(chunk_message)
logger.debug(f"消息切片完成: {total_size} bytes -> {num_chunks} chunks (ID: {chunk_id})")
return chunks
except Exception as e:
logger.error(f"消息切片时出错: {e}")
# 出错时返回原消息
if isinstance(message, dict):
return [message]
else:
return [{"_original_message": message}]
def is_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool:
"""判断是否是切片消息"""
try:
if isinstance(message, str):
data = json.loads(message)
else:
data = message
return (
isinstance(data, dict)
and "__mmc_chunk_info__" in data
and "__mmc_chunk_data__" in data
and "__mmc_is_chunked__" in data
)
except (json.JSONDecodeError, TypeError):
return False
class MessageReassembler:
"""消息重组器,用于重组接收到的切片消息"""
def __init__(self, timeout: int = 30):
self.timeout = timeout
self.chunk_buffers: Dict[str, Dict[str, Any]] = {}
self._cleanup_task = None
async def start_cleanup_task(self):
"""启动清理任务"""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_expired_chunks())
async def stop_cleanup_task(self):
"""停止清理任务"""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
async def _cleanup_expired_chunks(self):
"""清理过期的切片缓冲区"""
while True:
try:
await asyncio.sleep(10) # 每10秒检查一次
current_time = time.time()
expired_chunks = []
for chunk_id, buffer_info in self.chunk_buffers.items():
if current_time - buffer_info["timestamp"] > self.timeout:
expired_chunks.append(chunk_id)
for chunk_id in expired_chunks:
logger.warning(f"清理过期的切片缓冲区: {chunk_id}")
del self.chunk_buffers[chunk_id]
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"清理过期切片时出错: {e}")
async def add_chunk(self, message: Union[str, Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""
添加切片,如果切片完整则返回重组后的消息
Args:
message: 切片消息(字符串或字典)
Returns:
如果切片完整则返回重组后的原始消息字典否则返回None
"""
try:
# 统一转换为字典
if isinstance(message, str):
chunk_data = json.loads(message)
else:
chunk_data = message
# 检查是否是切片消息
if not chunker.is_chunk_message(chunk_data):
# 不是切片消息,直接返回
if "_original_message" in chunk_data:
# 这是一个被包装的非切片消息,解包返回
try:
return json.loads(chunk_data["_original_message"])
except json.JSONDecodeError:
return {"text_message": chunk_data["_original_message"]}
else:
return chunk_data
chunk_info = chunk_data["__mmc_chunk_info__"]
chunk_content = chunk_data["__mmc_chunk_data__"]
chunk_id = chunk_info["chunk_id"]
chunk_index = chunk_info["chunk_index"]
total_chunks = chunk_info["total_chunks"]
chunk_timestamp = chunk_info.get("timestamp", time.time())
# 初始化缓冲区
if chunk_id not in self.chunk_buffers:
self.chunk_buffers[chunk_id] = {
"chunks": {},
"total_chunks": total_chunks,
"received_chunks": 0,
"timestamp": chunk_timestamp,
}
buffer = self.chunk_buffers[chunk_id]
# 检查切片是否已经接收过
if chunk_index in buffer["chunks"]:
logger.warning(f"重复接收切片: {chunk_id}#{chunk_index}")
return None
# 添加切片
buffer["chunks"][chunk_index] = chunk_content
buffer["received_chunks"] += 1
buffer["timestamp"] = time.time() # 更新时间戳
logger.debug(f"接收切片: {chunk_id}#{chunk_index} ({buffer['received_chunks']}/{total_chunks})")
# 检查是否接收完整
if buffer["received_chunks"] == total_chunks:
# 重组消息
reassembled_message = ""
for i in range(total_chunks):
if i not in buffer["chunks"]:
logger.error(f"切片 {chunk_id}#{i} 缺失,无法重组")
return None
reassembled_message += buffer["chunks"][i]
# 清理缓冲区
del self.chunk_buffers[chunk_id]
logger.debug(f"消息重组完成: {chunk_id} ({len(reassembled_message)} chars)")
# 尝试反序列化重组后的消息
try:
return json.loads(reassembled_message)
except json.JSONDecodeError:
# 如果不能反序列化为JSON则作为文本消息返回
return {"text_message": reassembled_message}
return None
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.error(f"处理切片消息时出错: {e}")
return None
def get_pending_chunks_info(self) -> Dict[str, Any]:
"""获取待处理切片信息"""
info = {}
for chunk_id, buffer in self.chunk_buffers.items():
info[chunk_id] = {
"received": buffer["received_chunks"],
"total": buffer["total_chunks"],
"progress": f"{buffer['received_chunks']}/{buffer['total_chunks']}",
"age_seconds": time.time() - buffer["timestamp"],
}
return info
# 全局实例
chunker = MessageChunker()
reassembler = MessageReassembler()

View File

@@ -0,0 +1,44 @@
from maim_message import Router, RouteConfig, TargetConfig
from src.common.logger import get_logger
from .send_handler import send_handler
from src.plugin_system.apis import config_api
logger = get_logger("napcat_adapter")
router = None
def create_router(plugin_config: dict):
"""创建路由器实例"""
global router
platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "napcat")
host = config_api.get_plugin_config(plugin_config, "maibot_server.host", "localhost")
port = config_api.get_plugin_config(plugin_config, "maibot_server.port", 8000)
route_config = RouteConfig(
route_config={
platform_name: TargetConfig(
url=f"ws://{host}:{port}/ws",
token=None,
)
}
)
router = Router(route_config)
return router
async def mmc_start_com(plugin_config: dict = None):
"""启动MaiBot连接"""
logger.info("正在连接MaiBot")
if plugin_config:
create_router(plugin_config)
if router:
router.register_class_handler(send_handler.handle_message)
await router.run()
async def mmc_stop_com():
"""停止MaiBot连接"""
if router:
await router.stop()

View File

@@ -0,0 +1,89 @@
from enum import Enum
class MetaEventType:
lifecycle = "lifecycle" # 生命周期
class Lifecycle:
connect = "connect" # 生命周期 - WebSocket 连接成功
heartbeat = "heartbeat" # 心跳
class MessageType: # 接受消息大类
private = "private" # 私聊消息
class Private:
friend = "friend" # 私聊消息 - 好友
group = "group" # 私聊消息 - 群临时
group_self = "group_self" # 私聊消息 - 群中自身发送
other = "other" # 私聊消息 - 其他
group = "group" # 群聊消息
class Group:
normal = "normal" # 群聊消息 - 普通
anonymous = "anonymous" # 群聊消息 - 匿名消息
notice = "notice" # 群聊消息 - 系统提示
class NoticeType: # 通知事件
friend_recall = "friend_recall" # 私聊消息撤回
group_recall = "group_recall" # 群聊消息撤回
notify = "notify"
group_ban = "group_ban" # 群禁言
class Notify:
poke = "poke" # 戳一戳
input_status = "input_status" # 正在输入
class GroupBan:
ban = "ban" # 禁言
lift_ban = "lift_ban" # 解除禁言
class RealMessageType: # 实际消息分类
text = "text" # 纯文本
face = "face" # qq表情
image = "image" # 图片
record = "record" # 语音
video = "video" # 视频
at = "at" # @某人
rps = "rps" # 猜拳魔法表情
dice = "dice" # 骰子
shake = "shake" # 私聊窗口抖动(只收)
poke = "poke" # 群聊戳一戳
share = "share" # 链接分享json形式
reply = "reply" # 回复消息
forward = "forward" # 转发消息
node = "node" # 转发消息节点
json = "json" # json消息
class MessageSentType:
private = "private"
class Private:
friend = "friend"
group = "group"
group = "group"
class Group:
normal = "normal"
class CommandType(Enum):
"""命令类型"""
GROUP_BAN = "set_group_ban" # 禁言用户
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
GROUP_KICK = "set_group_kick" # 踢出群聊
SEND_POKE = "send_poke" # 戳一戳
DELETE_MSG = "delete_msg" # 撤回消息
def __str__(self) -> str:
return self.value
ACCEPT_FORMAT = ["text", "image", "emoji", "reply", "voice", "command", "voiceurl", "music", "videourl", "file"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,81 @@
import asyncio
from src.common.logger import get_logger
from ..message_chunker import chunker
from src.plugin_system.apis import config_api
logger = get_logger("napcat_adapter")
from maim_message import MessageBase, Router
class MessageSending:
"""
负责把消息发送到麦麦
"""
maibot_router: Router = None
plugin_config = None
def __init__(self):
pass
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
async def message_send(self, message_base: MessageBase) -> bool:
"""
发送消息Ada -> MMC 方向,需要实现切片)
Parameters:
message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息
"""
try:
# 检查是否需要切片发送
message_dict = message_base.to_dict()
if chunker.should_chunk_message(message_dict):
logger.info("消息过大,进行切片发送到 MaiBot")
# 切片消息
chunks = chunker.chunk_message(message_dict)
# 逐个发送切片
for i, chunk in enumerate(chunks):
logger.debug(f"发送切片 {i + 1}/{len(chunks)} 到 MaiBot")
# 获取对应的客户端并发送切片
platform = message_base.message_info.platform
if platform not in self.maibot_router.clients:
logger.error(f"平台 {platform} 未连接")
return False
client = self.maibot_router.clients[platform]
send_status = await client.send_message(chunk)
if not send_status:
logger.error(f"发送切片 {i + 1}/{len(chunks)} 失败")
return False
# 使用配置中的延迟时间
if i < len(chunks) - 1 and self.plugin_config:
delay_ms = config_api.get_plugin_config(self.plugin_config, "slicing.delay_ms", 10)
delay_seconds = delay_ms / 1000.0
logger.debug(f"切片发送延迟: {delay_ms}毫秒")
await asyncio.sleep(delay_seconds)
logger.debug("所有切片发送完成")
return True
else:
# 直接发送小消息
send_status = await self.maibot_router.send_message(message_base)
if not send_status:
raise RuntimeError("可能是路由未正确配置或连接异常")
return send_status
except Exception as e:
logger.error(f"发送消息失败: {str(e)}")
logger.error("请检查与MaiBot之间的连接")
return False
message_send_instance = MessageSending()

View File

@@ -0,0 +1,58 @@
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from src.plugin_system.apis import config_api
import time
import asyncio
from . import MetaEventType
class MetaEventHandler:
"""
处理Meta事件
"""
def __init__(self):
self.interval = 5.0 # 默认值稍后通过set_plugin_config设置
self._interval_checking = False
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
# 更新interval值
self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
async def handle_meta_event(self, message: dict) -> None:
event_type = message.get("meta_event_type")
if event_type == MetaEventType.lifecycle:
sub_type = message.get("sub_type")
if sub_type == MetaEventType.Lifecycle.connect:
self_id = message.get("self_id")
self.last_heart_beat = time.time()
logger.info(f"Bot {self_id} 连接成功")
asyncio.create_task(self.check_heartbeat(self_id))
elif event_type == MetaEventType.heartbeat:
if message["status"].get("online") and message["status"].get("good"):
if not self._interval_checking:
asyncio.create_task(self.check_heartbeat())
self.last_heart_beat = time.time()
self.interval = message.get("interval") / 1000
else:
self_id = message.get("self_id")
logger.warning(f"Bot {self_id} Napcat 端异常!")
async def check_heartbeat(self, id: int) -> None:
self._interval_checking = True
while True:
now_time = time.time()
if now_time - self.last_heart_beat > self.interval * 2:
logger.error(f"Bot {id} 可能发生了连接断开被下线或者Napcat卡死")
break
else:
logger.debug("心跳正常")
await asyncio.sleep(self.interval)
meta_event_handler = MetaEventHandler()

View File

@@ -0,0 +1,560 @@
import time
import json
import asyncio
import websockets as Server
from typing import Tuple, Optional
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from src.plugin_system.apis import config_api
from ..database import BanUser, db_manager, is_identical
from . import NoticeType, ACCEPT_FORMAT
from .message_sending import message_send_instance
from .message_handler import message_handler
from maim_message import FormatInfo, UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase
from ..websocket_manager import websocket_manager
from ..utils import (
get_group_info,
get_member_info,
get_self_info,
get_stranger_info,
read_ban_list,
)
from ...CONSTS import PLUGIN_NAME
notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=100)
unsuccessful_notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=3)
class NoticeHandler:
banned_list: list[BanUser] = [] # 当前仍在禁言中的用户列表
lifted_list: list[BanUser] = [] # 已经自然解除禁言
def __init__(self):
self.server_connection: Server.ServerConnection | None = None
self.last_poke_time: float = 0.0 # 记录最后一次针对机器人的戳一戳时间
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
self.server_connection = server_connection
while self.server_connection.state != Server.State.OPEN:
await asyncio.sleep(0.5)
self.banned_list, self.lifted_list = await read_ban_list(self.server_connection)
asyncio.create_task(self.auto_lift_detect())
asyncio.create_task(self.send_notice())
asyncio.create_task(self.handle_natural_lift())
def get_server_connection(self) -> Server.ServerConnection:
"""获取当前的服务器连接"""
# 优先使用直接设置的连接,否则从 websocket_manager 获取
if self.server_connection:
return self.server_connection
return websocket_manager.get_connection()
def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None:
"""
将用户禁言记录添加到self.banned_list中
如果是全体禁言则user_id为0
"""
if user_id is None:
user_id = 0 # 使用0表示全体禁言
lift_time = -1
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time)
for record in self.banned_list:
if is_identical(record, ban_record):
self.banned_list.remove(record)
self.banned_list.append(ban_record)
db_manager.create_ban_record(ban_record) # 作为更新
return
self.banned_list.append(ban_record)
db_manager.create_ban_record(ban_record) # 添加到数据库
def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None:
"""
从self.lifted_group_list中移除已经解除全体禁言的群
"""
if user_id is None:
user_id = 0 # 使用0表示全体禁言
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1)
self.lifted_list.append(ban_record)
db_manager.delete_ban_record(ban_record) # 删除数据库中的记录
async def handle_notice(self, raw_message: dict) -> None:
notice_type = raw_message.get("notice_type")
# message_time: int = raw_message.get("time")
message_time: float = time.time() # 应可乐要求现在是float了
group_id = raw_message.get("group_id")
user_id = raw_message.get("user_id")
target_id = raw_message.get("target_id")
handled_message: Seg = None
user_info: UserInfo = None
system_notice: bool = False
match notice_type:
case NoticeType.friend_recall:
logger.info("好友撤回一条消息")
logger.info(f"撤回消息ID{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
logger.warning("暂时不支持撤回消息处理")
case NoticeType.group_recall:
logger.info("群内用户撤回一条消息")
logger.info(f"撤回消息ID{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
logger.warning("暂时不支持撤回消息处理")
case NoticeType.notify:
sub_type = raw_message.get("sub_type")
match sub_type:
case NoticeType.Notify.poke:
if config_api.get_plugin_config(self.plugin_config, "features.poke_enabled", True) and await message_handler.check_allow_to_chat(
user_id, group_id, False, False
):
logger.info("处理戳一戳消息")
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
else:
logger.warning("戳一戳消息被禁用,取消戳一戳处理")
case NoticeType.Notify.input_status:
from src.plugin_system.core.event_manager import event_manager
from ...event_types import NapcatEvent
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, plugin_name=PLUGIN_NAME)
case _:
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
case NoticeType.group_ban:
sub_type = raw_message.get("sub_type")
match sub_type:
case NoticeType.GroupBan.ban:
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
return None
logger.info("处理群禁言")
handled_message, user_info = await self.handle_ban_notify(raw_message, group_id)
system_notice = True
case NoticeType.GroupBan.lift_ban:
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
return None
logger.info("处理解除群禁言")
handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id)
system_notice = True
case _:
logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}")
case _:
logger.warning(f"不支持的notice类型: {notice_type}")
return None
if not handled_message or not user_info:
logger.warning("notice处理失败或不支持")
return None
group_info: GroupInfo = None
if group_id:
fetched_group_info = await get_group_info(self.get_server_connection(), group_id)
group_name: str = None
if fetched_group_info:
group_name = fetched_group_info.get("group_name")
else:
logger.warning("无法获取notice消息所在群的名称")
group_info = GroupInfo(
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
group_id=group_id,
group_name=group_name,
)
message_info: BaseMessageInfo = BaseMessageInfo(
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
message_id="notice",
time=message_time,
user_info=user_info,
group_info=group_info,
template_info=None,
format_info=FormatInfo(
content_format=["text", "notify"],
accept_format=ACCEPT_FORMAT,
),
additional_config={"target_id": target_id}, # 在这里塞了一个target_id方便mmc那边知道被戳的人是谁
)
message_base: MessageBase = MessageBase(
message_info=message_info,
message_segment=handled_message,
raw_message=json.dumps(raw_message),
)
if system_notice:
await self.put_notice(message_base)
else:
logger.info("发送到Maibot处理通知信息")
await message_send_instance.message_send(message_base)
async def handle_poke_notify(
self, raw_message: dict, group_id: int, user_id: int
) -> Tuple[Seg | None, UserInfo | None]:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, remove-redundant-if, remove-unnecessary-else, swap-if-else-branches
self_info: dict = await get_self_info(self.get_server_connection())
if not self_info:
logger.error("自身信息获取失败")
return None, None
self_id = raw_message.get("self_id")
target_id = raw_message.get("target_id")
# 防抖检查:如果是针对机器人的戳一戳,检查防抖时间
if self_id == target_id:
current_time = time.time()
debounce_seconds = config_api.get_plugin_config(self.plugin_config, "features.poke_debounce_seconds", 2.0)
if self.last_poke_time > 0:
time_diff = current_time - self.last_poke_time
if time_diff < debounce_seconds:
logger.info(f"戳一戳防抖:用户 {user_id} 的戳一戳被忽略(距离上次戳一戳 {time_diff:.2f} 秒)")
return None, None
# 记录这次戳一戳的时间
self.last_poke_time = current_time
target_name: str = None
raw_info: list = raw_message.get("raw_info")
if group_id:
user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
else:
user_qq_info: dict = await get_stranger_info(self.get_server_connection(), user_id)
if user_qq_info:
user_name = user_qq_info.get("nickname")
user_cardname = user_qq_info.get("card")
else:
user_name = "QQ用户"
user_cardname = "QQ用户"
logger.info("无法获取戳一戳对方的用户昵称")
# 计算Seg
if self_id == target_id:
display_name = ""
target_name = self_info.get("nickname")
elif self_id == user_id:
# 让ada不发送麦麦戳别人的消息
return None, None
else:
# 如果配置为忽略不是针对自己的戳一戳则直接返回None
if config_api.get_plugin_config(self.plugin_config, "features.non_self_poke_ignored", False):
logger.info("忽略不是针对自己的戳一戳消息")
return None, None
# 老实说这一步判定没啥意义,毕竟私聊是没有其他人之间的戳一戳,但是感觉可以有这个判定来强限制群聊环境
if group_id:
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, target_id)
if fetched_member_info:
target_name = fetched_member_info.get("nickname")
else:
target_name = "QQ用户"
logger.info("无法获取被戳一戳方的用户昵称")
display_name = user_name
else:
return None, None
first_txt: str = "戳了戳"
second_txt: str = ""
try:
first_txt = raw_info[2].get("txt", "戳了戳")
second_txt = raw_info[4].get("txt", "")
except Exception as e:
logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本")
user_info: UserInfo = UserInfo(
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id,
user_nickname=user_name,
user_cardname=user_cardname,
)
seg_data: Seg = Seg(
type="text",
data=f"{display_name}{first_txt}{target_name}{second_txt}这是QQ的一个功能用于提及某人但没那么明显",
)
return seg_data, user_info
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
if not group_id:
logger.error("群ID不能为空无法处理禁言通知")
return None, None
# 计算user_info
operator_id = raw_message.get("operator_id")
operator_nickname: str = None
operator_cardname: str = None
member_info: dict = await get_member_info(self.get_server_connection(), group_id, operator_id)
if member_info:
operator_nickname = member_info.get("nickname")
operator_cardname = member_info.get("card")
else:
logger.warning("无法获取禁言执行者的昵称,消息可能会无效")
operator_nickname = "QQ用户"
operator_info: UserInfo = UserInfo(
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=operator_id,
user_nickname=operator_nickname,
user_cardname=operator_cardname,
)
# 计算Seg
user_id = raw_message.get("user_id")
banned_user_info: UserInfo = None
user_nickname: str = "QQ用户"
user_cardname: str = None
sub_type: str = None
duration = raw_message.get("duration")
if duration is None:
logger.error("禁言时长不能为空,无法处理禁言通知")
return None, None
if user_id == 0: # 为全体禁言
sub_type: str = "whole_ban"
self._ban_operation(group_id)
else: # 为单人禁言
# 获取被禁言人的信息
sub_type: str = "ban"
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
if fetched_member_info:
user_nickname = fetched_member_info.get("nickname")
user_cardname = fetched_member_info.get("card")
banned_user_info: UserInfo = UserInfo(
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
)
self._ban_operation(group_id, user_id, int(time.time() + duration))
seg_data: Seg = Seg(
type="notify",
data={
"sub_type": sub_type,
"duration": duration,
"banned_user_info": banned_user_info.to_dict() if banned_user_info else None,
},
)
return seg_data, operator_info
async def handle_lift_ban_notify(
self, raw_message: dict, group_id: int
) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
if not group_id:
logger.error("群ID不能为空无法处理解除禁言通知")
return None, None
# 计算user_info
operator_id = raw_message.get("operator_id")
operator_nickname: str = None
operator_cardname: str = None
member_info: dict = await get_member_info(self.get_server_connection(), group_id, operator_id)
if member_info:
operator_nickname = member_info.get("nickname")
operator_cardname = member_info.get("card")
else:
logger.warning("无法获取解除禁言执行者的昵称,消息可能会无效")
operator_nickname = "QQ用户"
operator_info: UserInfo = UserInfo(
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=operator_id,
user_nickname=operator_nickname,
user_cardname=operator_cardname,
)
# 计算Seg
sub_type: str = None
user_nickname: str = "QQ用户"
user_cardname: str = None
lifted_user_info: UserInfo = None
user_id = raw_message.get("user_id")
if user_id == 0: # 全体禁言解除
sub_type = "whole_lift_ban"
self._lift_operation(group_id)
else: # 单人禁言解除
sub_type = "lift_ban"
# 获取被解除禁言人的信息
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
if fetched_member_info:
user_nickname = fetched_member_info.get("nickname")
user_cardname = fetched_member_info.get("card")
else:
logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效")
lifted_user_info: UserInfo = UserInfo(
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
)
self._lift_operation(group_id, user_id)
seg_data: Seg = Seg(
type="notify",
data={
"sub_type": sub_type,
"lifted_user_info": lifted_user_info.to_dict() if lifted_user_info else None,
},
)
return seg_data, operator_info
async def put_notice(self, message_base: MessageBase) -> None:
"""
将处理后的通知消息放入通知队列
"""
if notice_queue.full() or unsuccessful_notice_queue.full():
logger.warning("通知队列已满,可能是多次发送失败,消息丢弃")
else:
await notice_queue.put(message_base)
async def handle_natural_lift(self) -> None:
while True:
if len(self.lifted_list) != 0:
lift_record = self.lifted_list.pop()
group_id = lift_record.group_id
user_id = lift_record.user_id
db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录
seg_message: Seg = await self.natural_lift(group_id, user_id)
fetched_group_info = await get_group_info(self.get_server_connection(), group_id)
group_name: str = None
if fetched_group_info:
group_name = fetched_group_info.get("group_name")
else:
logger.warning("无法获取notice消息所在群的名称")
group_info = GroupInfo(
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
group_id=group_id,
group_name=group_name,
)
message_info: BaseMessageInfo = BaseMessageInfo(
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
message_id="notice",
time=time.time(),
user_info=None, # 自然解除禁言没有操作者
group_info=group_info,
template_info=None,
format_info=None,
)
message_base: MessageBase = MessageBase(
message_info=message_info,
message_segment=seg_message,
raw_message=json.dumps(
{
"post_type": "notice",
"notice_type": "group_ban",
"sub_type": "lift_ban",
"group_id": group_id,
"user_id": user_id,
"operator_id": None, # 自然解除禁言没有操作者
}
),
)
await self.put_notice(message_base)
await asyncio.sleep(0.5) # 确保队列处理间隔
else:
await asyncio.sleep(5) # 每5秒检查一次
async def natural_lift(self, group_id: int, user_id: int) -> Seg | None:
if not group_id:
logger.error("群ID不能为空无法处理解除禁言通知")
return None
if user_id == 0: # 理论上永远不会触发
return Seg(
type="notify",
data={
"sub_type": "whole_lift_ban",
"lifted_user_info": None,
},
)
user_nickname: str = "QQ用户"
user_cardname: str = None
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
if fetched_member_info:
user_nickname = fetched_member_info.get("nickname")
user_cardname = fetched_member_info.get("card")
lifted_user_info: UserInfo = UserInfo(
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
)
return Seg(
type="notify",
data={
"sub_type": "lift_ban",
"lifted_user_info": lifted_user_info.to_dict(),
},
)
async def auto_lift_detect(self) -> None:
while True:
if len(self.banned_list) == 0:
await asyncio.sleep(5)
continue
for ban_record in self.banned_list:
if ban_record.user_id == 0 or ban_record.lift_time == -1:
continue
if ban_record.lift_time <= int(time.time()):
# 触发自然解除禁言
logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除")
self.lifted_list.append(ban_record)
self.banned_list.remove(ban_record)
await asyncio.sleep(5)
async def send_notice(self) -> None:
"""
发送通知消息到Napcat
"""
while True:
if not unsuccessful_notice_queue.empty():
to_be_send: MessageBase = await unsuccessful_notice_queue.get()
try:
send_status = await message_send_instance.message_send(to_be_send)
if send_status:
unsuccessful_notice_queue.task_done()
else:
await unsuccessful_notice_queue.put(to_be_send)
except Exception as e:
logger.error(f"发送通知消息失败: {str(e)}")
await unsuccessful_notice_queue.put(to_be_send)
await asyncio.sleep(1)
continue
to_be_send: MessageBase = await notice_queue.get()
try:
send_status = await message_send_instance.message_send(to_be_send)
if send_status:
notice_queue.task_done()
else:
await unsuccessful_notice_queue.put(to_be_send)
except Exception as e:
logger.error(f"发送通知消息失败: {str(e)}")
await unsuccessful_notice_queue.put(to_be_send)
await asyncio.sleep(1)
notice_handler = NoticeHandler()

View File

@@ -0,0 +1,250 @@
qq_face: dict = {
"0": "[表情:惊讶]",
"1": "[表情:撇嘴]",
"2": "[表情:色]",
"3": "[表情:发呆]",
"4": "[表情:得意]",
"5": "[表情:流泪]",
"6": "[表情:害羞]",
"7": "[表情:闭嘴]",
"8": "[表情:睡]",
"9": "[表情:大哭]",
"10": "[表情:尴尬]",
"11": "[表情:发怒]",
"12": "[表情:调皮]",
"13": "[表情:呲牙]",
"14": "[表情:微笑]",
"15": "[表情:难过]",
"16": "[表情:酷]",
"18": "[表情:抓狂]",
"19": "[表情:吐]",
"20": "[表情:偷笑]",
"21": "[表情:可爱]",
"22": "[表情:白眼]",
"23": "[表情:傲慢]",
"24": "[表情:饥饿]",
"25": "[表情:困]",
"26": "[表情:惊恐]",
"27": "[表情:流汗]",
"28": "[表情:憨笑]",
"29": "[表情:悠闲]",
"30": "[表情:奋斗]",
"31": "[表情:咒骂]",
"32": "[表情:疑问]",
"33": "[表情: 嘘]",
"34": "[表情:晕]",
"35": "[表情:折磨]",
"36": "[表情:衰]",
"37": "[表情:骷髅]",
"38": "[表情:敲打]",
"39": "[表情:再见]",
"41": "[表情:发抖]",
"42": "[表情:爱情]",
"43": "[表情:跳跳]",
"46": "[表情:猪头]",
"49": "[表情:拥抱]",
"53": "[表情:蛋糕]",
"56": "[表情:刀]",
"59": "[表情:便便]",
"60": "[表情:咖啡]",
"63": "[表情:玫瑰]",
"64": "[表情:凋谢]",
"66": "[表情:爱心]",
"67": "[表情:心碎]",
"74": "[表情:太阳]",
"75": "[表情:月亮]",
"76": "[表情:赞]",
"77": "[表情:踩]",
"78": "[表情:握手]",
"79": "[表情:胜利]",
"85": "[表情:飞吻]",
"86": "[表情:怄火]",
"89": "[表情:西瓜]",
"96": "[表情:冷汗]",
"97": "[表情:擦汗]",
"98": "[表情:抠鼻]",
"99": "[表情:鼓掌]",
"100": "[表情:糗大了]",
"101": "[表情:坏笑]",
"102": "[表情:左哼哼]",
"103": "[表情:右哼哼]",
"104": "[表情:哈欠]",
"105": "[表情:鄙视]",
"106": "[表情:委屈]",
"107": "[表情:快哭了]",
"108": "[表情:阴险]",
"109": "[表情:左亲亲]",
"110": "[表情:吓]",
"111": "[表情:可怜]",
"112": "[表情:菜刀]",
"114": "[表情:篮球]",
"116": "[表情:示爱]",
"118": "[表情:抱拳]",
"119": "[表情:勾引]",
"120": "[表情:拳头]",
"121": "[表情:差劲]",
"123": "[表情NO]",
"124": "[表情OK]",
"125": "[表情:转圈]",
"129": "[表情:挥手]",
"137": "[表情:鞭炮]",
"144": "[表情:喝彩]",
"146": "[表情:爆筋]",
"147": "[表情:棒棒糖]",
"169": "[表情:手枪]",
"171": "[表情:茶]",
"172": "[表情:眨眼睛]",
"173": "[表情:泪奔]",
"174": "[表情:无奈]",
"175": "[表情:卖萌]",
"176": "[表情:小纠结]",
"177": "[表情:喷血]",
"178": "[表情:斜眼笑]",
"179": "[表情doge]",
"181": "[表情:戳一戳]",
"182": "[表情:笑哭]",
"183": "[表情:我最美]",
"185": "[表情:羊驼]",
"187": "[表情:幽灵]",
"201": "[表情:点赞]",
"212": "[表情:托腮]",
"262": "[表情:脑阔疼]",
"263": "[表情:沧桑]",
"264": "[表情:捂脸]",
"265": "[表情:辣眼睛]",
"266": "[表情:哦哟]",
"267": "[表情:头秃]",
"268": "[表情:问号脸]",
"269": "[表情:暗中观察]",
"270": "[表情emm]",
"271": "[表情:吃 瓜]",
"272": "[表情:呵呵哒]",
"273": "[表情:我酸了]",
"277": "[表情:汪汪]",
"281": "[表情:无眼笑]",
"282": "[表情:敬礼]",
"283": "[表情:狂笑]",
"284": "[表情:面无表情]",
"285": "[表情:摸鱼]",
"286": "[表情:魔鬼笑]",
"287": "[表情:哦]",
"289": "[表情:睁眼]",
"293": "[表情:摸锦鲤]",
"294": "[表情:期待]",
"295": "[表情:拿到红包]",
"297": "[表情:拜谢]",
"298": "[表情:元宝]",
"299": "[表情:牛啊]",
"300": "[表情:胖三斤]",
"302": "[表情:左拜年]",
"303": "[表情:右拜年]",
"305": "[表情:右亲亲]",
"306": "[表情:牛气冲天]",
"307": "[表情:喵喵]",
"311": "[表情打call]",
"312": "[表情:变形]",
"314": "[表情:仔细分析]",
"317": "[表情:菜汪]",
"318": "[表情:崇拜]",
"319": "[表情: 比心]",
"320": "[表情:庆祝]",
"323": "[表情:嫌弃]",
"324": "[表情:吃糖]",
"325": "[表情:惊吓]",
"326": "[表情:生气]",
"332": "[表情:举牌牌]",
"333": "[表情:烟花]",
"334": "[表情:虎虎生威]",
"336": "[表情:豹富]",
"337": "[表情:花朵脸]",
"338": "[表情:我想开了]",
"339": "[表情:舔屏]",
"341": "[表情:打招呼]",
"342": "[表情酸Q]",
"343": "[表情:我方了]",
"344": "[表情:大怨种]",
"345": "[表情:红包多多]",
"346": "[表情:你真棒棒]",
"347": "[表情:大展宏兔]",
"349": "[表情:坚强]",
"350": "[表情:贴贴]",
"351": "[表情:敲敲]",
"352": "[表情:咦]",
"353": "[表情:拜托]",
"354": "[表情:尊嘟假嘟]",
"355": "[表情:耶]",
"356": "[表情666]",
"357": "[表情:裂开]",
"392": "[表情:龙年 快乐]",
"393": "[表情:新年中龙]",
"394": "[表情:新年大龙]",
"395": "[表情:略略略]",
"😊": "[表情:嘿嘿]",
"😌": "[表情:羞涩]",
"😚": "[ 表情:亲亲]",
"😓": "[表情:汗]",
"😰": "[表情:紧张]",
"😝": "[表情:吐舌]",
"😁": "[表情:呲牙]",
"😜": "[表情:淘气]",
"": "[表情:可爱]",
"😍": "[表情:花痴]",
"😔": "[表情:失落]",
"😄": "[表情:高兴]",
"😏": "[表情:哼哼]",
"😒": "[表情:不屑]",
"😳": "[表情:瞪眼]",
"😘": "[表情:飞吻]",
"😭": "[表情:大哭]",
"😱": "[表情:害怕]",
"😂": "[表情:激动]",
"💪": "[表情:肌肉]",
"👊": "[表情:拳头]",
"👍": "[表情 :厉害]",
"👏": "[表情:鼓掌]",
"👎": "[表情:鄙视]",
"🙏": "[表情:合十]",
"👌": "[表情:好的]",
"👆": "[表情:向上]",
"👀": "[表情:眼睛]",
"🍜": "[表情:拉面]",
"🍧": "[表情:刨冰]",
"🍞": "[表情:面包]",
"🍺": "[表情:啤酒]",
"🍻": "[表情:干杯]",
"": "[表情:咖啡]",
"🍎": "[表情:苹果]",
"🍓": "[表情:草莓]",
"🍉": "[表情:西瓜]",
"🚬": "[表情:吸烟]",
"🌹": "[表情:玫瑰]",
"🎉": "[表情:庆祝]",
"💝": "[表情:礼物]",
"💣": "[表情:炸弹]",
"": "[表情:闪光]",
"💨": "[表情:吹气]",
"💦": "[表情:水]",
"🔥": "[表情:火]",
"💤": "[表情:睡觉]",
"💩": "[表情:便便]",
"💉": "[表情:打针]",
"📫": "[表情:邮箱]",
"🐎": "[表情:骑马]",
"👧": "[表情:女孩]",
"👦": "[表情:男孩]",
"🐵": "[表情:猴]",
"🐷": "[表情:猪]",
"🐮": "[表情:牛]",
"🐔": "[表情:公鸡]",
"🐸": "[表情:青蛙]",
"👻": "[表情:幽灵]",
"🐛": "[表情:虫]",
"🐶": "[表情:狗]",
"🐳": "[表情:鲸鱼]",
"👢": "[表情:靴子]",
"": "[表情:晴天]",
"": "[表情:问号]",
"🔫": "[表情:手枪]",
"💓": "[表情:爱 心]",
"🏪": "[表情:便利店]",
}

View File

@@ -0,0 +1,61 @@
import asyncio
import time
from typing import Dict
from src.common.logger import get_logger
from src.plugin_system.apis import config_api
logger = get_logger("napcat_adapter")
response_dict: Dict = {}
response_time_dict: Dict = {}
plugin_config = None
def set_plugin_config(config: dict):
"""设置插件配置"""
global plugin_config
plugin_config = config
async def get_response(request_id: str, timeout: int = 10) -> dict:
response = await asyncio.wait_for(_get_response(request_id), timeout)
_ = response_time_dict.pop(request_id)
logger.info(f"响应信息id: {request_id} 已从响应字典中取出")
return response
async def _get_response(request_id: str) -> dict:
"""
内部使用的获取响应函数,主要用于在需要时获取响应
"""
while request_id not in response_dict:
await asyncio.sleep(0.2)
return response_dict.pop(request_id)
async def put_response(response: dict):
echo_id = response.get("echo")
now_time = time.time()
response_dict[echo_id] = response
response_time_dict[echo_id] = now_time
logger.info(f"响应信息id: {echo_id} 已存入响应字典")
async def check_timeout_response() -> None:
while True:
cleaned_message_count: int = 0
now_time = time.time()
# 获取心跳间隔配置
heartbeat_interval = 30 # 默认值
if plugin_config:
heartbeat_interval = config_api.get_plugin_config(plugin_config, "napcat_server.heartbeat_interval", 30)
for echo_id, response_time in list(response_time_dict.items()):
if now_time - response_time > heartbeat_interval:
cleaned_message_count += 1
response_dict.pop(echo_id)
response_time_dict.pop(echo_id)
logger.warning(f"响应消息 {echo_id} 超时,已删除")
logger.info(f"已删除 {cleaned_message_count} 条超时响应消息")
await asyncio.sleep(heartbeat_interval)

View File

@@ -0,0 +1,678 @@
import json
import time
import random
import websockets as Server
import uuid
import asyncio
from maim_message import (
UserInfo,
GroupInfo,
Seg,
BaseMessageInfo,
MessageBase,
)
from typing import Dict, Any, Tuple, Optional
from src.plugin_system.apis import config_api
from . import CommandType
from .response_pool import get_response
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from .utils import get_image_format, convert_image_to_gif
from .recv_handler.message_sending import message_send_instance
from .websocket_manager import websocket_manager
class SendHandler:
def __init__(self):
self.server_connection: Optional[Server.ServerConnection] = None
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
self.server_connection = server_connection
def get_server_connection(self) -> Optional[Server.ServerConnection]:
"""获取当前的服务器连接"""
# 优先使用直接设置的连接,否则从 websocket_manager 获取
if self.server_connection:
return self.server_connection
return websocket_manager.get_connection()
async def handle_message(self, raw_message_base_dict: dict) -> None:
raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_dict)
message_segment: Seg = raw_message_base.message_segment
logger.info("接收到来自MaiBot的消息处理中")
if message_segment.type == "command":
logger.info("处理命令")
return await self.send_command(raw_message_base)
elif message_segment.type == "adapter_command":
logger.info("处理适配器命令")
return await self.handle_adapter_command(raw_message_base)
else:
logger.info("处理普通消息")
return await self.send_normal_message(raw_message_base)
async def send_normal_message(self, raw_message_base: MessageBase) -> None:
"""
处理普通消息发送
"""
logger.info("处理普通信息中")
message_info: BaseMessageInfo = raw_message_base.message_info
message_segment: Seg = raw_message_base.message_segment
group_info: Optional[GroupInfo] = message_info.group_info
user_info: Optional[UserInfo] = message_info.user_info
target_id: Optional[int] = None
action: Optional[str] = None
id_name: Optional[str] = None
processed_message: list = []
try:
if user_info:
processed_message = await self.handle_seg_recursive(message_segment, user_info)
except Exception as e:
logger.error(f"处理消息时发生错误: {e}")
return
if not processed_message:
logger.critical("现在暂时不支持解析此回复!")
return None
if group_info and user_info:
logger.debug("发送群聊消息")
target_id = int(group_info.group_id) if group_info.group_id else None
action = "send_group_msg"
id_name = "group_id"
elif user_info:
logger.debug("发送私聊消息")
target_id = int(user_info.user_id) if user_info.user_id else None
action = "send_private_msg"
id_name = "user_id"
else:
logger.error("无法识别的消息类型")
return
logger.info("尝试发送到napcat")
response = await self.send_message_to_napcat(
action,
{
id_name: target_id,
"message": processed_message,
},
)
if response.get("status") == "ok":
logger.info("消息发送成功")
qq_message_id = response.get("data", {}).get("message_id")
await self.message_sent_back(raw_message_base, qq_message_id)
else:
logger.warning(f"消息发送失败napcat返回{str(response)}")
async def send_command(self, raw_message_base: MessageBase) -> None:
"""
处理命令类
"""
logger.info("处理命令中")
message_info: BaseMessageInfo = raw_message_base.message_info
message_segment: Seg = raw_message_base.message_segment
group_info: Optional[GroupInfo] = message_info.group_info
seg_data: Dict[str, Any] = message_segment.data if isinstance(message_segment.data, dict) else {}
command_name: Optional[str] = seg_data.get("name")
try:
args = seg_data.get("args", {})
if not isinstance(args, dict):
args = {}
match command_name:
case CommandType.GROUP_BAN.name:
command, args_dict = self.handle_ban_command(args, group_info)
case CommandType.GROUP_WHOLE_BAN.name:
command, args_dict = self.handle_whole_ban_command(args, group_info)
case CommandType.GROUP_KICK.name:
command, args_dict = self.handle_kick_command(args, group_info)
case CommandType.SEND_POKE.name:
command, args_dict = self.handle_poke_command(args, group_info)
case CommandType.DELETE_MSG.name:
command, args_dict = self.delete_msg_command(args)
case CommandType.AI_VOICE_SEND.name:
command, args_dict = self.handle_ai_voice_send_command(args, group_info)
case CommandType.SET_EMOJI_LIKE.name:
command, args_dict = self.handle_set_emoji_like_command(args)
case CommandType.SEND_AT_MESSAGE.name:
command, args_dict = self.handle_at_message_command(args, group_info)
case CommandType.SEND_LIKE.name:
command, args_dict = self.handle_send_like_command(args)
case _:
logger.error(f"未知命令: {command_name}")
return
except Exception as e:
logger.error(f"处理命令时发生错误: {e}")
return None
if not command or not args_dict:
logger.error("命令或参数缺失")
return None
response = await self.send_message_to_napcat(command, args_dict)
if response.get("status") == "ok":
logger.info(f"命令 {command_name} 执行成功")
else:
logger.warning(f"命令 {command_name} 执行失败napcat返回{str(response)}")
async def handle_adapter_command(self, raw_message_base: MessageBase) -> None:
"""
处理适配器命令类 - 用于直接向Napcat发送命令并返回结果
"""
logger.info("处理适配器命令中")
message_info: BaseMessageInfo = raw_message_base.message_info
message_segment: Seg = raw_message_base.message_segment
seg_data: Dict[str, Any] = message_segment.data if isinstance(message_segment.data, dict) else {}
try:
action = seg_data.get("action")
params = seg_data.get("params", {})
request_id = seg_data.get("request_id")
if not action:
logger.error("适配器命令缺少action参数")
await self.send_adapter_command_response(
raw_message_base, {"status": "error", "message": "缺少action参数"}, request_id
)
return
logger.info(f"执行适配器命令: {action}")
# 直接向Napcat发送命令并获取响应
response_task = asyncio.create_task(self.send_message_to_napcat(action, params))
response = await response_task
# 发送响应回MaiBot
await self.send_adapter_command_response(raw_message_base, response, request_id)
if response.get("status") == "ok":
logger.info(f"适配器命令 {action} 执行成功")
else:
logger.warning(f"适配器命令 {action} 执行失败napcat返回{str(response)}")
except Exception as e:
logger.error(f"处理适配器命令时发生错误: {e}")
error_response = {"status": "error", "message": str(e)}
await self.send_adapter_command_response(raw_message_base, error_response, seg_data.get("request_id"))
def get_level(self, seg_data: Seg) -> int:
if seg_data.type == "seglist":
return 1 + max(self.get_level(seg) for seg in seg_data.data)
else:
return 1
async def handle_seg_recursive(self, seg_data: Seg, user_info: UserInfo) -> list:
payload: list = []
if seg_data.type == "seglist":
# level = self.get_level(seg_data) # 给以后可能的多层嵌套做准备,此处不使用
if not seg_data.data:
return []
for seg in seg_data.data:
payload = await self.process_message_by_type(seg, payload, user_info)
else:
payload = await self.process_message_by_type(seg_data, payload, user_info)
return payload
async def process_message_by_type(self, seg: Seg, payload: list, user_info: UserInfo) -> list:
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
new_payload = payload
if seg.type == "reply":
target_id = seg.data
if target_id == "notice":
return payload
new_payload = self.build_payload(
payload,
await self.handle_reply_message(target_id if isinstance(target_id, str) else "", user_info),
True,
)
elif seg.type == "text":
text = seg.data
if not text:
return payload
new_payload = self.build_payload(
payload,
self.handle_text_message(text if isinstance(text, str) else ""),
False,
)
elif seg.type == "face":
logger.warning("MaiBot 发送了qq原生表情暂时不支持")
elif seg.type == "image":
image = seg.data
new_payload = self.build_payload(payload, self.handle_image_message(image), False)
elif seg.type == "emoji":
emoji = seg.data
new_payload = self.build_payload(payload, self.handle_emoji_message(emoji), False)
elif seg.type == "voice":
voice = seg.data
new_payload = self.build_payload(payload, self.handle_voice_message(voice), False)
elif seg.type == "voiceurl":
voice_url = seg.data
new_payload = self.build_payload(payload, self.handle_voiceurl_message(voice_url), False)
elif seg.type == "music":
song_id = seg.data
new_payload = self.build_payload(payload, self.handle_music_message(song_id), False)
elif seg.type == "videourl":
video_url = seg.data
new_payload = self.build_payload(payload, self.handle_videourl_message(video_url), False)
elif seg.type == "file":
file_path = seg.data
new_payload = self.build_payload(payload, self.handle_file_message(file_path), False)
return new_payload
def build_payload(self, payload: list, addon: dict | list, is_reply: bool = False) -> list:
# sourcery skip: for-append-to-extend, merge-list-append, simplify-generator
"""构建发送的消息体"""
if is_reply:
temp_list = []
if isinstance(addon, list):
temp_list.extend(addon)
else:
temp_list.append(addon)
for i in payload:
if i.get("type") == "reply":
logger.debug("检测到多个回复,使用最新的回复")
continue
temp_list.append(i)
return temp_list
else:
if isinstance(addon, list):
payload.extend(addon)
else:
payload.append(addon)
return payload
async def handle_reply_message(self, id: str, user_info: UserInfo) -> dict | list:
"""处理回复消息"""
reply_seg = {"type": "reply", "data": {"id": id}}
# 检查是否启用引用艾特功能
if not config_api.get_plugin_config(self.plugin_config, "features.enable_reply_at", False):
return reply_seg
try:
# 尝试通过 message_id 获取消息详情
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)})
replied_user_id = None
if msg_info_response and msg_info_response.get("status") == "ok":
sender_info = msg_info_response.get("data", {}).get("sender")
if sender_info:
replied_user_id = sender_info.get("user_id")
# 如果没有获取到被回复者的ID则直接返回不进行@
if not replied_user_id:
logger.warning(f"无法获取消息 {id} 的发送者信息,跳过 @")
return reply_seg
# 根据概率决定是否艾特用户
if random.random() < config_api.get_plugin_config(self.plugin_config, "features.reply_at_rate", 0.5):
at_seg = {"type": "at", "data": {"qq": str(replied_user_id)}}
# 在艾特后面添加一个空格
text_seg = {"type": "text", "data": {"text": " "}}
return [reply_seg, at_seg, text_seg]
except Exception as e:
logger.error(f"处理引用回复并尝试@时出错: {e}")
# 出现异常时,只发送普通的回复,避免程序崩溃
return reply_seg
return reply_seg
def handle_text_message(self, message: str) -> dict:
"""处理文本消息"""
return {"type": "text", "data": {"text": message}}
def handle_image_message(self, encoded_image: str) -> dict:
"""处理图片消息"""
return {
"type": "image",
"data": {
"file": f"base64://{encoded_image}",
"subtype": 0,
},
} # base64 编码的图片
def handle_emoji_message(self, encoded_emoji: str) -> dict:
"""处理表情消息"""
encoded_image = encoded_emoji
image_format = get_image_format(encoded_emoji)
if image_format != "gif":
encoded_image = convert_image_to_gif(encoded_emoji)
return {
"type": "image",
"data": {
"file": f"base64://{encoded_image}",
"subtype": 1,
"summary": "[动画表情]",
},
}
def handle_voice_message(self, encoded_voice: str) -> dict:
"""处理语音消息"""
use_tts = False
if self.plugin_config:
use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False)
if not use_tts:
logger.warning("未启用语音消息处理")
return {}
if not encoded_voice:
return {}
return {
"type": "record",
"data": {"file": f"base64://{encoded_voice}"},
}
def handle_voiceurl_message(self, voice_url: str) -> dict:
"""处理语音链接消息"""
return {
"type": "record",
"data": {"file": voice_url},
}
def handle_music_message(self, song_id: str) -> dict:
"""处理音乐消息"""
return {
"type": "music",
"data": {"type": "163", "id": song_id},
}
def handle_videourl_message(self, video_url: str) -> dict:
"""处理视频链接消息"""
return {
"type": "video",
"data": {"file": video_url},
}
def handle_file_message(self, file_path: str) -> dict:
"""处理文件消息"""
return {
"type": "file",
"data": {"file": f"file://{file_path}"},
}
def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""处理删除消息命令"""
return "delete_msg", {"message_id": args["message_id"]}
def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理封禁命令
Args:
args (Dict[str, Any]): 参数字典
group_info (GroupInfo): 群聊信息(对应目标群聊)
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
duration: int = int(args["duration"])
user_id: int = int(args["qq_id"])
group_id: int = int(group_info.group_id)
if duration < 0:
raise ValueError("封禁时间必须大于等于0")
if not user_id or not group_id:
raise ValueError("封禁命令缺少必要参数")
if duration > 2592000:
raise ValueError("封禁时间不能超过30天")
return (
CommandType.GROUP_BAN.value,
{
"group_id": group_id,
"user_id": user_id,
"duration": duration,
},
)
def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理全体禁言命令
Args:
args (Dict[str, Any]): 参数字典
group_info (GroupInfo): 群聊信息(对应目标群聊)
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
enable = args["enable"]
assert isinstance(enable, bool), "enable参数必须是布尔值"
group_id: int = int(group_info.group_id)
if group_id <= 0:
raise ValueError("群组ID无效")
return (
CommandType.GROUP_WHOLE_BAN.value,
{
"group_id": group_id,
"enable": enable,
},
)
def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理群成员踢出命令
Args:
args (Dict[str, Any]): 参数字典
group_info (GroupInfo): 群聊信息(对应目标群聊)
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
user_id: int = int(args["qq_id"])
group_id: int = int(group_info.group_id)
if group_id <= 0:
raise ValueError("群组ID无效")
if user_id <= 0:
raise ValueError("用户ID无效")
return (
CommandType.GROUP_KICK.value,
{
"group_id": group_id,
"user_id": user_id,
"reject_add_request": False, # 不拒绝加群请求
},
)
def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理戳一戳命令
Args:
args (Dict[str, Any]): 参数字典
group_info (GroupInfo): 群聊信息(对应目标群聊)
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
user_id: int = int(args["qq_id"])
if group_info is None:
group_id = None
else:
group_id: int = int(group_info.group_id)
if group_id <= 0:
raise ValueError("群组ID无效")
if user_id <= 0:
raise ValueError("用户ID无效")
return (
CommandType.SEND_POKE.value,
{
"group_id": group_id,
"user_id": user_id,
},
)
def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""处理设置表情回应命令
Args:
args (Dict[str, Any]): 参数字典
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
try:
message_id = int(args["message_id"])
emoji_id = int(args["emoji_id"])
set_like = str(args["set"])
except:
raise ValueError("缺少必需参数: message_id 或 emoji_id")
return (
CommandType.SET_EMOJI_LIKE.value,
{"message_id": message_id, "emoji_id": emoji_id, "set": set_like},
)
def handle_send_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""
处理发送点赞命令的逻辑。
Args:
args (Dict[str, Any]): 参数字典
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
try:
user_id: int = int(args["qq_id"])
times: int = int(args["times"])
except (KeyError, ValueError):
raise ValueError("缺少必需参数: qq_id 或 times")
return (
CommandType.SEND_LIKE.value,
{"user_id": user_id, "times": times},
)
def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""
处理AI语音发送命令的逻辑。
并返回 NapCat 兼容的 (action, params) 元组。
"""
if not group_info or not group_info.group_id:
raise ValueError("AI语音发送命令必须在群聊上下文中使用")
if not args:
raise ValueError("AI语音发送命令缺少参数")
group_id: int = int(group_info.group_id)
character_id = args.get("character")
text_content = args.get("text")
if not character_id or not text_content:
raise ValueError(f"AI语音发送命令参数不完整: character='{character_id}', text='{text_content}'")
return (
CommandType.AI_VOICE_SEND.value,
{
"group_id": group_id,
"text": text_content,
"character": character_id,
},
)
async def send_message_to_napcat(self, action: str, params: dict) -> dict:
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": action, "params": params, "echo": request_uuid})
# 获取当前连接
connection = self.get_server_connection()
if not connection:
logger.error("没有可用的 Napcat 连接")
return {"status": "error", "message": "no connection"}
try:
await connection.send(payload)
response = await get_response(request_uuid)
except TimeoutError:
logger.error("发送消息超时,未收到响应")
return {"status": "error", "message": "timeout"}
except Exception as e:
logger.error(f"发送消息失败: {e}")
return {"status": "error", "message": str(e)}
return response
async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None:
# 修改 additional_config添加 echo 字段
if message_base.message_info.additional_config is None:
message_base.message_info.additional_config = {}
message_base.message_info.additional_config["echo"] = True
# 获取原始的 mmc_message_id
mmc_message_id = message_base.message_info.message_id
# 修改 message_segment 为 notify 类型
message_base.message_segment = Seg(
type="notify", data={"sub_type": "echo", "echo": mmc_message_id, "actual_id": qq_message_id}
)
await message_send_instance.message_send(message_base)
logger.debug("已回送消息ID")
return
async def send_adapter_command_response(
self, original_message: MessageBase, response_data: dict, request_id: str
) -> None:
"""
发送适配器命令响应回MaiBot
Args:
original_message: 原始消息
response_data: 响应数据
request_id: 请求ID
"""
try:
# 修改 additional_config添加 echo 字段
if original_message.message_info.additional_config is None:
original_message.message_info.additional_config = {}
original_message.message_info.additional_config["echo"] = True
# 修改 message_segment 为 adapter_response 类型
original_message.message_segment = Seg(
type="adapter_response",
data={"request_id": request_id, "response": response_data, "timestamp": int(time.time() * 1000)},
)
await message_send_instance.message_send(original_message)
logger.debug(f"已发送适配器命令响应request_id: {request_id}")
except Exception as e:
logger.error(f"发送适配器命令响应时出错: {e}")
def handle_at_message_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理艾特并发送消息命令
Args:
args (Dict[str, Any]): 参数字典, 包含 qq_id 和 text
group_info (GroupInfo): 群聊信息
Returns:
Tuple[str, Dict[str, Any]]: (action, params)
"""
at_user_id = args.get("qq_id")
text = args.get("text")
if not at_user_id or not text:
raise ValueError("艾特消息命令缺少 qq_id 或 text 参数")
if not group_info:
raise ValueError("艾特消息命令必须在群聊上下文中使用")
message_payload = [
{"type": "at", "data": {"qq": str(at_user_id)}},
{"type": "text", "data": {"text": " " + str(text)}},
]
return (
"send_group_msg",
{
"group_id": group_info.group_id,
"message": message_payload,
},
)
send_handler = SendHandler()

View File

@@ -0,0 +1,312 @@
import websockets as Server
import json
import base64
import uuid
import urllib3
import ssl
import io
from .database import BanUser, db_manager
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from .response_pool import get_response
from PIL import Image
from typing import Union, List, Tuple, Optional
class SSLAdapter(urllib3.PoolManager):
def __init__(self, *args, **kwargs):
context = ssl.create_default_context()
context.set_ciphers("DEFAULT@SECLEVEL=1")
context.minimum_version = ssl.TLSVersion.TLSv1_2
kwargs["ssl_context"] = context
super().__init__(*args, **kwargs)
async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
"""
获取群相关信息
返回值需要处理可能为空的情况
"""
logger.debug("获取群聊信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid})
try:
await websocket.send(payload)
socket_response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error(f"获取群信息超时,群号: {group_id}")
return None
except Exception as e:
logger.error(f"获取群信息失败: {e}")
return None
logger.debug(socket_response)
return socket_response.get("data")
async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
"""
获取群详细信息
返回值需要处理可能为空的情况
"""
logger.debug("获取群详细信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid})
try:
await websocket.send(payload)
socket_response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error(f"获取群详细信息超时,群号: {group_id}")
return None
except Exception as e:
logger.error(f"获取群详细信息失败: {e}")
return None
logger.debug(socket_response)
return socket_response.get("data")
async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None:
"""
获取群成员信息
返回值需要处理可能为空的情况
"""
logger.debug("获取群成员信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps(
{
"action": "get_group_member_info",
"params": {"group_id": group_id, "user_id": user_id, "no_cache": True},
"echo": request_uuid,
}
)
try:
await websocket.send(payload)
socket_response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error(f"获取成员信息超时,群号: {group_id}, 用户ID: {user_id}")
return None
except Exception as e:
logger.error(f"获取成员信息失败: {e}")
return None
logger.debug(socket_response)
return socket_response.get("data")
async def get_image_base64(url: str) -> str:
# sourcery skip: raise-specific-error
"""获取图片/表情包的Base64"""
logger.debug(f"下载图片: {url}")
http = SSLAdapter()
try:
response = http.request("GET", url, timeout=10)
if response.status != 200:
raise Exception(f"HTTP Error: {response.status}")
image_bytes = response.data
return base64.b64encode(image_bytes).decode("utf-8")
except Exception as e:
logger.error(f"图片下载失败: {str(e)}")
raise
def convert_image_to_gif(image_base64: str) -> str:
# sourcery skip: extract-method
"""
将Base64编码的图片转换为GIF格式
Parameters:
image_base64: str: Base64编码的图片数据
Returns:
str: Base64编码的GIF图片数据
"""
logger.debug("转换图片为GIF格式")
try:
image_bytes = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_bytes))
output_buffer = io.BytesIO()
image.save(output_buffer, format="GIF")
output_buffer.seek(0)
return base64.b64encode(output_buffer.read()).decode("utf-8")
except Exception as e:
logger.error(f"图片转换为GIF失败: {str(e)}")
return image_base64
async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
"""
获取自身信息
Parameters:
websocket: WebSocket连接对象
Returns:
data: dict: 返回的自身信息
"""
logger.debug("获取自身信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid})
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error("获取自身信息超时")
return None
except Exception as e:
logger.error(f"获取自身信息失败: {e}")
return None
logger.debug(response)
return response.get("data")
def get_image_format(raw_data: str) -> str:
"""
从Base64编码的数据中确定图片的格式。
Parameters:
raw_data: str: Base64编码的图片数据。
Returns:
format: str: 图片的格式(例如 'jpeg', 'png', 'gif')。
"""
image_bytes = base64.b64decode(raw_data)
return Image.open(io.BytesIO(image_bytes)).format.lower()
async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict | None:
"""
获取陌生人信息
Parameters:
websocket: WebSocket连接对象
user_id: 用户ID
Returns:
dict: 返回的陌生人信息
"""
logger.debug("获取陌生人信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid})
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error(f"获取陌生人信息超时用户ID: {user_id}")
return None
except Exception as e:
logger.error(f"获取陌生人信息失败: {e}")
return None
logger.debug(response)
return response.get("data")
async def get_message_detail(websocket: Server.ServerConnection, message_id: Union[str, int]) -> dict | None:
"""
获取消息详情,可能为空
Parameters:
websocket: WebSocket连接对象
message_id: 消息ID
Returns:
dict: 返回的消息详情
"""
logger.debug("获取消息详情中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid})
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
except TimeoutError:
logger.error(f"获取消息详情超时消息ID: {message_id}")
return None
except Exception as e:
logger.error(f"获取消息详情失败: {e}")
return None
logger.debug(response)
return response.get("data")
async def get_record_detail(
websocket: Server.ServerConnection, file: str, file_id: Optional[str] = None
) -> dict | None:
"""
获取语音消息内容
Parameters:
websocket: WebSocket连接对象
file: 文件名
file_id: 文件ID
Returns:
dict: 返回的语音消息详情
"""
logger.debug("获取语音消息详情中")
request_uuid = str(uuid.uuid4())
payload = json.dumps(
{
"action": "get_record",
"params": {"file": file, "file_id": file_id, "out_format": "wav"},
"echo": request_uuid,
}
)
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
except TimeoutError:
logger.error(f"获取语音消息详情超时,文件: {file}, 文件ID: {file_id}")
return None
except Exception as e:
logger.error(f"获取语音消息详情失败: {e}")
return None
logger.debug(f"{str(response)[:200]}...") # 防止语音的超长base64编码导致日志过长
return response.get("data")
async def read_ban_list(
websocket: Server.ServerConnection,
) -> Tuple[List[BanUser], List[BanUser]]:
"""
从根目录下的data文件夹中的文件读取禁言列表。
同时自动更新已经失效禁言
Returns:
Tuple[
一个仍在禁言中的用户的BanUser列表,
一个已经自然解除禁言的用户的BanUser列表,
一个仍在全体禁言中的群的BanUser列表,
一个已经自然解除全体禁言的群的BanUser列表,
]
"""
try:
ban_list = db_manager.get_ban_records()
lifted_list: List[BanUser] = []
logger.info("已经读取禁言列表")
for ban_record in ban_list:
if ban_record.user_id == 0:
fetched_group_info = await get_group_info(websocket, ban_record.group_id)
if fetched_group_info is None:
logger.warning(f"无法获取群信息,群号: {ban_record.group_id},默认禁言解除")
lifted_list.append(ban_record)
ban_list.remove(ban_record)
continue
group_all_shut: int = fetched_group_info.get("group_all_shut")
if group_all_shut == 0:
lifted_list.append(ban_record)
ban_list.remove(ban_record)
continue
else:
fetched_member_info = await get_member_info(websocket, ban_record.group_id, ban_record.user_id)
if fetched_member_info is None:
logger.warning(
f"无法获取群成员信息用户ID: {ban_record.user_id}, 群号: {ban_record.group_id},默认禁言解除"
)
lifted_list.append(ban_record)
ban_list.remove(ban_record)
continue
lift_ban_time: int = fetched_member_info.get("shut_up_timestamp")
if lift_ban_time == 0:
lifted_list.append(ban_record)
ban_list.remove(ban_record)
else:
ban_record.lift_time = lift_ban_time
db_manager.update_ban_record(ban_list)
return ban_list, lifted_list
except Exception as e:
logger.error(f"读取禁言列表失败: {e}")
return [], []
def save_ban_record(list: List[BanUser]):
return db_manager.update_ban_record(list)

View File

@@ -0,0 +1,177 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
视频下载和处理模块
用于从QQ消息中下载视频并转发给Bot进行分析
"""
import aiohttp
import asyncio
from pathlib import Path
from typing import Optional, Dict, Any
from src.common.logger import get_logger
logger = get_logger("video_handler")
class VideoDownloader:
def __init__(self, max_size_mb: int = 100, download_timeout: int = 60):
self.max_size_mb = max_size_mb
self.download_timeout = download_timeout
self.supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm", ".m4v"}
def is_video_url(self, url: str) -> bool:
"""检查URL是否为视频文件"""
try:
# QQ视频URL可能没有扩展名所以先检查Content-Type
# 对于QQ视频我们先假设是视频稍后通过Content-Type验证
# 检查URL中是否包含视频相关的关键字
video_keywords = ["video", "mp4", "avi", "mov", "mkv", "flv", "wmv", "webm", "m4v"]
url_lower = url.lower()
# 如果URL包含视频关键字认为是视频
if any(keyword in url_lower for keyword in video_keywords):
return True
# 检查文件扩展名(传统方法)
path = Path(url.split("?")[0]) # 移除查询参数
if path.suffix.lower() in self.supported_formats:
return True
# 对于QQ等特殊平台URL可能没有扩展名
# 我们允许这些URL通过稍后通过HTTP头Content-Type验证
qq_domains = ["qpic.cn", "gtimg.cn", "qq.com", "tencent.com"]
if any(domain in url_lower for domain in qq_domains):
return True
return False
except:
# 如果解析失败,默认允许尝试下载(稍后验证)
return True
def check_file_size(self, content_length: Optional[str]) -> bool:
"""检查文件大小是否在允许范围内"""
if content_length is None:
return True # 无法获取大小时允许下载
try:
size_bytes = int(content_length)
size_mb = size_bytes / (1024 * 1024)
return size_mb <= self.max_size_mb
except:
return True
async def download_video(self, url: str, filename: Optional[str] = None) -> Dict[str, Any]:
"""
下载视频文件
Args:
url: 视频URL
filename: 可选的文件名
Returns:
dict: 下载结果包含success、data、filename、error等字段
"""
try:
logger.info(f"开始下载视频: {url}")
# 检查URL格式
if not self.is_video_url(url):
logger.warning(f"URL格式检查失败: {url}")
return {"success": False, "error": "不支持的视频格式", "url": url}
async with aiohttp.ClientSession() as session:
# 先发送HEAD请求检查文件大小
try:
async with session.head(url, timeout=aiohttp.ClientTimeout(total=10)) as response:
if response.status != 200:
logger.warning(f"HEAD请求失败状态码: {response.status}")
else:
content_length = response.headers.get("Content-Length")
if not self.check_file_size(content_length):
return {
"success": False,
"error": f"视频文件过大,超过{self.max_size_mb}MB限制",
"url": url,
}
except Exception as e:
logger.warning(f"HEAD请求失败: {e},继续尝试下载")
# 下载文件
async with session.get(url, timeout=aiohttp.ClientTimeout(total=self.download_timeout)) as response:
if response.status != 200:
return {"success": False, "error": f"下载失败HTTP状态码: {response.status}", "url": url}
# 检查Content-Type是否为视频
content_type = response.headers.get("Content-Type", "").lower()
if content_type:
# 检查是否为视频类型
video_mime_types = [
"video/",
"application/octet-stream",
"application/x-msvideo",
"video/x-msvideo",
]
is_video_content = any(mime in content_type for mime in video_mime_types)
if not is_video_content:
logger.warning(f"Content-Type不是视频格式: {content_type}")
# 如果不是明确的视频类型但可能是QQ的特殊格式继续尝试
if "text/" in content_type or "application/json" in content_type:
return {
"success": False,
"error": f"URL返回的不是视频内容Content-Type: {content_type}",
"url": url,
}
# 再次检查Content-Length
content_length = response.headers.get("Content-Length")
if not self.check_file_size(content_length):
return {"success": False, "error": f"视频文件过大,超过{self.max_size_mb}MB限制", "url": url}
# 读取文件内容
video_data = await response.read()
# 检查实际文件大小
actual_size_mb = len(video_data) / (1024 * 1024)
if actual_size_mb > self.max_size_mb:
return {
"success": False,
"error": f"视频文件过大,实际大小: {actual_size_mb:.2f}MB",
"url": url,
}
# 确定文件名
if filename is None:
filename = Path(url.split("?")[0]).name
if not filename or "." not in filename:
filename = "video.mp4"
logger.info(f"视频下载成功: {filename}, 大小: {actual_size_mb:.2f}MB")
return {
"success": True,
"data": video_data,
"filename": filename,
"size_mb": actual_size_mb,
"url": url,
}
except asyncio.TimeoutError:
return {"success": False, "error": "下载超时", "url": url}
except Exception as e:
logger.error(f"下载视频时出错: {e}")
return {"success": False, "error": str(e), "url": url}
# 全局实例
_video_downloader = None
def get_video_downloader(max_size_mb: int = 100, download_timeout: int = 60) -> VideoDownloader:
"""获取视频下载器实例"""
global _video_downloader
if _video_downloader is None:
_video_downloader = VideoDownloader(max_size_mb, download_timeout)
return _video_downloader

View File

@@ -0,0 +1,161 @@
import asyncio
import websockets as Server
from typing import Optional, Callable, Any
from src.common.logger import get_logger
from src.plugin_system.apis import config_api
logger = get_logger("napcat_adapter")
class WebSocketManager:
"""WebSocket 连接管理器,支持正向和反向连接"""
def __init__(self):
self.connection: Optional[Server.ServerConnection] = None
self.server: Optional[Server.WebSocketServer] = None
self.is_running = False
self.reconnect_interval = 5 # 重连间隔(秒)
self.max_reconnect_attempts = 10 # 最大重连次数
self.plugin_config = None
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict) -> None:
"""根据配置启动 WebSocket 连接"""
self.plugin_config = plugin_config
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
if mode == "reverse":
await self._start_reverse_connection(message_handler)
elif mode == "forward":
await self._start_forward_connection(message_handler)
else:
raise ValueError(f"不支持的连接模式: {mode}")
async def _start_reverse_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
"""启动反向连接(作为服务器)"""
host = config_api.get_plugin_config(self.plugin_config, "napcat_server.host")
port = config_api.get_plugin_config(self.plugin_config, "napcat_server.port")
logger.info(f"正在启动反向连接模式,监听地址: ws://{host}:{port}")
async def handle_client(websocket, path=None):
self.connection = websocket
logger.info(f"Napcat 客户端已连接: {websocket.remote_address}")
try:
await message_handler(websocket)
except Exception as e:
logger.error(f"处理客户端连接时出错: {e}")
finally:
self.connection = None
logger.info("Napcat 客户端已断开连接")
self.server = await Server.serve(handle_client, host, port, max_size=2**26)
self.is_running = True
logger.info(f"反向连接服务器已启动,监听地址: ws://{host}:{port}")
# 保持服务器运行
await self.server.serve_forever()
async def _start_forward_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
"""启动正向连接(作为客户端)"""
url = self._get_forward_url()
logger.info(f"正在启动正向连接模式,目标地址: {url}")
reconnect_count = 0
while reconnect_count < self.max_reconnect_attempts:
try:
logger.info(f"尝试连接到 Napcat 服务器: {url}")
# 准备连接参数
connect_kwargs = {"max_size": 2**26}
# 如果配置了访问令牌,添加到请求头
access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token")
if access_token:
connect_kwargs["additional_headers"] = {
"Authorization": f"Bearer {access_token}"
}
logger.info("已添加访问令牌到连接请求头")
async with Server.connect(url, **connect_kwargs) as websocket:
self.connection = websocket
self.is_running = True
reconnect_count = 0 # 重置重连计数
logger.info(f"成功连接到 Napcat 服务器: {url}")
try:
await message_handler(websocket)
except Server.exceptions.ConnectionClosed:
logger.warning("与 Napcat 服务器的连接已断开")
except Exception as e:
logger.error(f"处理正向连接时出错: {e}")
finally:
self.connection = None
self.is_running = False
except (
Server.exceptions.ConnectionClosed,
Server.exceptions.InvalidMessage,
OSError,
ConnectionRefusedError,
) as e:
reconnect_count += 1
logger.warning(f"连接失败 ({reconnect_count}/{self.max_reconnect_attempts}): {e}")
if reconnect_count < self.max_reconnect_attempts:
logger.info(f"将在 {self.reconnect_interval} 秒后重试连接...")
await asyncio.sleep(self.reconnect_interval)
else:
logger.error("已达到最大重连次数,停止重连")
raise
except Exception as e:
logger.error(f"正向连接时发生未知错误: {e}")
raise
def _get_forward_url(self) -> str:
"""获取正向连接的 URL"""
# 如果配置了完整的 URL直接使用
url = config_api.get_plugin_config(self.plugin_config, "napcat_server.url")
if url:
return url
# 否则根据 host 和 port 构建 URL
host = config_api.get_plugin_config(self.plugin_config, "napcat_server.host")
port = config_api.get_plugin_config(self.plugin_config, "napcat_server.port")
return f"ws://{host}:{port}"
async def stop_connection(self) -> None:
"""停止 WebSocket 连接"""
self.is_running = False
if self.connection:
try:
await self.connection.close()
logger.info("WebSocket 连接已关闭")
except Exception as e:
logger.error(f"关闭 WebSocket 连接时出错: {e}")
finally:
self.connection = None
if self.server:
try:
self.server.close()
await self.server.wait_closed()
logger.info("WebSocket 服务器已关闭")
except Exception as e:
logger.error(f"关闭 WebSocket 服务器时出错: {e}")
finally:
self.server = None
def get_connection(self) -> Optional[Server.ServerConnection]:
"""获取当前的 WebSocket 连接"""
return self.connection
def is_connected(self) -> bool:
"""检查是否已连接"""
return self.connection is not None and self.is_running
# 全局 WebSocket 管理器实例
websocket_manager = WebSocketManager()

View File

@@ -0,0 +1,43 @@
# 权限配置文件
# 此文件用于管理群聊和私聊的黑白名单设置,以及聊天相关功能
# 支持热重载,修改后会自动生效
# 群聊权限设置
group_list_type = "whitelist" # 群聊列表类型whitelist白名单或 blacklist黑名单
group_list = [] # 群聊ID列表
# 当 group_list_type 为 whitelist 时,只有列表中的群聊可以使用机器人
# 当 group_list_type 为 blacklist 时,列表中的群聊无法使用机器人
# 示例group_list = [123456789, 987654321]
# 私聊权限设置
private_list_type = "whitelist" # 私聊列表类型whitelist白名单或 blacklist黑名单
private_list = [] # 用户ID列表
# 当 private_list_type 为 whitelist 时,只有列表中的用户可以私聊机器人
# 当 private_list_type 为 blacklist 时,列表中的用户无法私聊机器人
# 示例private_list = [123456789, 987654321]
# 全局禁止设置
ban_user_id = [] # 全局禁止用户ID列表这些用户无法在任何地方使用机器人
ban_qq_bot = false # 是否屏蔽QQ官方机器人消息
# 聊天功能设置
enable_poke = true # 是否启用戳一戳功能
ignore_non_self_poke = false # 是否无视不是针对自己的戳一戳
poke_debounce_seconds = 3 # 戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略
enable_reply_at = true # 是否启用引用回复时艾特用户的功能
reply_at_rate = 0.5 # 引用回复时艾特用户的几率 (0.0 ~ 1.0)
# 视频处理设置
enable_video_analysis = true # 是否启用视频识别功能
max_video_size_mb = 100 # 视频文件最大大小限制MB
download_timeout = 60 # 视频下载超时时间(秒)
supported_formats = ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"] # 支持的视频格式
# 消息缓冲设置
enable_message_buffer = true # 是否启用消息缓冲合并功能
message_buffer_enable_group = true # 是否启用群聊消息缓冲合并
message_buffer_enable_private = true # 是否启用私聊消息缓冲合并
message_buffer_interval = 3.0 # 消息合并间隔时间(秒),在此时间内的连续消息将被合并
message_buffer_initial_delay = 0.5 # 消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并
message_buffer_max_components = 50 # 单个会话最大缓冲消息组件数量,超过此数量将强制合并
message_buffer_block_prefixes = ["/"] # 消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲

View File

@@ -0,0 +1,29 @@
[inner]
version = "0.2.1" # 版本号
# 请勿修改版本号,除非你知道自己在做什么
[nickname] # 现在没用
nickname = ""
[napcat_server] # Napcat连接的ws服务设置
mode = "reverse" # 连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)
host = "localhost" # 主机地址
port = 8095 # 端口号
url = "" # 正向连接时的完整WebSocket URL如 ws://localhost:8080/ws (仅在forward模式下使用)
access_token = "" # WebSocket 连接的访问令牌,用于身份验证(可选)
heartbeat_interval = 30 # 心跳间隔时间(按秒计)
[maibot_server] # 连接麦麦的ws服务设置
host = "localhost" # 麦麦在.env文件中设置的主机地址即HOST字段
port = 8000 # 麦麦在.env文件中设置的端口即PORT字段
[voice] # 发送语音设置
use_tts = false # 是否使用tts语音请确保你配置了tts并有对应的adapter
[slicing] # WebSocket消息切片设置
max_frame_size = 64 # WebSocket帧的最大大小单位为字节默认64KB
delay_ms = 10 # 切片发送间隔时间,单位为毫秒
[debug]
level = "INFO" # 日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL

View File

@@ -0,0 +1,89 @@
# TODO List:
- [x] logger使用主程序的
- [ ] 使用插件系统的config系统
- [x] 接收从napcat传递的所有信息
- [ ] <del>优化架构,各模块解耦,暴露关键方法用于提供接口</del>
- [ ] <del>单独一个模块负责与主程序通信</del>
- [ ] 使用event系统完善接口api
---
Event分为两种一种是对外输出的event由napcat插件自主触发并传递参数另一种是接收外界输入的event由外部插件触发并向napcat传递参数
## 例如,
### 对外输出的event
napcat_on_received_text -> (message_seg: Seg) 接受到qq的文字消息,会向handler传递一个Seg
napcat_on_received_face -> (message_seg: Seg) 接受到qq的表情消息,会向handler传递一个Seg
napcat_on_received_reply -> (message_seg: Seg) 接受到qq的回复消息,会向handler传递一个Seg
napcat_on_received_image -> (message_seg: Seg) 接受到qq的图片消息,会向handler传递一个Seg
napcat_on_received_image -> (message_seg: Seg) 接受到qq的图片消息,会向handler传递一个Seg
napcat_on_received_record -> (message_seg: Seg) 接受到qq的语音消息,会向handler传递一个Seg
napcat_on_received_rps -> (message_seg: Seg) 接受到qq的猜拳魔法表情,会向handler传递一个Seg
napcat_on_received_friend_invitation -> (user_id: str) 接受到qq的好友请求,会向handler传递一个user_id
...
此类event不接受外部插件的触发只能由napcat插件统一触发。
外部插件需要编写handler并订阅此类事件。
```python
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.base.base_event import HandlerResult
class MyEventHandler(BaseEventHandler):
handler_name = "my_handler"
handler_description = "我的自定义事件处理器"
weight = 10 # 权重,越大越先执行
intercept_message = False # 是否拦截消息
init_subscribe = ["napcat_on_received_text"] # 初始订阅的事件
async def execute(self, params: dict) -> HandlerResult:
"""处理事件"""
try:
message = params.get("message_seg")
print(f"收到消息: {message.data}")
# 业务逻辑处理
# ...
return HandlerResult(
success=True,
continue_process=True, # 是否继续让其他处理器处理
message="处理成功",
handler_name=self.handler_name
)
except Exception as e:
return HandlerResult(
success=False,
continue_process=True,
message=f"处理失败: {str(e)}",
handler_name=self.handler_name
)
```
### 接收外界输入的event
napcat_kick_group <- (user_id, group_id) 踢出某个群组中的某个用户
napcat_mute_user <- (user_id, group_id, time) 禁言某个群组中的某个用户
napcat_unmute_user <- (user_id, group_id) 取消禁言某个群组中的某个用户
napcat_mute_group <- (user_id, group_id) 禁言某个群组
napcat_unmute_group <- (user_id, group_id) 取消禁言某个群组
napcat_add_friend <- (user_id) 向某个用户发出好友请求
napcat_accept_friend <- (user_id) 接收某个用户的好友请求
napcat_reject_friend <- (user_id) 拒绝某个用户的好友请求
...
此类事件只由外部插件触发并传递参数由napcat完成请求任务
外部插件需要触发此类的event并传递正确的参数
```python
from src.plugin_system.core.event_manager import event_manager
# 触发事件
await event_manager.trigger_event("napcat_accept_friend", user_id = 1234123)
```