This commit is contained in:
SengokuCola
2025-07-23 02:26:57 +08:00
20 changed files with 515 additions and 411 deletions

View File

@@ -45,10 +45,19 @@
10. 修正了`main.py`中的错误输出。 10. 修正了`main.py`中的错误输出。
11. 修正了`command`所编译的`Pattern`注册时的错误输出。 11. 修正了`command`所编译的`Pattern`注册时的错误输出。
12. `events_manager`有了task相关逻辑了。 12. `events_manager`有了task相关逻辑了。
13. 现在有了插件卸载和重载功能了,也就是热插拔。
14. 实现了组件的全局启用和禁用功能。
- 通过`enable_component`和`disable_component`方法来启用或禁用组件。
- 不过这个操作不会保存到配置文件~
15. 实现了组件的局部禁用,也就是针对某一个聊天禁用的功能。
- 通过`disable_specific_chat_action``enable_specific_chat_action``disable_specific_chat_command``enable_specific_chat_command``disable_specific_chat_event_handler``enable_specific_chat_event_handler`来操作
- 同样不保存到配置文件~
### TODO ### TODO
把这个看起来就很别扭的config获取方式改一下 把这个看起来就很别扭的config获取方式改一下
来个API管理这些启用禁用
# 吐槽 # 吐槽
```python ```python
@@ -65,3 +74,6 @@ plugin_path = Path(plugin_file)
module_name = ".".join(plugin_path.parent.parts) module_name = ".".join(plugin_path.parent.parts)
``` ```
这两个区别很大的。 这两个区别很大的。
### 执笔BGM
塞壬唱片!

View File

@@ -51,6 +51,8 @@ NO_ACTION = {
"action_prompt": "", "action_prompt": "",
} }
IS_MAI4U = False
install(extra_lines=3) install(extra_lines=3)
# 注释:原来的动作修改超时常量已移除,因为改为顺序执行 # 注释:原来的动作修改超时常量已移除,因为改为顺序执行
@@ -258,29 +260,27 @@ class HeartFChatting:
return f"{person_name}:{message_data.get('processed_plain_text')}" return f"{person_name}:{message_data.get('processed_plain_text')}"
async def send_typing(self): async def send_typing(self):
group_info = GroupInfo(platform = "amaidesu_default",group_id = 114514,group_name = "内心") group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
chat = await get_chat_manager().get_or_create_stream( chat = await get_chat_manager().get_or_create_stream(
platform = "amaidesu_default", platform="amaidesu_default",
user_info = None, user_info=None,
group_info = group_info group_info=group_info,
) )
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
) )
async def stop_typing(self): async def stop_typing(self):
group_info = GroupInfo(platform = "amaidesu_default",group_id = 114514,group_name = "内心") group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
chat = await get_chat_manager().get_or_create_stream( chat = await get_chat_manager().get_or_create_stream(
platform = "amaidesu_default", platform="amaidesu_default",
user_info = None, user_info=None,
group_info = group_info group_info=group_info,
) )
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
) )
@@ -374,7 +374,6 @@ class HeartFChatting:
await self.stop_typing() await self.stop_typing()
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text) await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
return True return True
else: else:
@@ -508,7 +507,6 @@ class HeartFChatting:
self.willing_manager.setup(message_data, self.chat_stream) self.willing_manager.setup(message_data, self.chat_stream)
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id", "")) reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id", ""))
talk_frequency = -1.00 talk_frequency = -1.00
@@ -547,7 +545,6 @@ class HeartFChatting:
self.willing_manager.delete(message_data.get("message_id", "")) self.willing_manager.delete(message_data.get("message_id", ""))
return False return False
async def _generate_response( async def _generate_response(
self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str
) -> Optional[list]: ) -> Optional[list]:
@@ -571,7 +568,7 @@ class HeartFChatting:
logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}") logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}")
return None return None
async def _send_response(self, reply_set, reply_to, thinking_start_time,message_data): async def _send_response(self, reply_set, reply_to, thinking_start_time, message_data):
current_time = time.time() current_time = time.time()
new_message_count = message_api.count_new_messages( new_message_count = message_api.count_new_messages(
chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
@@ -593,13 +590,27 @@ class HeartFChatting:
if not first_replied: if not first_replied:
if need_reply: if need_reply:
await send_api.text_to_stream( await send_api.text_to_stream(
text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False text=data,
stream_id=self.chat_stream.stream_id,
reply_to=reply_to,
reply_to_platform_id=reply_to_platform_id,
typing=False,
) )
else: else:
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=False) await send_api.text_to_stream(
text=data,
stream_id=self.chat_stream.stream_id,
reply_to_platform_id=reply_to_platform_id,
typing=False,
)
first_replied = True first_replied = True
else: else:
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=True) await send_api.text_to_stream(
text=data,
stream_id=self.chat_stream.stream_id,
reply_to_platform_id=reply_to_platform_id,
typing=True,
)
reply_text += data reply_text += data
return reply_text return reply_text

View File

@@ -13,7 +13,7 @@ from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, events_manager # 导入新插件系统 from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
@@ -91,8 +91,20 @@ class ChatBot:
# 使用新的组件注册中心查找命令 # 使用新的组件注册中心查找命令
command_result = component_registry.find_command_by_text(text) command_result = component_registry.find_command_by_text(text)
if command_result: if command_result:
command_class, matched_groups, command_info = command_result
intercept_message = command_info.intercept_message
plugin_name = command_info.plugin_name
command_name = command_info.name
if (
message.chat_stream
and message.chat_stream.stream_id
and command_name
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
):
logger.info("用户禁用的命令,跳过处理")
return False, None, True
message.is_command = True message.is_command = True
command_class, matched_groups, intercept_message, plugin_name = command_result
# 获取插件配置 # 获取插件配置
plugin_config = component_registry.get_plugin_config(plugin_name) plugin_config = component_registry.get_plugin_config(plugin_name)
@@ -140,7 +152,6 @@ class ChatBot:
logger.info("收到notice消息暂时不支持处理") logger.info("收到notice消息暂时不支持处理")
return True return True
async def do_s4u(self, message_data: Dict[str, Any]): async def do_s4u(self, message_data: Dict[str, Any]):
message = MessageRecvS4U(message_data) message = MessageRecvS4U(message_data)
group_info = message.message_info.group_info group_info = message.message_info.group_info
@@ -162,7 +173,6 @@ class ChatBot:
return return
async def message_process(self, message_data: Dict[str, Any]) -> None: async def message_process(self, message_data: Dict[str, Any]) -> None:
"""处理转化后的统一格式消息 """处理转化后的统一格式消息
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中 这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
@@ -178,8 +188,6 @@ class ChatBot:
- 性能计时 - 性能计时
""" """
try: try:
# 确保所有任务已启动 # 确保所有任务已启动
await self._ensure_started() await self._ensure_started()
@@ -204,7 +212,6 @@ class ChatBot:
if await self.hanle_notice_message(message): if await self.hanle_notice_message(message):
return return
group_info = message.message_info.group_info group_info = message.message_info.group_info
user_info = message.message_info.user_info user_info = message.message_info.user_info
if message.message_info.additional_config: if message.message_info.additional_config:
@@ -233,7 +240,6 @@ class ChatBot:
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}") # logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
# return # return
# 过滤检查 # 过滤检查
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, # type: ignore message.raw_message, # type: ignore

