Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
14
changes.md
14
changes.md
@@ -45,10 +45,19 @@
|
|||||||
10. 修正了`main.py`中的错误输出。
|
10. 修正了`main.py`中的错误输出。
|
||||||
11. 修正了`command`所编译的`Pattern`注册时的错误输出。
|
11. 修正了`command`所编译的`Pattern`注册时的错误输出。
|
||||||
12. `events_manager`有了task相关逻辑了。
|
12. `events_manager`有了task相关逻辑了。
|
||||||
|
13. 现在有了插件卸载和重载功能了,也就是热插拔。
|
||||||
|
14. 实现了组件的全局启用和禁用功能。
|
||||||
|
- 通过`enable_component`和`disable_component`方法来启用或禁用组件。
|
||||||
|
- 不过这个操作不会保存到配置文件~
|
||||||
|
15. 实现了组件的局部禁用,也就是针对某一个聊天禁用的功能。
|
||||||
|
- 通过`disable_specific_chat_action`,`enable_specific_chat_action`,`disable_specific_chat_command`,`enable_specific_chat_command`,`disable_specific_chat_event_handler`,`enable_specific_chat_event_handler`来操作
|
||||||
|
- 同样不保存到配置文件~
|
||||||
|
|
||||||
### TODO
|
### TODO
|
||||||
把这个看起来就很别扭的config获取方式改一下
|
把这个看起来就很别扭的config获取方式改一下
|
||||||
|
|
||||||
|
来个API管理这些启用禁用!
|
||||||
|
|
||||||
|
|
||||||
# 吐槽
|
# 吐槽
|
||||||
```python
|
```python
|
||||||
@@ -64,4 +73,7 @@ else:
|
|||||||
plugin_path = Path(plugin_file)
|
plugin_path = Path(plugin_file)
|
||||||
module_name = ".".join(plugin_path.parent.parts)
|
module_name = ".".join(plugin_path.parent.parts)
|
||||||
```
|
```
|
||||||
这两个区别很大的。
|
这两个区别很大的。
|
||||||
|
|
||||||
|
### 执笔BGM
|
||||||
|
塞壬唱片!
|
||||||
@@ -51,6 +51,8 @@ NO_ACTION = {
|
|||||||
"action_prompt": "",
|
"action_prompt": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
IS_MAI4U = False
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
|
# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
|
||||||
@@ -256,31 +258,29 @@ class HeartFChatting:
|
|||||||
)
|
)
|
||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
return f"{person_name}:{message_data.get('processed_plain_text')}"
|
return f"{person_name}:{message_data.get('processed_plain_text')}"
|
||||||
|
|
||||||
async def send_typing(self):
|
async def send_typing(self):
|
||||||
group_info = GroupInfo(platform = "amaidesu_default",group_id = 114514,group_name = "内心")
|
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
platform = "amaidesu_default",
|
platform="amaidesu_default",
|
||||||
user_info = None,
|
user_info=None,
|
||||||
group_info = group_info
|
group_info=group_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
await send_api.custom_to_stream(
|
await send_api.custom_to_stream(
|
||||||
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stop_typing(self):
|
async def stop_typing(self):
|
||||||
group_info = GroupInfo(platform = "amaidesu_default",group_id = 114514,group_name = "内心")
|
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
platform = "amaidesu_default",
|
platform="amaidesu_default",
|
||||||
user_info = None,
|
user_info=None,
|
||||||
group_info = group_info
|
group_info=group_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
await send_api.custom_to_stream(
|
await send_api.custom_to_stream(
|
||||||
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
||||||
)
|
)
|
||||||
@@ -373,7 +373,6 @@ class HeartFChatting:
|
|||||||
if ENABLE_THINKING:
|
if ENABLE_THINKING:
|
||||||
await self.stop_typing()
|
await self.stop_typing()
|
||||||
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
|
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
|
||||||
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -505,10 +504,9 @@ class HeartFChatting:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
interested_rate = (message_data.get("interest_value") or 0.0) * self.willing_amplifier
|
interested_rate = (message_data.get("interest_value") or 0.0) * self.willing_amplifier
|
||||||
|
|
||||||
self.willing_manager.setup(message_data, self.chat_stream)
|
self.willing_manager.setup(message_data, self.chat_stream)
|
||||||
|
|
||||||
|
|
||||||
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id", ""))
|
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id", ""))
|
||||||
|
|
||||||
talk_frequency = -1.00
|
talk_frequency = -1.00
|
||||||
@@ -518,7 +516,7 @@ class HeartFChatting:
|
|||||||
if additional_config and "maimcore_reply_probability_gain" in additional_config:
|
if additional_config and "maimcore_reply_probability_gain" in additional_config:
|
||||||
reply_probability += additional_config["maimcore_reply_probability_gain"]
|
reply_probability += additional_config["maimcore_reply_probability_gain"]
|
||||||
reply_probability = min(max(reply_probability, 0), 1) # 确保概率在 0-1 之间
|
reply_probability = min(max(reply_probability, 0), 1) # 确保概率在 0-1 之间
|
||||||
|
|
||||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id)
|
talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id)
|
||||||
reply_probability = talk_frequency * reply_probability
|
reply_probability = talk_frequency * reply_probability
|
||||||
|
|
||||||
@@ -528,9 +526,9 @@ class HeartFChatting:
|
|||||||
|
|
||||||
# 打印消息信息
|
# 打印消息信息
|
||||||
mes_name = self.chat_stream.group_info.group_name if self.chat_stream.group_info else "私聊"
|
mes_name = self.chat_stream.group_info.group_name if self.chat_stream.group_info else "私聊"
|
||||||
|
|
||||||
# logger.info(f"[{mes_name}] 当前聊天频率: {talk_frequency:.2f},兴趣值: {interested_rate:.2f},回复概率: {reply_probability * 100:.1f}%")
|
# logger.info(f"[{mes_name}] 当前聊天频率: {talk_frequency:.2f},兴趣值: {interested_rate:.2f},回复概率: {reply_probability * 100:.1f}%")
|
||||||
|
|
||||||
if reply_probability > 0.05:
|
if reply_probability > 0.05:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{mes_name}]"
|
f"[{mes_name}]"
|
||||||
@@ -546,7 +544,6 @@ class HeartFChatting:
|
|||||||
# 意愿管理器:注销当前message信息 (无论是否回复,只要处理过就删除)
|
# 意愿管理器:注销当前message信息 (无论是否回复,只要处理过就删除)
|
||||||
self.willing_manager.delete(message_data.get("message_id", ""))
|
self.willing_manager.delete(message_data.get("message_id", ""))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def _generate_response(
|
async def _generate_response(
|
||||||
self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str
|
self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str
|
||||||
@@ -571,7 +568,7 @@ class HeartFChatting:
|
|||||||
logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}")
|
logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _send_response(self, reply_set, reply_to, thinking_start_time,message_data):
|
async def _send_response(self, reply_set, reply_to, thinking_start_time, message_data):
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
new_message_count = message_api.count_new_messages(
|
new_message_count = message_api.count_new_messages(
|
||||||
chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
|
chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
|
||||||
@@ -593,13 +590,27 @@ class HeartFChatting:
|
|||||||
if not first_replied:
|
if not first_replied:
|
||||||
if need_reply:
|
if need_reply:
|
||||||
await send_api.text_to_stream(
|
await send_api.text_to_stream(
|
||||||
text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False
|
text=data,
|
||||||
|
stream_id=self.chat_stream.stream_id,
|
||||||
|
reply_to=reply_to,
|
||||||
|
reply_to_platform_id=reply_to_platform_id,
|
||||||
|
typing=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=False)
|
await send_api.text_to_stream(
|
||||||
|
text=data,
|
||||||
|
stream_id=self.chat_stream.stream_id,
|
||||||
|
reply_to_platform_id=reply_to_platform_id,
|
||||||
|
typing=False,
|
||||||
|
)
|
||||||
first_replied = True
|
first_replied = True
|
||||||
else:
|
else:
|
||||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=True)
|
await send_api.text_to_stream(
|
||||||
|
text=data,
|
||||||
|
stream_id=self.chat_stream.stream_id,
|
||||||
|
reply_to_platform_id=reply_to_platform_id,
|
||||||
|
typing=True,
|
||||||
|
)
|
||||||
reply_text += data
|
reply_text += data
|
||||||
|
|
||||||
return reply_text
|
return reply_text
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
|||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.plugin_system.core import component_registry, events_manager # 导入新插件系统
|
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||||
from src.plugin_system.base import BaseCommand, EventType
|
from src.plugin_system.base import BaseCommand, EventType
|
||||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||||
|
|
||||||
@@ -91,8 +91,20 @@ class ChatBot:
|
|||||||
# 使用新的组件注册中心查找命令
|
# 使用新的组件注册中心查找命令
|
||||||
command_result = component_registry.find_command_by_text(text)
|
command_result = component_registry.find_command_by_text(text)
|
||||||
if command_result:
|
if command_result:
|
||||||
|
command_class, matched_groups, command_info = command_result
|
||||||
|
intercept_message = command_info.intercept_message
|
||||||
|
plugin_name = command_info.plugin_name
|
||||||
|
command_name = command_info.name
|
||||||
|
if (
|
||||||
|
message.chat_stream
|
||||||
|
and message.chat_stream.stream_id
|
||||||
|
and command_name
|
||||||
|
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
|
||||||
|
):
|
||||||
|
logger.info("用户禁用的命令,跳过处理")
|
||||||
|
return False, None, True
|
||||||
|
|
||||||
message.is_command = True
|
message.is_command = True
|
||||||
command_class, matched_groups, intercept_message, plugin_name = command_result
|
|
||||||
|
|
||||||
# 获取插件配置
|
# 获取插件配置
|
||||||
plugin_config = component_registry.get_plugin_config(plugin_name)
|
plugin_config = component_registry.get_plugin_config(plugin_name)
|
||||||
@@ -134,13 +146,12 @@ class ChatBot:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理命令时出错: {e}")
|
logger.error(f"处理命令时出错: {e}")
|
||||||
return False, None, True # 出错时继续处理消息
|
return False, None, True # 出错时继续处理消息
|
||||||
|
|
||||||
async def hanle_notice_message(self, message: MessageRecv):
|
async def hanle_notice_message(self, message: MessageRecv):
|
||||||
if message.message_info.message_id == "notice":
|
if message.message_info.message_id == "notice":
|
||||||
logger.info("收到notice消息,暂时不支持处理")
|
logger.info("收到notice消息,暂时不支持处理")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def do_s4u(self, message_data: Dict[str, Any]):
|
async def do_s4u(self, message_data: Dict[str, Any]):
|
||||||
message = MessageRecvS4U(message_data)
|
message = MessageRecvS4U(message_data)
|
||||||
group_info = message.message_info.group_info
|
group_info = message.message_info.group_info
|
||||||
@@ -162,7 +173,6 @@ class ChatBot:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||||
"""处理转化后的统一格式消息
|
"""处理转化后的统一格式消息
|
||||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
||||||
@@ -178,8 +188,6 @@ class ChatBot:
|
|||||||
- 性能计时
|
- 性能计时
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
||||||
# 确保所有任务已启动
|
# 确保所有任务已启动
|
||||||
await self._ensure_started()
|
await self._ensure_started()
|
||||||
|
|
||||||
@@ -200,11 +208,10 @@ class ChatBot:
|
|||||||
# print(message_data)
|
# print(message_data)
|
||||||
# logger.debug(str(message_data))
|
# logger.debug(str(message_data))
|
||||||
message = MessageRecv(message_data)
|
message = MessageRecv(message_data)
|
||||||
|
|
||||||
if await self.hanle_notice_message(message):
|
if await self.hanle_notice_message(message):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
group_info = message.message_info.group_info
|
group_info = message.message_info.group_info
|
||||||
user_info = message.message_info.user_info
|
user_info = message.message_info.user_info
|
||||||
if message.message_info.additional_config:
|
if message.message_info.additional_config:
|
||||||
@@ -228,11 +235,10 @@ class ChatBot:
|
|||||||
|
|
||||||
# 处理消息内容,生成纯文本
|
# 处理消息内容,生成纯文本
|
||||||
await message.process()
|
await message.process()
|
||||||
|
|
||||||
# if await self.check_ban_content(message):
|
# if await self.check_ban_content(message):
|
||||||
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
|
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
|
||||||
# return
|
# return
|
||||||
|
|
||||||
|
|
||||||
# 过滤检查
|
# 过滤检查
|
||||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||||
|
|||||||
@@ -163,20 +163,25 @@ class ChatManager:
|
|||||||
"""注册消息到聊天流"""
|
"""注册消息到聊天流"""
|
||||||
stream_id = self._generate_stream_id(
|
stream_id = self._generate_stream_id(
|
||||||
message.message_info.platform, # type: ignore
|
message.message_info.platform, # type: ignore
|
||||||
message.message_info.user_info, # type: ignore
|
message.message_info.user_info,
|
||||||
message.message_info.group_info,
|
message.message_info.group_info,
|
||||||
)
|
)
|
||||||
self.last_messages[stream_id] = message
|
self.last_messages[stream_id] = message
|
||||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
def _generate_stream_id(
|
||||||
|
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
|
||||||
|
) -> str:
|
||||||
"""生成聊天流唯一ID"""
|
"""生成聊天流唯一ID"""
|
||||||
|
if not user_info and not group_info:
|
||||||
|
raise ValueError("用户信息或群组信息必须提供")
|
||||||
|
|
||||||
if group_info:
|
if group_info:
|
||||||
# 组合关键信息
|
# 组合关键信息
|
||||||
components = [platform, str(group_info.group_id)]
|
components = [platform, str(group_info.group_id)]
|
||||||
else:
|
else:
|
||||||
components = [platform, str(user_info.user_id), "private"]
|
components = [platform, str(user_info.user_id), "private"] # type: ignore
|
||||||
|
|
||||||
# 使用MD5生成唯一ID
|
# 使用MD5生成唯一ID
|
||||||
key = "_".join(components)
|
key = "_".join(components)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional, Type
|
from typing import Dict, Optional, Type
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
from src.plugin_system.base.base_action import BaseAction
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -22,53 +22,14 @@ class ActionManager:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化动作管理器"""
|
"""初始化动作管理器"""
|
||||||
# 所有注册的动作集合
|
|
||||||
self._registered_actions: Dict[str, ActionInfo] = {}
|
|
||||||
# 当前正在使用的动作集合,默认加载默认动作
|
# 当前正在使用的动作集合,默认加载默认动作
|
||||||
self._using_actions: Dict[str, ActionInfo] = {}
|
self._using_actions: Dict[str, ActionInfo] = {}
|
||||||
|
|
||||||
# 加载插件动作
|
|
||||||
self._load_plugin_actions()
|
|
||||||
|
|
||||||
# 初始化时将默认动作加载到使用中的动作
|
# 初始化时将默认动作加载到使用中的动作
|
||||||
self._using_actions = component_registry.get_default_actions()
|
self._using_actions = component_registry.get_default_actions()
|
||||||
|
|
||||||
def _load_plugin_actions(self) -> None:
|
# === 执行Action方法 ===
|
||||||
"""
|
|
||||||
加载所有插件系统中的动作
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 从新插件系统获取Action组件
|
|
||||||
self._load_plugin_system_actions()
|
|
||||||
logger.debug("从插件系统加载Action组件成功")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"加载插件动作失败: {e}")
|
|
||||||
|
|
||||||
def _load_plugin_system_actions(self) -> None:
|
|
||||||
"""从插件系统的component_registry加载Action组件"""
|
|
||||||
try:
|
|
||||||
# 获取所有Action组件
|
|
||||||
action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore
|
|
||||||
|
|
||||||
for action_name, action_info in action_components.items():
|
|
||||||
if action_name in self._registered_actions:
|
|
||||||
logger.debug(f"Action组件 {action_name} 已存在,跳过")
|
|
||||||
continue
|
|
||||||
|
|
||||||
self._registered_actions[action_name] = action_info
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"加载了 {len(action_components)} 个Action动作")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"从插件系统加载Action组件失败: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
def create_action(
|
def create_action(
|
||||||
self,
|
self,
|
||||||
@@ -139,36 +100,11 @@ class ActionManager:
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_registered_actions(self) -> Dict[str, ActionInfo]:
|
|
||||||
"""获取所有已注册的动作集"""
|
|
||||||
return self._registered_actions.copy()
|
|
||||||
|
|
||||||
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
||||||
"""获取当前正在使用的动作集合"""
|
"""获取当前正在使用的动作集合"""
|
||||||
return self._using_actions.copy()
|
return self._using_actions.copy()
|
||||||
|
|
||||||
def add_action_to_using(self, action_name: str) -> bool:
|
# === Modify相关方法 ===
|
||||||
"""
|
|
||||||
添加已注册的动作到当前使用的动作集
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_name: 动作名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 添加是否成功
|
|
||||||
"""
|
|
||||||
if action_name not in self._registered_actions:
|
|
||||||
logger.warning(f"添加失败: 动作 {action_name} 未注册")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if action_name in self._using_actions:
|
|
||||||
logger.info(f"动作 {action_name} 已经在使用中")
|
|
||||||
return True
|
|
||||||
|
|
||||||
self._using_actions[action_name] = self._registered_actions[action_name]
|
|
||||||
logger.info(f"添加动作 {action_name} 到使用集")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def remove_action_from_using(self, action_name: str) -> bool:
|
def remove_action_from_using(self, action_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
从当前使用的动作集中移除指定动作
|
从当前使用的动作集中移除指定动作
|
||||||
@@ -187,79 +123,8 @@ class ActionManager:
|
|||||||
logger.debug(f"已从使用集中移除动作 {action_name}")
|
logger.debug(f"已从使用集中移除动作 {action_name}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
|
|
||||||
# """
|
|
||||||
# 添加新的动作到注册集
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# action_name: 动作名称
|
|
||||||
# description: 动作描述
|
|
||||||
# parameters: 动作参数定义,默认为空字典
|
|
||||||
# require: 动作依赖项,默认为空列表
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# bool: 添加是否成功
|
|
||||||
# """
|
|
||||||
# if action_name in self._registered_actions:
|
|
||||||
# return False
|
|
||||||
|
|
||||||
# if parameters is None:
|
|
||||||
# parameters = {}
|
|
||||||
# if require is None:
|
|
||||||
# require = []
|
|
||||||
|
|
||||||
# action_info = {"description": description, "parameters": parameters, "require": require}
|
|
||||||
|
|
||||||
# self._registered_actions[action_name] = action_info
|
|
||||||
# return True
|
|
||||||
|
|
||||||
def remove_action(self, action_name: str) -> bool:
|
|
||||||
"""从注册集移除指定动作"""
|
|
||||||
if action_name not in self._registered_actions:
|
|
||||||
return False
|
|
||||||
del self._registered_actions[action_name]
|
|
||||||
# 如果在使用集中也存在,一并移除
|
|
||||||
if action_name in self._using_actions:
|
|
||||||
del self._using_actions[action_name]
|
|
||||||
return True
|
|
||||||
|
|
||||||
def temporarily_remove_actions(self, actions_to_remove: List[str]) -> None:
|
|
||||||
"""临时移除使用集中的指定动作"""
|
|
||||||
for name in actions_to_remove:
|
|
||||||
self._using_actions.pop(name, None)
|
|
||||||
|
|
||||||
def restore_actions(self) -> None:
|
def restore_actions(self) -> None:
|
||||||
"""恢复到默认动作集"""
|
"""恢复到默认动作集"""
|
||||||
actions_to_restore = list(self._using_actions.keys())
|
actions_to_restore = list(self._using_actions.keys())
|
||||||
self._using_actions = component_registry.get_default_actions()
|
self._using_actions = component_registry.get_default_actions()
|
||||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||||
|
|
||||||
def add_system_action_if_needed(self, action_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
根据需要添加系统动作到使用集
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_name: 动作名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否成功添加
|
|
||||||
"""
|
|
||||||
if action_name in self._registered_actions and action_name not in self._using_actions:
|
|
||||||
self._using_actions[action_name] = self._registered_actions[action_name]
|
|
||||||
logger.info(f"临时添加系统动作到使用集: {action_name}")
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_action(self, action_name: str) -> Optional[Type[BaseAction]]:
|
|
||||||
"""
|
|
||||||
获取指定动作的处理器类
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_name: 动作名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None
|
|
||||||
"""
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
return component_registry.get_component_class(action_name, ComponentType.ACTION) # type: ignore
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import random
|
|||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
from typing import List, Any, Dict, TYPE_CHECKING
|
from typing import List, Any, Dict, TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
@@ -11,6 +11,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageCo
|
|||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||||
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
||||||
|
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
@@ -47,7 +48,6 @@ class ActionModifier:
|
|||||||
|
|
||||||
async def modify_actions(
|
async def modify_actions(
|
||||||
self,
|
self,
|
||||||
history_loop=None,
|
|
||||||
message_content: str = "",
|
message_content: str = "",
|
||||||
): # sourcery skip: use-named-expression
|
): # sourcery skip: use-named-expression
|
||||||
"""
|
"""
|
||||||
@@ -61,8 +61,9 @@ class ActionModifier:
|
|||||||
"""
|
"""
|
||||||
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
||||||
|
|
||||||
removals_s1 = []
|
removals_s1: List[Tuple[str, str]] = []
|
||||||
removals_s2 = []
|
removals_s2: List[Tuple[str, str]] = []
|
||||||
|
removals_s3: List[Tuple[str, str]] = []
|
||||||
|
|
||||||
self.action_manager.restore_actions()
|
self.action_manager.restore_actions()
|
||||||
all_actions = self.action_manager.get_using_actions()
|
all_actions = self.action_manager.get_using_actions()
|
||||||
@@ -84,25 +85,28 @@ class ActionModifier:
|
|||||||
if message_content:
|
if message_content:
|
||||||
chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}"
|
chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}"
|
||||||
|
|
||||||
# === 第一阶段:传统观察处理 ===
|
# === 第一阶段:去除用户自行禁用的 ===
|
||||||
# if history_loop:
|
disabled_actions = global_announcement_manager.get_disabled_chat_actions(self.chat_id)
|
||||||
# removals_from_loop = await self.analyze_loop_actions(history_loop)
|
if disabled_actions:
|
||||||
# if removals_from_loop:
|
for disabled_action_name in disabled_actions:
|
||||||
# removals_s1.extend(removals_from_loop)
|
if disabled_action_name in all_actions:
|
||||||
|
removals_s1.append((disabled_action_name, "用户自行禁用"))
|
||||||
|
self.action_manager.remove_action_from_using(disabled_action_name)
|
||||||
|
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
||||||
|
|
||||||
# 检查动作的关联类型
|
# === 第二阶段:检查动作的关联类型 ===
|
||||||
chat_context = self.chat_stream.context
|
chat_context = self.chat_stream.context
|
||||||
type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context)
|
type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context)
|
||||||
|
|
||||||
if type_mismatched_actions:
|
if type_mismatched_actions:
|
||||||
removals_s1.extend(type_mismatched_actions)
|
removals_s2.extend(type_mismatched_actions)
|
||||||
|
|
||||||
# 应用第一阶段的移除
|
# 应用第二阶段的移除
|
||||||
for action_name, reason in removals_s1:
|
for action_name, reason in removals_s2:
|
||||||
self.action_manager.remove_action_from_using(action_name)
|
self.action_manager.remove_action_from_using(action_name)
|
||||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {action_name},原因: {reason}")
|
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
|
||||||
|
|
||||||
# === 第二阶段:激活类型判定 ===
|
# === 第三阶段:激活类型判定 ===
|
||||||
if chat_content is not None:
|
if chat_content is not None:
|
||||||
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||||
|
|
||||||
@@ -110,18 +114,18 @@ class ActionModifier:
|
|||||||
current_using_actions = self.action_manager.get_using_actions()
|
current_using_actions = self.action_manager.get_using_actions()
|
||||||
|
|
||||||
# 获取因激活类型判定而需要移除的动作
|
# 获取因激活类型判定而需要移除的动作
|
||||||
removals_s2 = await self._get_deactivated_actions_by_type(
|
removals_s3 = await self._get_deactivated_actions_by_type(
|
||||||
current_using_actions,
|
current_using_actions,
|
||||||
chat_content,
|
chat_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 应用第二阶段的移除
|
# 应用第三阶段的移除
|
||||||
for action_name, reason in removals_s2:
|
for action_name, reason in removals_s3:
|
||||||
self.action_manager.remove_action_from_using(action_name)
|
self.action_manager.remove_action_from_using(action_name)
|
||||||
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
|
logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
|
||||||
|
|
||||||
# === 统一日志记录 ===
|
# === 统一日志记录 ===
|
||||||
all_removals = removals_s1 + removals_s2
|
all_removals = removals_s1 + removals_s2 + removals_s3
|
||||||
removals_summary: str = ""
|
removals_summary: str = ""
|
||||||
if all_removals:
|
if all_removals:
|
||||||
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
|
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
|
||||||
@@ -131,7 +135,7 @@ class ActionModifier:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||||
type_mismatched_actions = []
|
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||||
for action_name, action_info in all_actions.items():
|
for action_name, action_info in all_actions.items():
|
||||||
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
||||||
associated_types_str = ", ".join(action_info.associated_types)
|
associated_types_str = ", ".join(action_info.associated_types)
|
||||||
@@ -318,7 +322,7 @@ class ActionModifier:
|
|||||||
action_name: str,
|
action_name: str,
|
||||||
action_info: ActionInfo,
|
action_info: ActionInfo,
|
||||||
chat_content: str = "",
|
chat_content: str = "",
|
||||||
) -> bool:
|
) -> bool: # sourcery skip: move-assign-in-block, use-named-expression
|
||||||
"""
|
"""
|
||||||
使用LLM判定是否应该激活某个action
|
使用LLM判定是否应该激活某个action
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, Any, Optional, Tuple
|
from typing import Dict, Any, Optional, Tuple, List
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
@@ -19,8 +19,8 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ComponentType
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
logger = get_logger("planner")
|
logger = get_logger("planner")
|
||||||
|
|
||||||
@@ -99,7 +99,7 @@ class ActionPlanner:
|
|||||||
|
|
||||||
async def plan(
|
async def plan(
|
||||||
self, mode: ChatMode = ChatMode.FOCUS
|
self, mode: ChatMode = ChatMode.FOCUS
|
||||||
) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]: # sourcery skip: dict-comprehension
|
) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||||
"""
|
"""
|
||||||
@@ -113,16 +113,17 @@ class ActionPlanner:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
is_group_chat = True
|
is_group_chat = True
|
||||||
|
|
||||||
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||||
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
||||||
|
|
||||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||||
|
|
||||||
# 获取完整的动作信息
|
# 获取完整的动作信息
|
||||||
all_registered_actions = self.action_manager.get_registered_actions()
|
all_registered_actions: List[ActionInfo] = list(
|
||||||
|
component_registry.get_components_by_type(ComponentType.ACTION).values() # type: ignore
|
||||||
for action_name in current_available_actions_dict.keys():
|
)
|
||||||
|
current_available_actions = {}
|
||||||
|
for action_name in current_available_actions_dict:
|
||||||
if action_name in all_registered_actions:
|
if action_name in all_registered_actions:
|
||||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||||
else:
|
else:
|
||||||
@@ -234,10 +235,13 @@ class ActionPlanner:
|
|||||||
"is_parallel": is_parallel,
|
"is_parallel": is_parallel,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return (
|
||||||
"action_result": action_result,
|
{
|
||||||
"action_prompt": prompt,
|
"action_result": action_result,
|
||||||
}, target_message
|
"action_prompt": prompt,
|
||||||
|
},
|
||||||
|
target_message,
|
||||||
|
)
|
||||||
|
|
||||||
async def build_planner_prompt(
|
async def build_planner_prompt(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -619,9 +619,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
|||||||
chat_target_info = None
|
chat_target_info = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
if chat_stream := get_chat_manager().get_stream(chat_id):
|
||||||
|
|
||||||
if chat_stream:
|
|
||||||
if chat_stream.group_info:
|
if chat_stream.group_info:
|
||||||
is_group_chat = True
|
is_group_chat = True
|
||||||
chat_target_info = None # Explicitly None for group chat
|
chat_target_info = None # Explicitly None for group chat
|
||||||
@@ -660,8 +658,6 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
|||||||
chat_target_info = target_info
|
chat_target_info = target_info
|
||||||
else:
|
else:
|
||||||
logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
|
logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
|
||||||
# Keep defaults: is_group_chat=False, chat_target_info=None
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
|
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
|
||||||
# Keep defaults on error
|
# Keep defaults on error
|
||||||
|
|||||||
@@ -173,12 +173,10 @@ class Individuality:
|
|||||||
personality = short_impression[0]
|
personality = short_impression[0]
|
||||||
identity = short_impression[1]
|
identity = short_impression[1]
|
||||||
prompt_personality = f"{personality},{identity}"
|
prompt_personality = f"{personality},{identity}"
|
||||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||||
|
|
||||||
return identity_block
|
|
||||||
|
|
||||||
def _get_config_hash(
|
def _get_config_hash(
|
||||||
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list
|
self, bot_nickname: str, personality_core: str, personality_side: str, identity: str
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""获取personality和identity配置的哈希值
|
"""获取personality和identity配置的哈希值
|
||||||
|
|
||||||
@@ -197,7 +195,7 @@ class Individuality:
|
|||||||
|
|
||||||
# 身份配置哈希
|
# 身份配置哈希
|
||||||
identity_config = {
|
identity_config = {
|
||||||
"identity": sorted(identity),
|
"identity": identity,
|
||||||
"compress_identity": self.personality.compress_identity if self.personality else True,
|
"compress_identity": self.personality.compress_identity if self.personality else True,
|
||||||
}
|
}
|
||||||
identity_str = json.dumps(identity_config, sort_keys=True)
|
identity_str = json.dumps(identity_config, sort_keys=True)
|
||||||
@@ -206,7 +204,7 @@ class Individuality:
|
|||||||
return personality_hash, identity_hash
|
return personality_hash, identity_hash
|
||||||
|
|
||||||
async def _check_config_and_clear_if_changed(
|
async def _check_config_and_clear_if_changed(
|
||||||
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list
|
self, bot_nickname: str, personality_core: str, personality_side: str, identity: str
|
||||||
) -> tuple[bool, bool]:
|
) -> tuple[bool, bool]:
|
||||||
"""检查配置是否发生变化,如果变化则清空相应缓存
|
"""检查配置是否发生变化,如果变化则清空相应缓存
|
||||||
|
|
||||||
@@ -321,7 +319,7 @@ class Individuality:
|
|||||||
|
|
||||||
return personality_result
|
return personality_result
|
||||||
|
|
||||||
async def _create_identity(self, identity: list) -> str:
|
async def _create_identity(self, identity: str) -> str:
|
||||||
"""使用LLM创建压缩版本的impression"""
|
"""使用LLM创建压缩版本的impression"""
|
||||||
logger.info("正在构建身份.........")
|
logger.info("正在构建身份.........")
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -10,7 +9,7 @@ class Personality:
|
|||||||
bot_nickname: str # 机器人昵称
|
bot_nickname: str # 机器人昵称
|
||||||
personality_core: str # 人格核心特点
|
personality_core: str # 人格核心特点
|
||||||
personality_side: str # 人格侧面描述
|
personality_side: str # 人格侧面描述
|
||||||
identity: List[str] # 身份细节描述
|
identity: Optional[str] # 身份细节描述
|
||||||
compress_personality: bool # 是否压缩人格
|
compress_personality: bool # 是否压缩人格
|
||||||
compress_identity: bool # 是否压缩身份
|
compress_identity: bool # 是否压缩身份
|
||||||
|
|
||||||
@@ -21,7 +20,7 @@ class Personality:
|
|||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, personality_core: str = "", personality_side: str = "", identity: List[str] = None):
|
def __init__(self, personality_core: str = "", personality_side: str = "", identity: Optional[str] = None):
|
||||||
self.personality_core = personality_core
|
self.personality_core = personality_core
|
||||||
self.personality_side = personality_side
|
self.personality_side = personality_side
|
||||||
self.identity = identity
|
self.identity = identity
|
||||||
@@ -45,7 +44,7 @@ class Personality:
|
|||||||
bot_nickname: str,
|
bot_nickname: str,
|
||||||
personality_core: str,
|
personality_core: str,
|
||||||
personality_side: str,
|
personality_side: str,
|
||||||
identity: List[str] = None,
|
identity: Optional[str] = None,
|
||||||
compress_personality: bool = True,
|
compress_personality: bool = True,
|
||||||
compress_identity: bool = True,
|
compress_identity: bool = True,
|
||||||
) -> "Personality":
|
) -> "Personality":
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from .core import (
|
|||||||
component_registry,
|
component_registry,
|
||||||
dependency_manager,
|
dependency_manager,
|
||||||
events_manager,
|
events_manager,
|
||||||
|
global_announcement_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 导入工具模块
|
# 导入工具模块
|
||||||
@@ -67,6 +68,7 @@ __all__ = [
|
|||||||
"component_registry",
|
"component_registry",
|
||||||
"dependency_manager",
|
"dependency_manager",
|
||||||
"events_manager",
|
"events_manager",
|
||||||
|
"global_announcement_manager",
|
||||||
# 装饰器
|
# 装饰器
|
||||||
"register_plugin",
|
"register_plugin",
|
||||||
"ConfigField",
|
"ConfigField",
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ def register_plugin(cls):
|
|||||||
if "." in plugin_name:
|
if "." in plugin_name:
|
||||||
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||||
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||||
plugin_manager.plugin_classes[plugin_name] = cls
|
|
||||||
splitted_name = cls.__module__.split(".")
|
splitted_name = cls.__module__.split(".")
|
||||||
root_path = Path(__file__)
|
root_path = Path(__file__)
|
||||||
|
|
||||||
@@ -40,6 +39,7 @@ def register_plugin(cls):
|
|||||||
logger.error(f"注册 {plugin_name} 无法找到项目根目录")
|
logger.error(f"注册 {plugin_name} 无法找到项目根目录")
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
plugin_manager.plugin_classes[plugin_name] = cls
|
||||||
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
|
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
|
||||||
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")
|
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")
|
||||||
|
|
||||||
|
|||||||
@@ -65,21 +65,28 @@ class BaseAction(ABC):
|
|||||||
self.thinking_id = thinking_id
|
self.thinking_id = thinking_id
|
||||||
self.log_prefix = log_prefix
|
self.log_prefix = log_prefix
|
||||||
|
|
||||||
# 保存插件配置
|
|
||||||
self.plugin_config = plugin_config or {}
|
self.plugin_config = plugin_config or {}
|
||||||
|
"""对应的插件配置"""
|
||||||
|
|
||||||
# 设置动作基本信息实例属性
|
# 设置动作基本信息实例属性
|
||||||
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
||||||
|
"""Action的名字"""
|
||||||
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
||||||
|
"""Action的描述"""
|
||||||
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
||||||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
||||||
|
|
||||||
# 设置激活类型实例属性(从类属性复制,提供默认值)
|
# 设置激活类型实例属性(从类属性复制,提供默认值)
|
||||||
self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS)
|
self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS)
|
||||||
|
"""FOCUS模式下的激活类型"""
|
||||||
self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS)
|
self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS)
|
||||||
|
"""NORMAL模式下的激活类型"""
|
||||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||||||
|
"""当激活类型为RANDOM时的概率"""
|
||||||
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
|
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
|
||||||
|
"""协助LLM进行判断的Prompt"""
|
||||||
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
||||||
|
"""激活类型为KEYWORD时的KEYWORDS列表"""
|
||||||
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
||||||
self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL)
|
self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL)
|
||||||
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
||||||
|
|||||||
@@ -21,13 +21,18 @@ class BaseCommand(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
command_name: str = ""
|
command_name: str = ""
|
||||||
|
"""Command组件的名称"""
|
||||||
command_description: str = ""
|
command_description: str = ""
|
||||||
|
"""Command组件的描述"""
|
||||||
|
|
||||||
# 默认命令设置(子类可以覆盖)
|
# 默认命令设置(子类可以覆盖)
|
||||||
command_pattern: str = ""
|
command_pattern: str = ""
|
||||||
|
"""命令匹配的正则表达式"""
|
||||||
command_help: str = ""
|
command_help: str = ""
|
||||||
|
"""命令帮助信息"""
|
||||||
command_examples: List[str] = []
|
command_examples: List[str] = []
|
||||||
intercept_message: bool = True # 默认拦截消息,不继续处理
|
intercept_message: bool = True
|
||||||
|
"""是否拦截信息,默认拦截,不进行后续处理"""
|
||||||
|
|
||||||
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||||
"""初始化Command组件
|
"""初始化Command组件
|
||||||
|
|||||||
@@ -13,16 +13,23 @@ class BaseEventHandler(ABC):
|
|||||||
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
|
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知
|
event_type: EventType = EventType.UNKNOWN
|
||||||
handler_name: str = "" # 处理器名称
|
"""事件类型,默认为未知"""
|
||||||
|
handler_name: str = ""
|
||||||
|
"""处理器名称"""
|
||||||
handler_description: str = ""
|
handler_description: str = ""
|
||||||
weight: int = 0 # 权重,数值越大优先级越高
|
"""处理器描述"""
|
||||||
intercept_message: bool = False # 是否拦截消息,默认为否
|
weight: int = 0
|
||||||
|
"""处理器权重,越大权重越高"""
|
||||||
|
intercept_message: bool = False
|
||||||
|
"""是否拦截消息,默认为否"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.log_prefix = "[EventHandler]"
|
self.log_prefix = "[EventHandler]"
|
||||||
self.plugin_name = "" # 对应插件名
|
self.plugin_name = ""
|
||||||
self.plugin_config: Optional[Dict] = None # 插件配置字典
|
"""对应插件名"""
|
||||||
|
self.plugin_config: Optional[Dict] = None
|
||||||
|
"""插件配置字典"""
|
||||||
if self.event_type == EventType.UNKNOWN:
|
if self.event_type == EventType.UNKNOWN:
|
||||||
raise NotImplementedError("事件处理器必须指定 event_type")
|
raise NotImplementedError("事件处理器必须指定 event_type")
|
||||||
|
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
|||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
from src.plugin_system.core.dependency_manager import dependency_manager
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.plugin_system.core.events_manager import events_manager
|
||||||
|
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plugin_manager",
|
"plugin_manager",
|
||||||
"component_registry",
|
"component_registry",
|
||||||
"dependency_manager",
|
"dependency_manager",
|
||||||
"events_manager",
|
"events_manager",
|
||||||
|
"global_announcement_manager",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class ComponentRegistry:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 组件注册表
|
# 组件注册表
|
||||||
self._components: Dict[str, ComponentInfo] = {} # 命名空间式组件名 -> 组件信息
|
self._components: Dict[str, ComponentInfo] = {} # 命名空间式组件名 -> 组件信息
|
||||||
# 类型 -> 命名空间式名称 -> 组件信息
|
# 类型 -> 组件原名称 -> 组件信息
|
||||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
||||||
# 命名空间式组件名 -> 组件类
|
# 命名空间式组件名 -> 组件类
|
||||||
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {}
|
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {}
|
||||||
@@ -110,11 +110,17 @@ class ComponentRegistry:
|
|||||||
# 根据组件类型进行特定注册(使用原始名称)
|
# 根据组件类型进行特定注册(使用原始名称)
|
||||||
match component_type:
|
match component_type:
|
||||||
case ComponentType.ACTION:
|
case ComponentType.ACTION:
|
||||||
ret = self._register_action_component(component_info, component_class) # type: ignore
|
assert isinstance(component_info, ActionInfo)
|
||||||
|
assert issubclass(component_class, BaseAction)
|
||||||
|
ret = self._register_action_component(component_info, component_class)
|
||||||
case ComponentType.COMMAND:
|
case ComponentType.COMMAND:
|
||||||
ret = self._register_command_component(component_info, component_class) # type: ignore
|
assert isinstance(component_info, CommandInfo)
|
||||||
|
assert issubclass(component_class, BaseCommand)
|
||||||
|
ret = self._register_command_component(component_info, component_class)
|
||||||
case ComponentType.EVENT_HANDLER:
|
case ComponentType.EVENT_HANDLER:
|
||||||
ret = self._register_event_handler_component(component_info, component_class) # type: ignore
|
assert isinstance(component_info, EventHandlerInfo)
|
||||||
|
assert issubclass(component_class, BaseEventHandler)
|
||||||
|
ret = self._register_event_handler_component(component_info, component_class)
|
||||||
case _:
|
case _:
|
||||||
logger.warning(f"未知组件类型: {component_type}")
|
logger.warning(f"未知组件类型: {component_type}")
|
||||||
|
|
||||||
@@ -160,7 +166,9 @@ class ComponentRegistry:
|
|||||||
if pattern not in self._command_patterns:
|
if pattern not in self._command_patterns:
|
||||||
self._command_patterns[pattern] = command_name
|
self._command_patterns[pattern] = command_name
|
||||||
else:
|
else:
|
||||||
logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令")
|
logger.warning(
|
||||||
|
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -176,6 +184,10 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
self._event_handler_registry[handler_name] = handler_class
|
self._event_handler_registry[handler_name] = handler_class
|
||||||
|
|
||||||
|
if not handler_info.enabled:
|
||||||
|
logger.warning(f"EventHandler组件 {handler_name} 未启用")
|
||||||
|
return True # 未启用,但是也是注册成功
|
||||||
|
|
||||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||||
|
|
||||||
if events_manager.register_event_subscriber(handler_info, handler_class):
|
if events_manager.register_event_subscriber(handler_info, handler_class):
|
||||||
@@ -185,6 +197,98 @@ class ComponentRegistry:
|
|||||||
logger.error(f"注册事件处理器 {handler_name} 失败")
|
logger.error(f"注册事件处理器 {handler_name} 失败")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# === 组件移除相关 ===
|
||||||
|
|
||||||
|
async def remove_component(self, component_name: str, component_type: ComponentType):
|
||||||
|
target_component_class = self.get_component_class(component_name, component_type)
|
||||||
|
if not target_component_class:
|
||||||
|
logger.warning(f"组件 {component_name} 未注册,无法移除")
|
||||||
|
return
|
||||||
|
match component_type:
|
||||||
|
case ComponentType.ACTION:
|
||||||
|
self._action_registry.pop(component_name, None)
|
||||||
|
self._default_actions.pop(component_name, None)
|
||||||
|
case ComponentType.COMMAND:
|
||||||
|
self._command_registry.pop(component_name, None)
|
||||||
|
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
self._command_patterns.pop(key, None)
|
||||||
|
case ComponentType.EVENT_HANDLER:
|
||||||
|
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||||
|
|
||||||
|
self._event_handler_registry.pop(component_name, None)
|
||||||
|
self._enabled_event_handlers.pop(component_name, None)
|
||||||
|
await events_manager.unregister_event_subscriber(component_name)
|
||||||
|
self._components.pop(component_name, None)
|
||||||
|
self._components_by_type[component_type].pop(component_name, None)
|
||||||
|
self._components_classes.pop(component_name, None)
|
||||||
|
logger.info(f"组件 {component_name} 已移除")
|
||||||
|
|
||||||
|
# === 组件全局启用/禁用方法 ===
|
||||||
|
|
||||||
|
def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
||||||
|
"""全局的启用某个组件
|
||||||
|
Parameters:
|
||||||
|
component_name: 组件名称
|
||||||
|
component_type: 组件类型
|
||||||
|
Returns:
|
||||||
|
bool: 启用成功返回True,失败返回False
|
||||||
|
"""
|
||||||
|
target_component_class = self.get_component_class(component_name, component_type)
|
||||||
|
target_component_info = self.get_component_info(component_name, component_type)
|
||||||
|
if not target_component_class or not target_component_info:
|
||||||
|
logger.warning(f"组件 {component_name} 未注册,无法启用")
|
||||||
|
return False
|
||||||
|
target_component_info.enabled = True
|
||||||
|
match component_type:
|
||||||
|
case ComponentType.ACTION:
|
||||||
|
assert isinstance(target_component_info, ActionInfo)
|
||||||
|
self._default_actions[component_name] = target_component_info
|
||||||
|
case ComponentType.COMMAND:
|
||||||
|
assert isinstance(target_component_info, CommandInfo)
|
||||||
|
pattern = target_component_info.command_pattern
|
||||||
|
self._command_patterns[re.compile(pattern)] = component_name
|
||||||
|
case ComponentType.EVENT_HANDLER:
|
||||||
|
assert isinstance(target_component_info, EventHandlerInfo)
|
||||||
|
assert issubclass(target_component_class, BaseEventHandler)
|
||||||
|
self._enabled_event_handlers[component_name] = target_component_class
|
||||||
|
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||||
|
|
||||||
|
events_manager.register_event_subscriber(target_component_info, target_component_class)
|
||||||
|
self._components[component_name].enabled = True
|
||||||
|
self._components_by_type[component_type][component_name].enabled = True
|
||||||
|
logger.info(f"组件 {component_name} 已启用")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def disable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
||||||
|
"""全局的禁用某个组件
|
||||||
|
Parameters:
|
||||||
|
component_name: 组件名称
|
||||||
|
component_type: 组件类型
|
||||||
|
Returns:
|
||||||
|
bool: 禁用成功返回True,失败返回False
|
||||||
|
"""
|
||||||
|
target_component_class = self.get_component_class(component_name, component_type)
|
||||||
|
target_component_info = self.get_component_info(component_name, component_type)
|
||||||
|
if not target_component_class or not target_component_info:
|
||||||
|
logger.warning(f"组件 {component_name} 未注册,无法禁用")
|
||||||
|
return False
|
||||||
|
target_component_info.enabled = False
|
||||||
|
match component_type:
|
||||||
|
case ComponentType.ACTION:
|
||||||
|
self._default_actions.pop(component_name, None)
|
||||||
|
case ComponentType.COMMAND:
|
||||||
|
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
|
||||||
|
case ComponentType.EVENT_HANDLER:
|
||||||
|
self._enabled_event_handlers.pop(component_name, None)
|
||||||
|
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||||
|
|
||||||
|
await events_manager.unregister_event_subscriber(component_name)
|
||||||
|
self._components[component_name].enabled = False
|
||||||
|
self._components_by_type[component_type][component_name].enabled = False
|
||||||
|
logger.info(f"组件 {component_name} 已禁用")
|
||||||
|
return True
|
||||||
|
|
||||||
# === 组件查询方法 ===
|
# === 组件查询方法 ===
|
||||||
def get_component_info(
|
def get_component_info(
|
||||||
self, component_name: str, component_type: Optional[ComponentType] = None
|
self, component_name: str, component_type: Optional[ComponentType] = None
|
||||||
@@ -287,7 +391,7 @@ class ComponentRegistry:
|
|||||||
# === Action特定查询方法 ===
|
# === Action特定查询方法 ===
|
||||||
|
|
||||||
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
|
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
|
||||||
"""获取Action注册表(用于兼容现有系统)"""
|
"""获取Action注册表"""
|
||||||
return self._action_registry.copy()
|
return self._action_registry.copy()
|
||||||
|
|
||||||
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
||||||
@@ -314,7 +418,7 @@ class ComponentRegistry:
|
|||||||
"""获取Command模式注册表"""
|
"""获取Command模式注册表"""
|
||||||
return self._command_patterns.copy()
|
return self._command_patterns.copy()
|
||||||
|
|
||||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]:
|
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
|
||||||
# sourcery skip: use-named-expression, use-next
|
# sourcery skip: use-named-expression, use-next
|
||||||
"""根据文本查找匹配的命令
|
"""根据文本查找匹配的命令
|
||||||
|
|
||||||
@@ -335,8 +439,7 @@ class ComponentRegistry:
|
|||||||
return (
|
return (
|
||||||
self._command_registry[command_name],
|
self._command_registry[command_name],
|
||||||
candidates[0].match(text).groupdict(), # type: ignore
|
candidates[0].match(text).groupdict(), # type: ignore
|
||||||
command_info.intercept_message,
|
command_info,
|
||||||
command_info.plugin_name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# === 事件处理器特定查询方法 ===
|
# === 事件处理器特定查询方法 ===
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from src.chat.message_receive.message import MessageRecv
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
|
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
|
||||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||||
|
from .global_announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
logger = get_logger("events_manager")
|
logger = get_logger("events_manager")
|
||||||
|
|
||||||
@@ -28,18 +29,16 @@ class EventsManager:
|
|||||||
bool: 是否注册成功
|
bool: 是否注册成功
|
||||||
"""
|
"""
|
||||||
handler_name = handler_info.name
|
handler_name = handler_info.name
|
||||||
plugin_name = getattr(handler_info, "plugin_name", "unknown")
|
|
||||||
|
|
||||||
namespace_name = f"{plugin_name}.{handler_name}"
|
if handler_name in self._handler_mapping:
|
||||||
if namespace_name in self._handler_mapping:
|
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
|
||||||
logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not issubclass(handler_class, BaseEventHandler):
|
if not issubclass(handler_class, BaseEventHandler):
|
||||||
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._handler_mapping[namespace_name] = handler_class
|
self._handler_mapping[handler_name] = handler_class
|
||||||
return self._insert_event_handler(handler_class, handler_info)
|
return self._insert_event_handler(handler_class, handler_info)
|
||||||
|
|
||||||
async def handle_mai_events(
|
async def handle_mai_events(
|
||||||
@@ -55,6 +54,10 @@ class EventsManager:
|
|||||||
continue_flag = True
|
continue_flag = True
|
||||||
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
||||||
for handler in self._events_subscribers.get(event_type, []):
|
for handler in self._events_subscribers.get(event_type, []):
|
||||||
|
if message.chat_stream and message.chat_stream.stream_id:
|
||||||
|
stream_id = message.chat_stream.stream_id
|
||||||
|
if handler.handler_name in global_announcement_manager.get_disabled_chat_event_handlers(stream_id):
|
||||||
|
continue
|
||||||
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
|
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
|
||||||
if handler.intercept_message:
|
if handler.intercept_message:
|
||||||
try:
|
try:
|
||||||
@@ -71,7 +74,7 @@ class EventsManager:
|
|||||||
try:
|
try:
|
||||||
handler_task = asyncio.create_task(handler.execute(transformed_message))
|
handler_task = asyncio.create_task(handler.execute(transformed_message))
|
||||||
handler_task.add_done_callback(self._task_done_callback)
|
handler_task.add_done_callback(self._task_done_callback)
|
||||||
handler_task.set_name(f"EventHandler-{handler.handler_name}-{event_type.name}")
|
handler_task.set_name(f"{handler.plugin_name}-{handler.handler_name}")
|
||||||
self._handler_tasks[handler.handler_name].append(handler_task)
|
self._handler_tasks[handler.handler_name].append(handler_task)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
|
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
|
||||||
@@ -91,7 +94,7 @@ class EventsManager:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
|
def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||||
"""从事件类型列表中移除事件处理器"""
|
"""从事件类型列表中移除事件处理器"""
|
||||||
display_handler_name = handler_class.handler_name or handler_class.__name__
|
display_handler_name = handler_class.handler_name or handler_class.__name__
|
||||||
if handler_class.event_type == EventType.UNKNOWN:
|
if handler_class.event_type == EventType.UNKNOWN:
|
||||||
@@ -190,5 +193,20 @@ class EventsManager:
|
|||||||
finally:
|
finally:
|
||||||
del self._handler_tasks[handler_name]
|
del self._handler_tasks[handler_name]
|
||||||
|
|
||||||
|
async def unregister_event_subscriber(self, handler_name: str) -> bool:
|
||||||
|
"""取消注册事件处理器"""
|
||||||
|
if handler_name not in self._handler_mapping:
|
||||||
|
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self.cancel_handler_tasks(handler_name)
|
||||||
|
|
||||||
|
handler_class = self._handler_mapping.pop(handler_name)
|
||||||
|
if not self._remove_event_handler_instance(handler_class):
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
events_manager = EventsManager()
|
events_manager = EventsManager()
|
||||||
|
|||||||
90
src/plugin_system/core/global_announcement_manager.py
Normal file
90
src/plugin_system/core/global_announcement_manager.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("global_announcement_manager")
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalAnnouncementManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# 用户禁用的动作,chat_id -> [action_name]
|
||||||
|
self._user_disabled_actions: Dict[str, List[str]] = {}
|
||||||
|
# 用户禁用的命令,chat_id -> [command_name]
|
||||||
|
self._user_disabled_commands: Dict[str, List[str]] = {}
|
||||||
|
# 用户禁用的事件处理器,chat_id -> [handler_name]
|
||||||
|
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
|
||||||
|
|
||||||
|
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||||
|
"""禁用特定聊天的某个动作"""
|
||||||
|
if chat_id not in self._user_disabled_actions:
|
||||||
|
self._user_disabled_actions[chat_id] = []
|
||||||
|
if action_name in self._user_disabled_actions[chat_id]:
|
||||||
|
logger.warning(f"动作 {action_name} 已经被禁用")
|
||||||
|
return False
|
||||||
|
self._user_disabled_actions[chat_id].append(action_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def enable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||||
|
"""启用特定聊天的某个动作"""
|
||||||
|
if chat_id in self._user_disabled_actions:
|
||||||
|
try:
|
||||||
|
self._user_disabled_actions[chat_id].remove(action_name)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def disable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||||
|
"""禁用特定聊天的某个命令"""
|
||||||
|
if chat_id not in self._user_disabled_commands:
|
||||||
|
self._user_disabled_commands[chat_id] = []
|
||||||
|
if command_name in self._user_disabled_commands[chat_id]:
|
||||||
|
logger.warning(f"命令 {command_name} 已经被禁用")
|
||||||
|
return False
|
||||||
|
self._user_disabled_commands[chat_id].append(command_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def enable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||||
|
"""启用特定聊天的某个命令"""
|
||||||
|
if chat_id in self._user_disabled_commands:
|
||||||
|
try:
|
||||||
|
self._user_disabled_commands[chat_id].remove(command_name)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def disable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||||
|
"""禁用特定聊天的某个事件处理器"""
|
||||||
|
if chat_id not in self._user_disabled_event_handlers:
|
||||||
|
self._user_disabled_event_handlers[chat_id] = []
|
||||||
|
if handler_name in self._user_disabled_event_handlers[chat_id]:
|
||||||
|
logger.warning(f"事件处理器 {handler_name} 已经被禁用")
|
||||||
|
return False
|
||||||
|
self._user_disabled_event_handlers[chat_id].append(handler_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def enable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||||
|
"""启用特定聊天的某个事件处理器"""
|
||||||
|
if chat_id in self._user_disabled_event_handlers:
|
||||||
|
try:
|
||||||
|
self._user_disabled_event_handlers[chat_id].remove(handler_name)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
|
||||||
|
"""获取特定聊天禁用的所有动作"""
|
||||||
|
return self._user_disabled_actions.get(chat_id, []).copy()
|
||||||
|
|
||||||
|
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
|
||||||
|
"""获取特定聊天禁用的所有命令"""
|
||||||
|
return self._user_disabled_commands.get(chat_id, []).copy()
|
||||||
|
|
||||||
|
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
|
||||||
|
"""获取特定聊天禁用的所有事件处理器"""
|
||||||
|
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||||
|
|
||||||
|
|
||||||
|
global_announcement_manager = GlobalAnnouncementManager()
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import inspect
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple, Type, Any
|
from typing import Dict, List, Optional, Tuple, Type, Any
|
||||||
@@ -8,11 +7,11 @@ from pathlib import Path
|
|||||||
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
|
||||||
from src.plugin_system.base.plugin_base import PluginBase
|
from src.plugin_system.base.plugin_base import PluginBase
|
||||||
from src.plugin_system.base.component_types import ComponentType, PluginInfo, PythonDependency
|
from src.plugin_system.base.component_types import ComponentType, PluginInfo, PythonDependency
|
||||||
from src.plugin_system.utils.manifest_utils import VersionComparator
|
from src.plugin_system.utils.manifest_utils import VersionComparator
|
||||||
|
from .component_registry import component_registry
|
||||||
|
from .dependency_manager import dependency_manager
|
||||||
|
|
||||||
logger = get_logger("plugin_manager")
|
logger = get_logger("plugin_manager")
|
||||||
|
|
||||||
@@ -36,19 +35,7 @@ class PluginManager:
|
|||||||
self._ensure_plugin_directories()
|
self._ensure_plugin_directories()
|
||||||
logger.info("插件管理器初始化完成")
|
logger.info("插件管理器初始化完成")
|
||||||
|
|
||||||
def _ensure_plugin_directories(self) -> None:
|
# === 插件目录管理 ===
|
||||||
"""确保所有插件根目录存在,如果不存在则创建"""
|
|
||||||
default_directories = ["src/plugins/built_in", "plugins"]
|
|
||||||
|
|
||||||
for directory in default_directories:
|
|
||||||
if not os.path.exists(directory):
|
|
||||||
os.makedirs(directory, exist_ok=True)
|
|
||||||
logger.info(f"创建插件根目录: {directory}")
|
|
||||||
if directory not in self.plugin_directories:
|
|
||||||
self.plugin_directories.append(directory)
|
|
||||||
logger.debug(f"已添加插件根目录: {directory}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"根目录不可重复加载: {directory}")
|
|
||||||
|
|
||||||
def add_plugin_directory(self, directory: str) -> bool:
|
def add_plugin_directory(self, directory: str) -> bool:
|
||||||
"""添加插件目录"""
|
"""添加插件目录"""
|
||||||
@@ -63,6 +50,8 @@ class PluginManager:
|
|||||||
logger.warning(f"插件目录不存在: {directory}")
|
logger.warning(f"插件目录不存在: {directory}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# === 插件加载管理 ===
|
||||||
|
|
||||||
def load_all_plugins(self) -> Tuple[int, int]:
|
def load_all_plugins(self) -> Tuple[int, int]:
|
||||||
"""加载所有插件
|
"""加载所有插件
|
||||||
|
|
||||||
@@ -86,7 +75,7 @@ class PluginManager:
|
|||||||
total_failed_registration = 0
|
total_failed_registration = 0
|
||||||
|
|
||||||
for plugin_name in self.plugin_classes.keys():
|
for plugin_name in self.plugin_classes.keys():
|
||||||
load_status, count = self.load_registered_plugin_classes(plugin_name)
|
load_status, count = self._load_registered_plugin_classes(plugin_name)
|
||||||
if load_status:
|
if load_status:
|
||||||
total_registered += 1
|
total_registered += 1
|
||||||
else:
|
else:
|
||||||
@@ -96,90 +85,32 @@ class PluginManager:
|
|||||||
|
|
||||||
return total_registered, total_failed_registration
|
return total_registered, total_failed_registration
|
||||||
|
|
||||||
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
|
async def remove_registered_plugin(self, plugin_name: str) -> None:
|
||||||
# sourcery skip: extract-duplicate-method, extract-method
|
|
||||||
"""
|
"""
|
||||||
加载已经注册的插件类
|
禁用插件模块
|
||||||
"""
|
"""
|
||||||
plugin_class = self.plugin_classes.get(plugin_name)
|
if not plugin_name:
|
||||||
if not plugin_class:
|
raise ValueError("插件名称不能为空")
|
||||||
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
|
if plugin_name not in self.loaded_plugins:
|
||||||
return False, 1
|
logger.warning(f"插件 {plugin_name} 未加载")
|
||||||
try:
|
return
|
||||||
# 使用记录的插件目录路径
|
plugin_instance = self.loaded_plugins[plugin_name]
|
||||||
plugin_dir = self.plugin_paths.get(plugin_name)
|
plugin_info = plugin_instance.plugin_info
|
||||||
|
for component in plugin_info.components:
|
||||||
|
await component_registry.remove_component(component.name, component.component_type)
|
||||||
|
del self.loaded_plugins[plugin_name]
|
||||||
|
|
||||||
# 如果没有记录,直接返回失败
|
async def reload_registered_plugin_module(self, plugin_name: str) -> None:
|
||||||
if not plugin_dir:
|
|
||||||
return False, 1
|
|
||||||
|
|
||||||
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败)
|
|
||||||
if not plugin_instance:
|
|
||||||
logger.error(f"插件 {plugin_name} 实例化失败")
|
|
||||||
return False, 1
|
|
||||||
# 检查插件是否启用
|
|
||||||
if not plugin_instance.enable_plugin:
|
|
||||||
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
|
|
||||||
return False, 0
|
|
||||||
|
|
||||||
# 检查版本兼容性
|
|
||||||
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
|
|
||||||
plugin_name, plugin_instance.manifest_data
|
|
||||||
)
|
|
||||||
if not is_compatible:
|
|
||||||
self.failed_plugins[plugin_name] = compatibility_error
|
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
|
|
||||||
return False, 1
|
|
||||||
if plugin_instance.register_plugin():
|
|
||||||
self.loaded_plugins[plugin_name] = plugin_instance
|
|
||||||
self._show_plugin_components(plugin_name)
|
|
||||||
return True, 1
|
|
||||||
else:
|
|
||||||
self.failed_plugins[plugin_name] = "插件注册失败"
|
|
||||||
logger.error(f"❌ 插件注册失败: {plugin_name}")
|
|
||||||
return False, 1
|
|
||||||
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
# manifest文件缺失
|
|
||||||
error_msg = f"缺少manifest文件: {str(e)}"
|
|
||||||
self.failed_plugins[plugin_name] = error_msg
|
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
|
||||||
return False, 1
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
# manifest文件格式错误或验证失败
|
|
||||||
traceback.print_exc()
|
|
||||||
error_msg = f"manifest验证失败: {str(e)}"
|
|
||||||
self.failed_plugins[plugin_name] = error_msg
|
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
|
||||||
return False, 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 其他错误
|
|
||||||
error_msg = f"未知错误: {str(e)}"
|
|
||||||
self.failed_plugins[plugin_name] = error_msg
|
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
|
||||||
logger.debug("详细错误信息: ", exc_info=True)
|
|
||||||
return False, 1
|
|
||||||
|
|
||||||
def unload_registered_plugin_module(self, plugin_name: str) -> None:
|
|
||||||
"""
|
|
||||||
卸载插件模块
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reload_registered_plugin_module(self, plugin_name: str) -> None:
|
|
||||||
"""
|
"""
|
||||||
重载插件模块
|
重载插件模块
|
||||||
"""
|
"""
|
||||||
self.unload_registered_plugin_module(plugin_name)
|
await self.remove_registered_plugin(plugin_name)
|
||||||
self.load_registered_plugin_classes(plugin_name)
|
self._load_registered_plugin_classes(plugin_name)
|
||||||
|
|
||||||
def rescan_plugin_directory(self) -> None:
|
def rescan_plugin_directory(self) -> None:
|
||||||
"""
|
"""
|
||||||
重新扫描插件根目录
|
重新扫描插件根目录
|
||||||
"""
|
"""
|
||||||
# --------------------------------------- NEED REFACTORING ---------------------------------------
|
|
||||||
for directory in self.plugin_directories:
|
for directory in self.plugin_directories:
|
||||||
if os.path.exists(directory):
|
if os.path.exists(directory):
|
||||||
logger.debug(f"重新扫描插件根目录: {directory}")
|
logger.debug(f"重新扫描插件根目录: {directory}")
|
||||||
@@ -195,30 +126,6 @@ class PluginManager:
|
|||||||
"""获取所有启用的插件信息"""
|
"""获取所有启用的插件信息"""
|
||||||
return list(component_registry.get_enabled_plugins().values())
|
return list(component_registry.get_enabled_plugins().values())
|
||||||
|
|
||||||
# def enable_plugin(self, plugin_name: str) -> bool:
|
|
||||||
# # -------------------------------- NEED REFACTORING --------------------------------
|
|
||||||
# """启用插件"""
|
|
||||||
# if plugin_info := component_registry.get_plugin_info(plugin_name):
|
|
||||||
# plugin_info.enabled = True
|
|
||||||
# # 启用插件的所有组件
|
|
||||||
# for component in plugin_info.components:
|
|
||||||
# component_registry.enable_component(component.name)
|
|
||||||
# logger.debug(f"已启用插件: {plugin_name}")
|
|
||||||
# return True
|
|
||||||
# return False
|
|
||||||
|
|
||||||
# def disable_plugin(self, plugin_name: str) -> bool:
|
|
||||||
# # -------------------------------- NEED REFACTORING --------------------------------
|
|
||||||
# """禁用插件"""
|
|
||||||
# if plugin_info := component_registry.get_plugin_info(plugin_name):
|
|
||||||
# plugin_info.enabled = False
|
|
||||||
# # 禁用插件的所有组件
|
|
||||||
# for component in plugin_info.components:
|
|
||||||
# component_registry.disable_component(component.name)
|
|
||||||
# logger.debug(f"已禁用插件: {plugin_name}")
|
|
||||||
# return True
|
|
||||||
# return False
|
|
||||||
|
|
||||||
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
|
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
|
||||||
"""获取插件实例
|
"""获取插件实例
|
||||||
|
|
||||||
@@ -230,25 +137,6 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
return self.loaded_plugins.get(plugin_name)
|
return self.loaded_plugins.get(plugin_name)
|
||||||
|
|
||||||
def get_plugin_stats(self) -> Dict[str, Any]:
|
|
||||||
"""获取插件统计信息"""
|
|
||||||
all_plugins = component_registry.get_all_plugins()
|
|
||||||
enabled_plugins = component_registry.get_enabled_plugins()
|
|
||||||
|
|
||||||
action_components = component_registry.get_components_by_type(ComponentType.ACTION)
|
|
||||||
command_components = component_registry.get_components_by_type(ComponentType.COMMAND)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_plugins": len(all_plugins),
|
|
||||||
"enabled_plugins": len(enabled_plugins),
|
|
||||||
"failed_plugins": len(self.failed_plugins),
|
|
||||||
"total_components": len(action_components) + len(command_components),
|
|
||||||
"action_components": len(action_components),
|
|
||||||
"command_components": len(command_components),
|
|
||||||
"loaded_plugin_files": len(self.loaded_plugins),
|
|
||||||
"failed_plugin_details": self.failed_plugins.copy(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]:
|
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]:
|
||||||
"""检查所有插件的Python依赖包
|
"""检查所有插件的Python依赖包
|
||||||
|
|
||||||
@@ -347,6 +235,24 @@ class PluginManager:
|
|||||||
|
|
||||||
return dependency_manager.generate_requirements_file(all_dependencies, output_path)
|
return dependency_manager.generate_requirements_file(all_dependencies, output_path)
|
||||||
|
|
||||||
|
# === 私有方法 ===
|
||||||
|
# == 目录管理 ==
|
||||||
|
def _ensure_plugin_directories(self) -> None:
|
||||||
|
"""确保所有插件根目录存在,如果不存在则创建"""
|
||||||
|
default_directories = ["src/plugins/built_in", "plugins"]
|
||||||
|
|
||||||
|
for directory in default_directories:
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
logger.info(f"创建插件根目录: {directory}")
|
||||||
|
if directory not in self.plugin_directories:
|
||||||
|
self.plugin_directories.append(directory)
|
||||||
|
logger.debug(f"已添加插件根目录: {directory}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"根目录不可重复加载: {directory}")
|
||||||
|
|
||||||
|
# == 插件加载 ==
|
||||||
|
|
||||||
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
|
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
|
||||||
"""从指定目录加载插件模块"""
|
"""从指定目录加载插件模块"""
|
||||||
loaded_count = 0
|
loaded_count = 0
|
||||||
@@ -372,18 +278,6 @@ class PluginManager:
|
|||||||
|
|
||||||
return loaded_count, failed_count
|
return loaded_count, failed_count
|
||||||
|
|
||||||
def _find_plugin_directory(self, plugin_class: Type[PluginBase]) -> Optional[str]:
|
|
||||||
"""查找插件类对应的目录路径"""
|
|
||||||
try:
|
|
||||||
# module = getmodule(plugin_class)
|
|
||||||
# if module and hasattr(module, "__file__") and module.__file__:
|
|
||||||
# return os.path.dirname(module.__file__)
|
|
||||||
file_path = inspect.getfile(plugin_class)
|
|
||||||
return os.path.dirname(file_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"通过inspect获取插件目录失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _load_plugin_module_file(self, plugin_file: str) -> bool:
|
def _load_plugin_module_file(self, plugin_file: str) -> bool:
|
||||||
# sourcery skip: extract-method
|
# sourcery skip: extract-method
|
||||||
"""加载单个插件模块文件
|
"""加载单个插件模块文件
|
||||||
@@ -416,6 +310,74 @@ class PluginManager:
|
|||||||
self.failed_plugins[module_name] = error_msg
|
self.failed_plugins[module_name] = error_msg
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
|
||||||
|
# sourcery skip: extract-duplicate-method, extract-method
|
||||||
|
"""
|
||||||
|
加载已经注册的插件类
|
||||||
|
"""
|
||||||
|
plugin_class = self.plugin_classes.get(plugin_name)
|
||||||
|
if not plugin_class:
|
||||||
|
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
|
||||||
|
return False, 1
|
||||||
|
try:
|
||||||
|
# 使用记录的插件目录路径
|
||||||
|
plugin_dir = self.plugin_paths.get(plugin_name)
|
||||||
|
|
||||||
|
# 如果没有记录,直接返回失败
|
||||||
|
if not plugin_dir:
|
||||||
|
return False, 1
|
||||||
|
|
||||||
|
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败)
|
||||||
|
if not plugin_instance:
|
||||||
|
logger.error(f"插件 {plugin_name} 实例化失败")
|
||||||
|
return False, 1
|
||||||
|
# 检查插件是否启用
|
||||||
|
if not plugin_instance.enable_plugin:
|
||||||
|
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
# 检查版本兼容性
|
||||||
|
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
|
||||||
|
plugin_name, plugin_instance.manifest_data
|
||||||
|
)
|
||||||
|
if not is_compatible:
|
||||||
|
self.failed_plugins[plugin_name] = compatibility_error
|
||||||
|
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
|
||||||
|
return False, 1
|
||||||
|
if plugin_instance.register_plugin():
|
||||||
|
self.loaded_plugins[plugin_name] = plugin_instance
|
||||||
|
self._show_plugin_components(plugin_name)
|
||||||
|
return True, 1
|
||||||
|
else:
|
||||||
|
self.failed_plugins[plugin_name] = "插件注册失败"
|
||||||
|
logger.error(f"❌ 插件注册失败: {plugin_name}")
|
||||||
|
return False, 1
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
# manifest文件缺失
|
||||||
|
error_msg = f"缺少manifest文件: {str(e)}"
|
||||||
|
self.failed_plugins[plugin_name] = error_msg
|
||||||
|
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||||
|
return False, 1
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
# manifest文件格式错误或验证失败
|
||||||
|
traceback.print_exc()
|
||||||
|
error_msg = f"manifest验证失败: {str(e)}"
|
||||||
|
self.failed_plugins[plugin_name] = error_msg
|
||||||
|
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||||
|
return False, 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 其他错误
|
||||||
|
error_msg = f"未知错误: {str(e)}"
|
||||||
|
self.failed_plugins[plugin_name] = error_msg
|
||||||
|
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||||
|
logger.debug("详细错误信息: ", exc_info=True)
|
||||||
|
return False, 1
|
||||||
|
|
||||||
|
# == 兼容性检查 ==
|
||||||
|
|
||||||
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
||||||
"""检查插件版本兼容性
|
"""检查插件版本兼容性
|
||||||
|
|
||||||
@@ -451,6 +413,8 @@ class PluginManager:
|
|||||||
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
|
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
|
||||||
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
|
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
|
||||||
|
|
||||||
|
# == 显示统计与插件信息 ==
|
||||||
|
|
||||||
def _show_stats(self, total_registered: int, total_failed_registration: int):
|
def _show_stats(self, total_registered: int, total_failed_registration: int):
|
||||||
# sourcery skip: low-code-quality
|
# sourcery skip: low-code-quality
|
||||||
# 获取组件统计信息
|
# 获取组件统计信息
|
||||||
@@ -493,9 +457,15 @@ class PluginManager:
|
|||||||
|
|
||||||
# 组件列表
|
# 组件列表
|
||||||
if plugin_info.components:
|
if plugin_info.components:
|
||||||
action_components = [c for c in plugin_info.components if c.component_type == ComponentType.ACTION]
|
action_components = [
|
||||||
command_components = [c for c in plugin_info.components if c.component_type == ComponentType.COMMAND]
|
c for c in plugin_info.components if c.component_type == ComponentType.ACTION
|
||||||
event_handler_components = [c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER]
|
]
|
||||||
|
command_components = [
|
||||||
|
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
|
||||||
|
]
|
||||||
|
event_handler_components = [
|
||||||
|
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
|
||||||
|
]
|
||||||
|
|
||||||
if action_components:
|
if action_components:
|
||||||
action_names = [c.name for c in action_components]
|
action_names = [c.name for c in action_components]
|
||||||
@@ -504,7 +474,7 @@ class PluginManager:
|
|||||||
if command_components:
|
if command_components:
|
||||||
command_names = [c.name for c in command_components]
|
command_names = [c.name for c in command_components]
|
||||||
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
||||||
|
|
||||||
if event_handler_components:
|
if event_handler_components:
|
||||||
event_handler_names = [c.name for c in event_handler_components]
|
event_handler_names = [c.name for c in event_handler_components]
|
||||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
||||||
|
|||||||
Reference in New Issue
Block a user