View File

@@ -163,20 +163,25 @@ class ChatManager:
"""注册消息到聊天流""" """注册消息到聊天流"""
stream_id = self._generate_stream_id( stream_id = self._generate_stream_id(
message.message_info.platform, # type: ignore message.message_info.platform, # type: ignore
message.message_info.user_info, # type: ignore message.message_info.user_info,
message.message_info.group_info, message.message_info.group_info,
) )
self.last_messages[stream_id] = message self.last_messages[stream_id] = message
# logger.debug(f"注册消息到聊天流: {stream_id}") # logger.debug(f"注册消息到聊天流: {stream_id}")
@staticmethod @staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: def _generate_stream_id(
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
) -> str:
"""生成聊天流唯一ID""" """生成聊天流唯一ID"""
if not user_info and not group_info:
raise ValueError("用户信息或群组信息必须提供")
if group_info: if group_info:
# 组合关键信息 # 组合关键信息
components = [platform, str(group_info.group_id)] components = [platform, str(group_info.group_id)]
else: else:
components = [platform, str(user_info.user_id), "private"] components = [platform, str(user_info.user_id), "private"] # type: ignore
# 使用MD5生成唯一ID # 使用MD5生成唯一ID
key = "_".join(components) key = "_".join(components)

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Type from typing import Dict, Optional, Type
from src.plugin_system.base.base_action import BaseAction from src.plugin_system.base.base_action import BaseAction
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -22,53 +22,14 @@ class ActionManager:
def __init__(self): def __init__(self):
"""初始化动作管理器""" """初始化动作管理器"""
# 所有注册的动作集合
self._registered_actions: Dict[str, ActionInfo] = {}
# 当前正在使用的动作集合,默认加载默认动作 # 当前正在使用的动作集合,默认加载默认动作
self._using_actions: Dict[str, ActionInfo] = {} self._using_actions: Dict[str, ActionInfo] = {}
# 加载插件动作
self._load_plugin_actions()
# 初始化时将默认动作加载到使用中的动作 # 初始化时将默认动作加载到使用中的动作
self._using_actions = component_registry.get_default_actions() self._using_actions = component_registry.get_default_actions()
def _load_plugin_actions(self) -> None: # === 执行Action方法 ===
"""
加载所有插件系统中的动作
"""
try:
# 从新插件系统获取Action组件
self._load_plugin_system_actions()
logger.debug("从插件系统加载Action组件成功")
except Exception as e:
logger.error(f"加载插件动作失败: {e}")
def _load_plugin_system_actions(self) -> None:
"""从插件系统的component_registry加载Action组件"""
try:
# 获取所有Action组件
action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore
for action_name, action_info in action_components.items():
if action_name in self._registered_actions:
logger.debug(f"Action组件 {action_name} 已存在,跳过")
continue
self._registered_actions[action_name] = action_info
logger.debug(
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
)
logger.info(f"加载了 {len(action_components)} 个Action动作")
except Exception as e:
logger.error(f"从插件系统加载Action组件失败: {e}")
import traceback
logger.error(traceback.format_exc())
def create_action( def create_action(
self, self,
@@ -139,36 +100,11 @@ class ActionManager:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return None return None
def get_registered_actions(self) -> Dict[str, ActionInfo]:
"""获取所有已注册的动作集"""
return self._registered_actions.copy()
def get_using_actions(self) -> Dict[str, ActionInfo]: def get_using_actions(self) -> Dict[str, ActionInfo]:
"""获取当前正在使用的动作集合""" """获取当前正在使用的动作集合"""
return self._using_actions.copy() return self._using_actions.copy()
def add_action_to_using(self, action_name: str) -> bool: # === Modify相关方法 ===
"""
添加已注册的动作到当前使用的动作集
Args:
action_name: 动作名称
Returns:
bool: 添加是否成功
"""
if action_name not in self._registered_actions:
logger.warning(f"添加失败: 动作 {action_name} 未注册")
return False
if action_name in self._using_actions:
logger.info(f"动作 {action_name} 已经在使用中")
return True
self._using_actions[action_name] = self._registered_actions[action_name]
logger.info(f"添加动作 {action_name} 到使用集")
return True
def remove_action_from_using(self, action_name: str) -> bool: def remove_action_from_using(self, action_name: str) -> bool:
""" """
从当前使用的动作集中移除指定动作 从当前使用的动作集中移除指定动作
@@ -187,79 +123,8 @@ class ActionManager:
logger.debug(f"已从使用集中移除动作 {action_name}") logger.debug(f"已从使用集中移除动作 {action_name}")
return True return True
# def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
# """
# 添加新的动作到注册集
# Args:
# action_name: 动作名称
# description: 动作描述
# parameters: 动作参数定义,默认为空字典
# require: 动作依赖项,默认为空列表
# Returns:
# bool: 添加是否成功
# """
# if action_name in self._registered_actions:
# return False
# if parameters is None:
# parameters = {}
# if require is None:
# require = []
# action_info = {"description": description, "parameters": parameters, "require": require}
# self._registered_actions[action_name] = action_info
# return True
def remove_action(self, action_name: str) -> bool:
"""从注册集移除指定动作"""
if action_name not in self._registered_actions:
return False
del self._registered_actions[action_name]
# 如果在使用集中也存在,一并移除
if action_name in self._using_actions:
del self._using_actions[action_name]
return True
def temporarily_remove_actions(self, actions_to_remove: List[str]) -> None:
"""临时移除使用集中的指定动作"""
for name in actions_to_remove:
self._using_actions.pop(name, None)
def restore_actions(self) -> None: def restore_actions(self) -> None:
"""恢复到默认动作集""" """恢复到默认动作集"""
actions_to_restore = list(self._using_actions.keys()) actions_to_restore = list(self._using_actions.keys())
self._using_actions = component_registry.get_default_actions() self._using_actions = component_registry.get_default_actions()
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
def add_system_action_if_needed(self, action_name: str) -> bool:
"""
根据需要添加系统动作到使用集
Args:
action_name: 动作名称
Returns:
bool: 是否成功添加
"""
if action_name in self._registered_actions and action_name not in self._using_actions:
self._using_actions[action_name] = self._registered_actions[action_name]
logger.info(f"临时添加系统动作到使用集: {action_name}")
return True
return False
def get_action(self, action_name: str) -> Optional[Type[BaseAction]]:
"""
获取指定动作的处理器类
Args:
action_name: 动作名称
Returns:
Optional[Type[BaseAction]]: 动作处理器类如果不存在则返回None
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_component_class(action_name, ComponentType.ACTION) # type: ignore

View File

@@ -2,7 +2,7 @@ import random
import asyncio import asyncio
import hashlib import hashlib
import time import time
from typing import List, Any, Dict, TYPE_CHECKING from typing import List, Any, Dict, TYPE_CHECKING, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@@ -11,6 +11,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageCo
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
if TYPE_CHECKING: if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
@@ -47,7 +48,6 @@ class ActionModifier:
async def modify_actions( async def modify_actions(
self, self,
history_loop=None,
message_content: str = "", message_content: str = "",
): # sourcery skip: use-named-expression ): # sourcery skip: use-named-expression
""" """
@@ -61,8 +61,9 @@ class ActionModifier:
""" """
logger.debug(f"{self.log_prefix}开始完整动作修改流程") logger.debug(f"{self.log_prefix}开始完整动作修改流程")
removals_s1 = [] removals_s1: List[Tuple[str, str]] = []
removals_s2 = [] removals_s2: List[Tuple[str, str]] = []
removals_s3: List[Tuple[str, str]] = []
self.action_manager.restore_actions() self.action_manager.restore_actions()
all_actions = self.action_manager.get_using_actions() all_actions = self.action_manager.get_using_actions()
@@ -84,25 +85,28 @@ class ActionModifier:
if message_content: if message_content:
chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}" chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}"
# === 第一阶段:传统观察处理 === # === 第一阶段:去除用户自行禁用的 ===
# if history_loop: disabled_actions = global_announcement_manager.get_disabled_chat_actions(self.chat_id)
# removals_from_loop = await self.analyze_loop_actions(history_loop) if disabled_actions:
# if removals_from_loop: for disabled_action_name in disabled_actions:
# removals_s1.extend(removals_from_loop) if disabled_action_name in all_actions:
removals_s1.append((disabled_action_name, "用户自行禁用"))
self.action_manager.remove_action_from_using(disabled_action_name)
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
# 检查动作的关联类型 # === 第二阶段:检查动作的关联类型 ===
chat_context = self.chat_stream.context chat_context = self.chat_stream.context
type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context) type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context)
if type_mismatched_actions: if type_mismatched_actions:
removals_s1.extend(type_mismatched_actions) removals_s2.extend(type_mismatched_actions)
# 应用第阶段的移除 # 应用第阶段的移除
for action_name, reason in removals_s1: for action_name, reason in removals_s2:
self.action_manager.remove_action_from_using(action_name) self.action_manager.remove_action_from_using(action_name)
logger.debug(f"{self.log_prefix}阶段移除动作: {action_name},原因: {reason}") logger.debug(f"{self.log_prefix}阶段移除动作: {action_name},原因: {reason}")
# === 第阶段:激活类型判定 === # === 第阶段:激活类型判定 ===
if chat_content is not None: if chat_content is not None:
logger.debug(f"{self.log_prefix}开始激活类型判定阶段") logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
@@ -110,18 +114,18 @@ class ActionModifier:
current_using_actions = self.action_manager.get_using_actions() current_using_actions = self.action_manager.get_using_actions()
# 获取因激活类型判定而需要移除的动作 # 获取因激活类型判定而需要移除的动作
removals_s2 = await self._get_deactivated_actions_by_type( removals_s3 = await self._get_deactivated_actions_by_type(
current_using_actions, current_using_actions,
chat_content, chat_content,
) )
# 应用第阶段的移除 # 应用第阶段的移除
for action_name, reason in removals_s2: for action_name, reason in removals_s3:
self.action_manager.remove_action_from_using(action_name) self.action_manager.remove_action_from_using(action_name)
logger.debug(f"{self.log_prefix}阶段移除动作: {action_name},原因: {reason}") logger.debug(f"{self.log_prefix}阶段移除动作: {action_name},原因: {reason}")
# === 统一日志记录 === # === 统一日志记录 ===
all_removals = removals_s1 + removals_s2 all_removals = removals_s1 + removals_s2 + removals_s3
removals_summary: str = "" removals_summary: str = ""
if all_removals: if all_removals:
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals]) removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
@@ -131,7 +135,7 @@ class ActionModifier:
) )
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext): def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
type_mismatched_actions = [] type_mismatched_actions: List[Tuple[str, str]] = []
for action_name, action_info in all_actions.items(): for action_name, action_info in all_actions.items():
if action_info.associated_types and not chat_context.check_types(action_info.associated_types): if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
associated_types_str = ", ".join(action_info.associated_types) associated_types_str = ", ".join(action_info.associated_types)
@@ -318,7 +322,7 @@ class ActionModifier:
action_name: str, action_name: str,
action_info: ActionInfo, action_info: ActionInfo,
chat_content: str = "", chat_content: str = "",
) -> bool: ) -> bool: # sourcery skip: move-assign-in-block, use-named-expression
""" """
使用LLM判定是否应该激活某个action 使用LLM判定是否应该激活某个action

View File

@@ -1,7 +1,7 @@
import json import json
import time import time
import traceback import traceback
from typing import Dict, Any, Optional, Tuple from typing import Dict, Any, Optional, Tuple, List
from rich.traceback import install from rich.traceback import install
from datetime import datetime from datetime import datetime
from json_repair import repair_json from json_repair import repair_json
@@ -19,8 +19,8 @@ from src.chat.utils.chat_message_builder import (
from src.chat.utils.utils import get_chat_type_and_target_info from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.plugin_system.base.component_types import ActionInfo, ChatMode from src.plugin_system.base.component_types import ActionInfo, ChatMode, ComponentType
from src.plugin_system.core.component_registry import component_registry
logger = get_logger("planner") logger = get_logger("planner")
@@ -99,7 +99,7 @@ class ActionPlanner:
async def plan( async def plan(
self, mode: ChatMode = ChatMode.FOCUS self, mode: ChatMode = ChatMode.FOCUS
) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]: # sourcery skip: dict-comprehension ) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]:
""" """
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
""" """
@@ -113,16 +113,17 @@ class ActionPlanner:
try: try:
is_group_chat = True is_group_chat = True
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id) is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}") logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
current_available_actions_dict = self.action_manager.get_using_actions() current_available_actions_dict = self.action_manager.get_using_actions()
# 获取完整的动作信息 # 获取完整的动作信息
all_registered_actions = self.action_manager.get_registered_actions() all_registered_actions: List[ActionInfo] = list(
component_registry.get_components_by_type(ComponentType.ACTION).values() # type: ignore
for action_name in current_available_actions_dict.keys(): )
current_available_actions = {}
for action_name in current_available_actions_dict:
if action_name in all_registered_actions: if action_name in all_registered_actions:
current_available_actions[action_name] = all_registered_actions[action_name] current_available_actions[action_name] = all_registered_actions[action_name]
else: else:
@@ -234,10 +235,13 @@ class ActionPlanner:
"is_parallel": is_parallel, "is_parallel": is_parallel,
} }
return { return (
{
"action_result": action_result, "action_result": action_result,
"action_prompt": prompt, "action_prompt": prompt,
}, target_message },
target_message,
)
async def build_planner_prompt( async def build_planner_prompt(
self, self,

View File

@@ -619,9 +619,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
chat_target_info = None chat_target_info = None
try: try:
chat_stream = get_chat_manager().get_stream(chat_id) if chat_stream := get_chat_manager().get_stream(chat_id):
if chat_stream:
if chat_stream.group_info: if chat_stream.group_info:
is_group_chat = True is_group_chat = True
chat_target_info = None # Explicitly None for group chat chat_target_info = None # Explicitly None for group chat
@@ -660,8 +658,6 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
chat_target_info = target_info chat_target_info = target_info
else: else:
logger.warning(f"无法获取 chat_stream for {chat_id} in utils") logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
# Keep defaults: is_group_chat=False, chat_target_info=None
except Exception as e: except Exception as e:
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True) logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
# Keep defaults on error # Keep defaults on error

View File

@@ -173,12 +173,10 @@ class Individuality:
personality = short_impression[0] personality = short_impression[0]
identity = short_impression[1] identity = short_impression[1]
prompt_personality = f"{personality}{identity}" prompt_personality = f"{personality}{identity}"
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
return identity_block
def _get_config_hash( def _get_config_hash(
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list self, bot_nickname: str, personality_core: str, personality_side: str, identity: str
) -> tuple[str, str]: ) -> tuple[str, str]:
"""获取personality和identity配置的哈希值 """获取personality和identity配置的哈希值
@@ -197,7 +195,7 @@ class Individuality:
# 身份配置哈希 # 身份配置哈希
identity_config = { identity_config = {
"identity": sorted(identity), "identity": identity,
"compress_identity": self.personality.compress_identity if self.personality else True, "compress_identity": self.personality.compress_identity if self.personality else True,
} }
identity_str = json.dumps(identity_config, sort_keys=True) identity_str = json.dumps(identity_config, sort_keys=True)
@@ -206,7 +204,7 @@ class Individuality:
return personality_hash, identity_hash return personality_hash, identity_hash
async def _check_config_and_clear_if_changed( async def _check_config_and_clear_if_changed(
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list self, bot_nickname: str, personality_core: str, personality_side: str, identity: str
) -> tuple[bool, bool]: ) -> tuple[bool, bool]:
"""检查配置是否发生变化,如果变化则清空相应缓存 """检查配置是否发生变化,如果变化则清空相应缓存
@@ -321,7 +319,7 @@ class Individuality:
return personality_result return personality_result
async def _create_identity(self, identity: list) -> str: async def _create_identity(self, identity: str) -> str:
"""使用LLM创建压缩版本的impression""" """使用LLM创建压缩版本的impression"""
logger.info("正在构建身份.........") logger.info("正在构建身份.........")

View File

@@ -1,6 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List from typing import Dict, Optional
@dataclass @dataclass
@@ -10,7 +9,7 @@ class Personality:
bot_nickname: str # 机器人昵称 bot_nickname: str # 机器人昵称
personality_core: str # 人格核心特点 personality_core: str # 人格核心特点
personality_side: str # 人格侧面描述 personality_side: str # 人格侧面描述
identity: List[str] # 身份细节描述 identity: Optional[str] # 身份细节描述
compress_personality: bool # 是否压缩人格 compress_personality: bool # 是否压缩人格
compress_identity: bool # 是否压缩身份 compress_identity: bool # 是否压缩身份
@@ -21,7 +20,7 @@ class Personality:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self, personality_core: str = "", personality_side: str = "", identity: List[str] = None): def __init__(self, personality_core: str = "", personality_side: str = "", identity: Optional[str] = None):
self.personality_core = personality_core self.personality_core = personality_core
self.personality_side = personality_side self.personality_side = personality_side
self.identity = identity self.identity = identity
@@ -45,7 +44,7 @@ class Personality:
bot_nickname: str, bot_nickname: str,
personality_core: str, personality_core: str,
personality_side: str, personality_side: str,
identity: List[str] = None, identity: Optional[str] = None,
compress_personality: bool = True, compress_personality: bool = True,
compress_identity: bool = True, compress_identity: bool = True,
) -> "Personality": ) -> "Personality":

View File

@@ -28,6 +28,7 @@ from .core import (
component_registry, component_registry,
dependency_manager, dependency_manager,
events_manager, events_manager,
global_announcement_manager,
) )
# 导入工具模块 # 导入工具模块
@@ -67,6 +68,7 @@ __all__ = [
"component_registry", "component_registry",
"dependency_manager", "dependency_manager",
"events_manager", "events_manager",
"global_announcement_manager",
# 装饰器 # 装饰器
"register_plugin", "register_plugin",
"ConfigField", "ConfigField",

View File

@@ -28,7 +28,6 @@ def register_plugin(cls):
if "." in plugin_name: if "." in plugin_name:
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代") logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代") raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
plugin_manager.plugin_classes[plugin_name] = cls
splitted_name = cls.__module__.split(".") splitted_name = cls.__module__.split(".")
root_path = Path(__file__) root_path = Path(__file__)
@@ -40,6 +39,7 @@ def register_plugin(cls):
logger.error(f"注册 {plugin_name} 无法找到项目根目录") logger.error(f"注册 {plugin_name} 无法找到项目根目录")
return cls return cls
plugin_manager.plugin_classes[plugin_name] = cls
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve()) plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}") logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")

View File

@@ -65,21 +65,28 @@ class BaseAction(ABC):
self.thinking_id = thinking_id self.thinking_id = thinking_id
self.log_prefix = log_prefix self.log_prefix = log_prefix
# 保存插件配置
self.plugin_config = plugin_config or {} self.plugin_config = plugin_config or {}
"""对应的插件配置"""
# 设置动作基本信息实例属性 # 设置动作基本信息实例属性
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", "")) self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
"""Action的名字"""
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件") self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
"""Action的描述"""
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy() self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy() self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
# 设置激活类型实例属性(从类属性复制,提供默认值) # 设置激活类型实例属性(从类属性复制,提供默认值)
self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS) self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS)
"""FOCUS模式下的激活类型"""
self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS) self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS)
"""NORMAL模式下的激活类型"""
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0) self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
"""当激活类型为RANDOM时的概率"""
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "") self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
"""协助LLM进行判断的Prompt"""
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy() self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
"""激活类型为KEYWORD时的KEYWORDS列表"""
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False) self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL) self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL)
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True) self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)

View File

@@ -21,13 +21,18 @@ class BaseCommand(ABC):
""" """
command_name: str = "" command_name: str = ""
"""Command组件的名称"""
command_description: str = "" command_description: str = ""
"""Command组件的描述"""
# 默认命令设置(子类可以覆盖) # 默认命令设置(子类可以覆盖)
command_pattern: str = "" command_pattern: str = ""
"""命令匹配的正则表达式"""
command_help: str = "" command_help: str = ""
"""命令帮助信息"""
command_examples: List[str] = [] command_examples: List[str] = []
intercept_message: bool = True # 默认拦截消息,不继续处理 intercept_message: bool = True
"""是否拦截信息,默认拦截,不进行后续处理"""
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None): def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
"""初始化Command组件 """初始化Command组件

View File

@@ -13,16 +13,23 @@ class BaseEventHandler(ABC):
所有事件处理器都应该继承这个基类,提供事件处理的基本接口 所有事件处理器都应该继承这个基类,提供事件处理的基本接口
""" """
event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知 event_type: EventType = EventType.UNKNOWN
handler_name: str = "" # 处理器名称 """事件类型,默认为未知"""
handler_name: str = ""
"""处理器名称"""
handler_description: str = "" handler_description: str = ""
weight: int = 0 # 权重,数值越大优先级越高 """处理器描述"""
intercept_message: bool = False # 是否拦截消息,默认为否 weight: int = 0
"""处理器权重,越大权重越高"""
intercept_message: bool = False
"""是否拦截消息,默认为否"""
def __init__(self): def __init__(self):
self.log_prefix = "[EventHandler]" self.log_prefix = "[EventHandler]"
self.plugin_name = "" # 对应插件名 self.plugin_name = ""
self.plugin_config: Optional[Dict] = None # 插件配置字典 """对应插件名"""
self.plugin_config: Optional[Dict] = None
"""插件配置字典"""
if self.event_type == EventType.UNKNOWN: if self.event_type == EventType.UNKNOWN:
raise NotImplementedError("事件处理器必须指定 event_type") raise NotImplementedError("事件处理器必须指定 event_type")

View File

@@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.dependency_manager import dependency_manager from src.plugin_system.core.dependency_manager import dependency_manager
from src.plugin_system.core.events_manager import events_manager from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
__all__ = [ __all__ = [
"plugin_manager", "plugin_manager",
"component_registry", "component_registry",
"dependency_manager", "dependency_manager",
"events_manager", "events_manager",
"global_announcement_manager",
] ]

View File

@@ -27,7 +27,7 @@ class ComponentRegistry:
def __init__(self): def __init__(self):
# 组件注册表 # 组件注册表
self._components: Dict[str, ComponentInfo] = {} # 命名空间式组件名 -> 组件信息 self._components: Dict[str, ComponentInfo] = {} # 命名空间式组件名 -> 组件信息
# 类型 -> 命名空间式名称 -> 组件信息 # 类型 -> 组件原名称 -> 组件信息
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType} self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
# 命名空间式组件名 -> 组件类 # 命名空间式组件名 -> 组件类
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {} self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {}
@@ -110,11 +110,17 @@ class ComponentRegistry:
# 根据组件类型进行特定注册(使用原始名称) # 根据组件类型进行特定注册(使用原始名称)
match component_type: match component_type:
case ComponentType.ACTION: case ComponentType.ACTION:
ret = self._register_action_component(component_info, component_class) # type: ignore assert isinstance(component_info, ActionInfo)
assert issubclass(component_class, BaseAction)
ret = self._register_action_component(component_info, component_class)
case ComponentType.COMMAND: case ComponentType.COMMAND:
ret = self._register_command_component(component_info, component_class) # type: ignore assert isinstance(component_info, CommandInfo)
assert issubclass(component_class, BaseCommand)
ret = self._register_command_component(component_info, component_class)
case ComponentType.EVENT_HANDLER: case ComponentType.EVENT_HANDLER:
ret = self._register_event_handler_component(component_info, component_class) # type: ignore assert isinstance(component_info, EventHandlerInfo)
assert issubclass(component_class, BaseEventHandler)
ret = self._register_event_handler_component(component_info, component_class)
case _: case _:
logger.warning(f"未知组件类型: {component_type}") logger.warning(f"未知组件类型: {component_type}")
@@ -160,7 +166,9 @@ class ComponentRegistry:
if pattern not in self._command_patterns: if pattern not in self._command_patterns:
self._command_patterns[pattern] = command_name self._command_patterns[pattern] = command_name
else: else:
logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令") logger.warning(
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
)
return True return True
@@ -176,6 +184,10 @@ class ComponentRegistry:
self._event_handler_registry[handler_name] = handler_class self._event_handler_registry[handler_name] = handler_class
if not handler_info.enabled:
logger.warning(f"EventHandler组件 {handler_name} 未启用")
return True # 未启用,但是也是注册成功
from .events_manager import events_manager # 延迟导入防止循环导入问题 from .events_manager import events_manager # 延迟导入防止循环导入问题
if events_manager.register_event_subscriber(handler_info, handler_class): if events_manager.register_event_subscriber(handler_info, handler_class):
@@ -185,6 +197,98 @@ class ComponentRegistry:
logger.error(f"注册事件处理器 {handler_name} 失败") logger.error(f"注册事件处理器 {handler_name} 失败")
return False return False
# === 组件移除相关 ===
async def remove_component(self, component_name: str, component_type: ComponentType):
target_component_class = self.get_component_class(component_name, component_type)
if not target_component_class:
logger.warning(f"组件 {component_name} 未注册,无法移除")
return
match component_type:
case ComponentType.ACTION:
self._action_registry.pop(component_name, None)
self._default_actions.pop(component_name, None)
case ComponentType.COMMAND:
self._command_registry.pop(component_name, None)
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
for key in keys_to_remove:
self._command_patterns.pop(key, None)
case ComponentType.EVENT_HANDLER:
from .events_manager import events_manager # 延迟导入防止循环导入问题
self._event_handler_registry.pop(component_name, None)
self._enabled_event_handlers.pop(component_name, None)
await events_manager.unregister_event_subscriber(component_name)
self._components.pop(component_name, None)
self._components_by_type[component_type].pop(component_name, None)
self._components_classes.pop(component_name, None)
logger.info(f"组件 {component_name} 已移除")
# === 组件全局启用/禁用方法 ===
def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
"""全局的启用某个组件
Parameters:
component_name: 组件名称
component_type: 组件类型
Returns:
bool: 启用成功返回True失败返回False
"""
target_component_class = self.get_component_class(component_name, component_type)
target_component_info = self.get_component_info(component_name, component_type)
if not target_component_class or not target_component_info:
logger.warning(f"组件 {component_name} 未注册,无法启用")
return False
target_component_info.enabled = True
match component_type:
case ComponentType.ACTION:
assert isinstance(target_component_info, ActionInfo)
self._default_actions[component_name] = target_component_info
case ComponentType.COMMAND:
assert isinstance(target_component_info, CommandInfo)
pattern = target_component_info.command_pattern
self._command_patterns[re.compile(pattern)] = component_name
case ComponentType.EVENT_HANDLER:
assert isinstance(target_component_info, EventHandlerInfo)
assert issubclass(target_component_class, BaseEventHandler)
self._enabled_event_handlers[component_name] = target_component_class
from .events_manager import events_manager # 延迟导入防止循环导入问题
events_manager.register_event_subscriber(target_component_info, target_component_class)
self._components[component_name].enabled = True
self._components_by_type[component_type][component_name].enabled = True
logger.info(f"组件 {component_name} 已启用")
return True
async def disable_component(self, component_name: str, component_type: ComponentType) -> bool:
"""全局的禁用某个组件
Parameters:
component_name: 组件名称
component_type: 组件类型
Returns:
bool: 禁用成功返回True失败返回False
"""
target_component_class = self.get_component_class(component_name, component_type)
target_component_info = self.get_component_info(component_name, component_type)
if not target_component_class or not target_component_info:
logger.warning(f"组件 {component_name} 未注册,无法禁用")
return False
target_component_info.enabled = False
match component_type:
case ComponentType.ACTION:
self._default_actions.pop(component_name, None)
case ComponentType.COMMAND:
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
case ComponentType.EVENT_HANDLER:
self._enabled_event_handlers.pop(component_name, None)
from .events_manager import events_manager # 延迟导入防止循环导入问题
await events_manager.unregister_event_subscriber(component_name)
self._components[component_name].enabled = False
self._components_by_type[component_type][component_name].enabled = False
logger.info(f"组件 {component_name} 已禁用")
return True
# === 组件查询方法 === # === 组件查询方法 ===
def get_component_info( def get_component_info(
self, component_name: str, component_type: Optional[ComponentType] = None self, component_name: str, component_type: Optional[ComponentType] = None
@@ -287,7 +391,7 @@ class ComponentRegistry:
# === Action特定查询方法 === # === Action特定查询方法 ===
def get_action_registry(self) -> Dict[str, Type[BaseAction]]: def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
"""获取Action注册表(用于兼容现有系统)""" """获取Action注册表"""
return self._action_registry.copy() return self._action_registry.copy()
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]: def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
@@ -314,7 +418,7 @@ class ComponentRegistry:
"""获取Command模式注册表""" """获取Command模式注册表"""
return self._command_patterns.copy() return self._command_patterns.copy()
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]: def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
# sourcery skip: use-named-expression, use-next # sourcery skip: use-named-expression, use-next
"""根据文本查找匹配的命令 """根据文本查找匹配的命令
@@ -335,8 +439,7 @@ class ComponentRegistry:
return ( return (
self._command_registry[command_name], self._command_registry[command_name],
candidates[0].match(text).groupdict(), # type: ignore candidates[0].match(text).groupdict(), # type: ignore
command_info.intercept_message, command_info,
command_info.plugin_name,
) )
# === 事件处理器特定查询方法 === # === 事件处理器特定查询方法 ===

View File

@@ -6,6 +6,7 @@ from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.base_events_handler import BaseEventHandler
from .global_announcement_manager import global_announcement_manager
logger = get_logger("events_manager") logger = get_logger("events_manager")
@@ -28,18 +29,16 @@ class EventsManager:
bool: 是否注册成功 bool: 是否注册成功
""" """
handler_name = handler_info.name handler_name = handler_info.name
plugin_name = getattr(handler_info, "plugin_name", "unknown")
namespace_name = f"{plugin_name}.{handler_name}" if handler_name in self._handler_mapping:
if namespace_name in self._handler_mapping: logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册")
return False return False
if not issubclass(handler_class, BaseEventHandler): if not issubclass(handler_class, BaseEventHandler):
logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类") logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类")
return False return False
self._handler_mapping[namespace_name] = handler_class self._handler_mapping[handler_name] = handler_class
return self._insert_event_handler(handler_class, handler_info) return self._insert_event_handler(handler_class, handler_info)
async def handle_mai_events( async def handle_mai_events(
@@ -55,6 +54,10 @@ class EventsManager:
continue_flag = True continue_flag = True
transformed_message = self._transform_event_message(message, llm_prompt, llm_response) transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
for handler in self._events_subscribers.get(event_type, []): for handler in self._events_subscribers.get(event_type, []):
if message.chat_stream and message.chat_stream.stream_id:
stream_id = message.chat_stream.stream_id
if handler.handler_name in global_announcement_manager.get_disabled_chat_event_handlers(stream_id):
continue
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {}) handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
if handler.intercept_message: if handler.intercept_message:
try: try:
@@ -71,7 +74,7 @@ class EventsManager:
try: try:
handler_task = asyncio.create_task(handler.execute(transformed_message)) handler_task = asyncio.create_task(handler.execute(transformed_message))
handler_task.add_done_callback(self._task_done_callback) handler_task.add_done_callback(self._task_done_callback)
handler_task.set_name(f"EventHandler-{handler.handler_name}-{event_type.name}") handler_task.set_name(f"{handler.plugin_name}-{handler.handler_name}")
self._handler_tasks[handler.handler_name].append(handler_task) self._handler_tasks[handler.handler_name].append(handler_task)
except Exception as e: except Exception as e:
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}") logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
@@ -91,7 +94,7 @@ class EventsManager:
return True return True
def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool: def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool:
"""从事件类型列表中移除事件处理器""" """从事件类型列表中移除事件处理器"""
display_handler_name = handler_class.handler_name or handler_class.__name__ display_handler_name = handler_class.handler_name or handler_class.__name__
if handler_class.event_type == EventType.UNKNOWN: if handler_class.event_type == EventType.UNKNOWN:
@@ -190,5 +193,20 @@ class EventsManager:
finally: finally:
del self._handler_tasks[handler_name] del self._handler_tasks[handler_name]
async def unregister_event_subscriber(self, handler_name: str) -> bool:
"""取消注册事件处理器"""
if handler_name not in self._handler_mapping:
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
return False
await self.cancel_handler_tasks(handler_name)
handler_class = self._handler_mapping.pop(handler_name)
if not self._remove_event_handler_instance(handler_class):
return False
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
return True
events_manager = EventsManager() events_manager = EventsManager()

View File

@@ -0,0 +1,90 @@
from typing import List, Dict
from src.common.logger import get_logger
logger = get_logger("global_announcement_manager")
class GlobalAnnouncementManager:
def __init__(self) -> None:
# 用户禁用的动作chat_id -> [action_name]
self._user_disabled_actions: Dict[str, List[str]] = {}
# 用户禁用的命令chat_id -> [command_name]
self._user_disabled_commands: Dict[str, List[str]] = {}
# 用户禁用的事件处理器chat_id -> [handler_name]
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
"""禁用特定聊天的某个动作"""
if chat_id not in self._user_disabled_actions:
self._user_disabled_actions[chat_id] = []
if action_name in self._user_disabled_actions[chat_id]:
logger.warning(f"动作 {action_name} 已经被禁用")
return False
self._user_disabled_actions[chat_id].append(action_name)
return True
def enable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
"""启用特定聊天的某个动作"""
if chat_id in self._user_disabled_actions:
try:
self._user_disabled_actions[chat_id].remove(action_name)
return True
except ValueError:
return False
return False
def disable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
"""禁用特定聊天的某个命令"""
if chat_id not in self._user_disabled_commands:
self._user_disabled_commands[chat_id] = []
if command_name in self._user_disabled_commands[chat_id]:
logger.warning(f"命令 {command_name} 已经被禁用")
return False
self._user_disabled_commands[chat_id].append(command_name)
return True
def enable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
"""启用特定聊天的某个命令"""
if chat_id in self._user_disabled_commands:
try:
self._user_disabled_commands[chat_id].remove(command_name)
return True
except ValueError:
return False
return False
def disable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
"""禁用特定聊天的某个事件处理器"""
if chat_id not in self._user_disabled_event_handlers:
self._user_disabled_event_handlers[chat_id] = []
if handler_name in self._user_disabled_event_handlers[chat_id]:
logger.warning(f"事件处理器 {handler_name} 已经被禁用")
return False
self._user_disabled_event_handlers[chat_id].append(handler_name)
return True
def enable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
"""启用特定聊天的某个事件处理器"""
if chat_id in self._user_disabled_event_handlers:
try:
self._user_disabled_event_handlers[chat_id].remove(handler_name)
return True
except ValueError:
return False
return False
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有动作"""
return self._user_disabled_actions.get(chat_id, []).copy()
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有命令"""
return self._user_disabled_commands.get(chat_id, []).copy()
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有事件处理器"""
return self._user_disabled_event_handlers.get(chat_id, []).copy()
global_announcement_manager = GlobalAnnouncementManager()

View File

@@ -1,5 +1,4 @@
import os import os
import inspect
import traceback import traceback
from typing import Dict, List, Optional, Tuple, Type, Any from typing import Dict, List, Optional, Tuple, Type, Any
@@ -8,11 +7,11 @@ from pathlib import Path
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.dependency_manager import dependency_manager
from src.plugin_system.base.plugin_base import PluginBase from src.plugin_system.base.plugin_base import PluginBase
from src.plugin_system.base.component_types import ComponentType, PluginInfo, PythonDependency from src.plugin_system.base.component_types import ComponentType, PluginInfo, PythonDependency
from src.plugin_system.utils.manifest_utils import VersionComparator from src.plugin_system.utils.manifest_utils import VersionComparator
from .component_registry import component_registry
from .dependency_manager import dependency_manager
logger = get_logger("plugin_manager") logger = get_logger("plugin_manager")
@@ -36,19 +35,7 @@ class PluginManager:
self._ensure_plugin_directories() self._ensure_plugin_directories()
logger.info("插件管理器初始化完成") logger.info("插件管理器初始化完成")
def _ensure_plugin_directories(self) -> None: # === 插件目录管理 ===
"""确保所有插件根目录存在,如果不存在则创建"""
default_directories = ["src/plugins/built_in", "plugins"]
for directory in default_directories:
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
logger.info(f"创建插件根目录: {directory}")
if directory not in self.plugin_directories:
self.plugin_directories.append(directory)
logger.debug(f"已添加插件根目录: {directory}")
else:
logger.warning(f"根目录不可重复加载: {directory}")
def add_plugin_directory(self, directory: str) -> bool: def add_plugin_directory(self, directory: str) -> bool:
"""添加插件目录""" """添加插件目录"""
@@ -63,6 +50,8 @@ class PluginManager:
logger.warning(f"插件目录不存在: {directory}") logger.warning(f"插件目录不存在: {directory}")
return False return False
# === 插件加载管理 ===
def load_all_plugins(self) -> Tuple[int, int]: def load_all_plugins(self) -> Tuple[int, int]:
"""加载所有插件 """加载所有插件
@@ -86,7 +75,7 @@ class PluginManager:
total_failed_registration = 0 total_failed_registration = 0
for plugin_name in self.plugin_classes.keys(): for plugin_name in self.plugin_classes.keys():
load_status, count = self.load_registered_plugin_classes(plugin_name) load_status, count = self._load_registered_plugin_classes(plugin_name)
if load_status: if load_status:
total_registered += 1 total_registered += 1
else: else:
@@ -96,90 +85,32 @@ class PluginManager:
return total_registered, total_failed_registration return total_registered, total_failed_registration
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]: async def remove_registered_plugin(self, plugin_name: str) -> None:
# sourcery skip: extract-duplicate-method, extract-method
""" """
加载已经注册的插件类 禁用插件模块
""" """
plugin_class = self.plugin_classes.get(plugin_name) if not plugin_name:
if not plugin_class: raise ValueError("插件名称不能为空")
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在") if plugin_name not in self.loaded_plugins:
return False, 1 logger.warning(f"插件 {plugin_name} 未加载")
try: return
# 使用记录的插件目录路径 plugin_instance = self.loaded_plugins[plugin_name]
plugin_dir = self.plugin_paths.get(plugin_name) plugin_info = plugin_instance.plugin_info
for component in plugin_info.components:
await component_registry.remove_component(component.name, component.component_type)
del self.loaded_plugins[plugin_name]
# 如果没有记录,直接返回失败 async def reload_registered_plugin_module(self, plugin_name: str) -> None:
if not plugin_dir:
return False, 1
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件可能因为缺少manifest而失败
if not plugin_instance:
logger.error(f"插件 {plugin_name} 实例化失败")
return False, 1
# 检查插件是否启用
if not plugin_instance.enable_plugin:
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
return False, 0
# 检查版本兼容性
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
plugin_name, plugin_instance.manifest_data
)
if not is_compatible:
self.failed_plugins[plugin_name] = compatibility_error
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
return False, 1
if plugin_instance.register_plugin():
self.loaded_plugins[plugin_name] = plugin_instance
self._show_plugin_components(plugin_name)
return True, 1
else:
self.failed_plugins[plugin_name] = "插件注册失败"
logger.error(f"❌ 插件注册失败: {plugin_name}")
return False, 1
except FileNotFoundError as e:
# manifest文件缺失
error_msg = f"缺少manifest文件: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except ValueError as e:
# manifest文件格式错误或验证失败
traceback.print_exc()
error_msg = f"manifest验证失败: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except Exception as e:
# 其他错误
error_msg = f"未知错误: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
logger.debug("详细错误信息: ", exc_info=True)
return False, 1
def unload_registered_plugin_module(self, plugin_name: str) -> None:
"""
卸载插件模块
"""
pass
def reload_registered_plugin_module(self, plugin_name: str) -> None:
""" """
重载插件模块 重载插件模块
""" """
self.unload_registered_plugin_module(plugin_name) await self.remove_registered_plugin(plugin_name)
self.load_registered_plugin_classes(plugin_name) self._load_registered_plugin_classes(plugin_name)
def rescan_plugin_directory(self) -> None: def rescan_plugin_directory(self) -> None:
""" """
重新扫描插件根目录 重新扫描插件根目录
""" """
# --------------------------------------- NEED REFACTORING ---------------------------------------
for directory in self.plugin_directories: for directory in self.plugin_directories:
if os.path.exists(directory): if os.path.exists(directory):
logger.debug(f"重新扫描插件根目录: {directory}") logger.debug(f"重新扫描插件根目录: {directory}")
@@ -195,30 +126,6 @@ class PluginManager:
"""获取所有启用的插件信息""" """获取所有启用的插件信息"""
return list(component_registry.get_enabled_plugins().values()) return list(component_registry.get_enabled_plugins().values())
# def enable_plugin(self, plugin_name: str) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# """启用插件"""
# if plugin_info := component_registry.get_plugin_info(plugin_name):
# plugin_info.enabled = True
# # 启用插件的所有组件
# for component in plugin_info.components:
# component_registry.enable_component(component.name)
# logger.debug(f"已启用插件: {plugin_name}")
# return True
# return False
# def disable_plugin(self, plugin_name: str) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# """禁用插件"""
# if plugin_info := component_registry.get_plugin_info(plugin_name):
# plugin_info.enabled = False
# # 禁用插件的所有组件
# for component in plugin_info.components:
# component_registry.disable_component(component.name)
# logger.debug(f"已禁用插件: {plugin_name}")
# return True
# return False
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]: def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
"""获取插件实例 """获取插件实例
@@ -230,25 +137,6 @@ class PluginManager:
""" """
return self.loaded_plugins.get(plugin_name) return self.loaded_plugins.get(plugin_name)
def get_plugin_stats(self) -> Dict[str, Any]:
"""获取插件统计信息"""
all_plugins = component_registry.get_all_plugins()
enabled_plugins = component_registry.get_enabled_plugins()
action_components = component_registry.get_components_by_type(ComponentType.ACTION)
command_components = component_registry.get_components_by_type(ComponentType.COMMAND)
return {
"total_plugins": len(all_plugins),
"enabled_plugins": len(enabled_plugins),
"failed_plugins": len(self.failed_plugins),
"total_components": len(action_components) + len(command_components),
"action_components": len(action_components),
"command_components": len(command_components),
"loaded_plugin_files": len(self.loaded_plugins),
"failed_plugin_details": self.failed_plugins.copy(),
}
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]: def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]:
"""检查所有插件的Python依赖包 """检查所有插件的Python依赖包
@@ -347,6 +235,24 @@ class PluginManager:
return dependency_manager.generate_requirements_file(all_dependencies, output_path) return dependency_manager.generate_requirements_file(all_dependencies, output_path)
# === 私有方法 ===
# == 目录管理 ==
def _ensure_plugin_directories(self) -> None:
"""确保所有插件根目录存在,如果不存在则创建"""
default_directories = ["src/plugins/built_in", "plugins"]
for directory in default_directories:
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
logger.info(f"创建插件根目录: {directory}")
if directory not in self.plugin_directories:
self.plugin_directories.append(directory)
logger.debug(f"已添加插件根目录: {directory}")
else:
logger.warning(f"根目录不可重复加载: {directory}")
# == 插件加载 ==
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]: def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
"""从指定目录加载插件模块""" """从指定目录加载插件模块"""
loaded_count = 0 loaded_count = 0
@@ -372,18 +278,6 @@ class PluginManager:
return loaded_count, failed_count return loaded_count, failed_count
def _find_plugin_directory(self, plugin_class: Type[PluginBase]) -> Optional[str]:
"""查找插件类对应的目录路径"""
try:
# module = getmodule(plugin_class)
# if module and hasattr(module, "__file__") and module.__file__:
# return os.path.dirname(module.__file__)
file_path = inspect.getfile(plugin_class)
return os.path.dirname(file_path)
except Exception as e:
logger.debug(f"通过inspect获取插件目录失败: {e}")
return None
def _load_plugin_module_file(self, plugin_file: str) -> bool: def _load_plugin_module_file(self, plugin_file: str) -> bool:
# sourcery skip: extract-method # sourcery skip: extract-method
"""加载单个插件模块文件 """加载单个插件模块文件
@@ -416,6 +310,74 @@ class PluginManager:
self.failed_plugins[module_name] = error_msg self.failed_plugins[module_name] = error_msg
return False return False
def _load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
# sourcery skip: extract-duplicate-method, extract-method
"""
加载已经注册的插件类
"""
plugin_class = self.plugin_classes.get(plugin_name)
if not plugin_class:
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
return False, 1
try:
# 使用记录的插件目录路径
plugin_dir = self.plugin_paths.get(plugin_name)
# 如果没有记录,直接返回失败
if not plugin_dir:
return False, 1
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件可能因为缺少manifest而失败
if not plugin_instance:
logger.error(f"插件 {plugin_name} 实例化失败")
return False, 1
# 检查插件是否启用
if not plugin_instance.enable_plugin:
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
return False, 0
# 检查版本兼容性
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
plugin_name, plugin_instance.manifest_data
)
if not is_compatible:
self.failed_plugins[plugin_name] = compatibility_error
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
return False, 1
if plugin_instance.register_plugin():
self.loaded_plugins[plugin_name] = plugin_instance
self._show_plugin_components(plugin_name)
return True, 1
else:
self.failed_plugins[plugin_name] = "插件注册失败"
logger.error(f"❌ 插件注册失败: {plugin_name}")
return False, 1
except FileNotFoundError as e:
# manifest文件缺失
error_msg = f"缺少manifest文件: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except ValueError as e:
# manifest文件格式错误或验证失败
traceback.print_exc()
error_msg = f"manifest验证失败: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except Exception as e:
# 其他错误
error_msg = f"未知错误: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
logger.debug("详细错误信息: ", exc_info=True)
return False, 1
# == 兼容性检查 ==
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
"""检查插件版本兼容性 """检查插件版本兼容性
@@ -451,6 +413,8 @@ class PluginManager:
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}") logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载 return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
# == 显示统计与插件信息 ==
def _show_stats(self, total_registered: int, total_failed_registration: int): def _show_stats(self, total_registered: int, total_failed_registration: int):
# sourcery skip: low-code-quality # sourcery skip: low-code-quality
# 获取组件统计信息 # 获取组件统计信息
@@ -493,9 +457,15 @@ class PluginManager:
# 组件列表 # 组件列表
if plugin_info.components: if plugin_info.components:
action_components = [c for c in plugin_info.components if c.component_type == ComponentType.ACTION] action_components = [
command_components = [c for c in plugin_info.components if c.component_type == ComponentType.COMMAND] c for c in plugin_info.components if c.component_type == ComponentType.ACTION
event_handler_components = [c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER] ]
command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
]
event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
]
if action_components: if action_components:
action_names = [c.name for c in action_components] action_names = [c.name for c in action_components]