增加了组件的全局启用和禁用功能
This commit is contained in:
@@ -46,6 +46,9 @@
|
||||
11. 修正了`command`所编译的`Pattern`注册时的错误输出。
|
||||
12. `events_manager`有了task相关逻辑了。
|
||||
13. 现在有了插件卸载和重载功能了,也就是热插拔。
|
||||
14. 实现了组件的全局启用和禁用功能。
|
||||
- 通过`enable_component`和`disable_component`方法来启用或禁用组件。
|
||||
- 不过这个操作不会保存到配置文件~
|
||||
|
||||
### TODO
|
||||
把这个看起来就很别扭的config获取方式改一下
|
||||
|
||||
@@ -21,7 +21,7 @@ from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api
|
||||
from src.chat.willing.willing_manager import get_willing_manager
|
||||
from src.chat.mai_thinking.mai_think import mai_thinking_manager
|
||||
from maim_message.message_base import GroupInfo,UserInfo
|
||||
from maim_message.message_base import GroupInfo
|
||||
|
||||
ENABLE_THINKING = False
|
||||
|
||||
@@ -259,29 +259,27 @@ class HeartFChatting:
|
||||
return f"{person_name}:{message_data.get('processed_plain_text')}"
|
||||
|
||||
async def send_typing(self):
|
||||
group_info = GroupInfo(platform = "amaidesu_default",group_id = 114514,group_name = "内心")
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform = "amaidesu_default",
|
||||
user_info = None,
|
||||
group_info = group_info
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default",
|
||||
user_info=None, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
|
||||
async def stop_typing(self):
|
||||
group_info = GroupInfo(platform = "amaidesu_default",group_id = 114514,group_name = "内心")
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform = "amaidesu_default",
|
||||
user_info = None,
|
||||
group_info = group_info
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default",
|
||||
user_info=None, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
@@ -364,16 +362,13 @@ class HeartFChatting:
|
||||
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定的回复内容: {content}")
|
||||
|
||||
# 发送回复 (不再需要传入 chat)
|
||||
reply_text = await self._send_response(response_set, reply_to_str, loop_start_time,message_data)
|
||||
reply_text = await self._send_response(response_set, reply_to_str, loop_start_time, message_data)
|
||||
|
||||
await self.stop_typing()
|
||||
|
||||
|
||||
|
||||
if ENABLE_THINKING:
|
||||
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
else:
|
||||
@@ -507,7 +502,6 @@ class HeartFChatting:
|
||||
|
||||
self.willing_manager.setup(message_data, self.chat_stream)
|
||||
|
||||
|
||||
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id", ""))
|
||||
|
||||
talk_frequency = -1.00
|
||||
@@ -546,7 +540,6 @@ class HeartFChatting:
|
||||
self.willing_manager.delete(message_data.get("message_id", ""))
|
||||
return False
|
||||
|
||||
|
||||
async def _generate_response(
|
||||
self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str
|
||||
) -> Optional[list]:
|
||||
@@ -570,7 +563,7 @@ class HeartFChatting:
|
||||
logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def _send_response(self, reply_set, reply_to, thinking_start_time,message_data):
|
||||
async def _send_response(self, reply_set, reply_to, thinking_start_time, message_data):
|
||||
current_time = time.time()
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
|
||||
@@ -592,13 +585,27 @@ class HeartFChatting:
|
||||
if not first_replied:
|
||||
if need_reply:
|
||||
await send_api.text_to_stream(
|
||||
text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_to=reply_to,
|
||||
reply_to_platform_id=reply_to_platform_id,
|
||||
typing=False,
|
||||
)
|
||||
else:
|
||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=False)
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_to_platform_id=reply_to_platform_id,
|
||||
typing=False,
|
||||
)
|
||||
first_replied = True
|
||||
else:
|
||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=True)
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_to_platform_id=reply_to_platform_id,
|
||||
typing=True,
|
||||
)
|
||||
reply_text += data
|
||||
|
||||
return reply_text
|
||||
|
||||
@@ -163,20 +163,25 @@ class ChatManager:
|
||||
"""注册消息到聊天流"""
|
||||
stream_id = self._generate_stream_id(
|
||||
message.message_info.platform, # type: ignore
|
||||
message.message_info.user_info, # type: ignore
|
||||
message.message_info.user_info,
|
||||
message.message_info.group_info,
|
||||
)
|
||||
self.last_messages[stream_id] = message
|
||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||
|
||||
@staticmethod
|
||||
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
||||
def _generate_stream_id(
|
||||
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
|
||||
) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
if not user_info and not group_info:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
|
||||
if group_info:
|
||||
# 组合关键信息
|
||||
components = [platform, str(group_info.group_id)]
|
||||
else:
|
||||
components = [platform, str(user_info.user_id), "private"]
|
||||
components = [platform, str(user_info.user_id), "private"] # type: ignore
|
||||
|
||||
# 使用MD5生成唯一ID
|
||||
key = "_".join(components)
|
||||
|
||||
@@ -110,11 +110,17 @@ class ComponentRegistry:
|
||||
# 根据组件类型进行特定注册(使用原始名称)
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
ret = self._register_action_component(component_info, component_class) # type: ignore
|
||||
assert isinstance(component_info, ActionInfo)
|
||||
assert issubclass(component_class, BaseAction)
|
||||
ret = self._register_action_component(component_info, component_class)
|
||||
case ComponentType.COMMAND:
|
||||
ret = self._register_command_component(component_info, component_class) # type: ignore
|
||||
assert isinstance(component_info, CommandInfo)
|
||||
assert issubclass(component_class, BaseCommand)
|
||||
ret = self._register_command_component(component_info, component_class)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
ret = self._register_event_handler_component(component_info, component_class) # type: ignore
|
||||
assert isinstance(component_info, EventHandlerInfo)
|
||||
assert issubclass(component_class, BaseEventHandler)
|
||||
ret = self._register_event_handler_component(component_info, component_class)
|
||||
case _:
|
||||
logger.warning(f"未知组件类型: {component_type}")
|
||||
|
||||
@@ -218,6 +224,71 @@ class ComponentRegistry:
|
||||
self._components_classes.pop(component_name, None)
|
||||
logger.info(f"组件 {component_name} 已移除")
|
||||
|
||||
# === 组件全局启用/禁用方法 ===
|
||||
|
||||
def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
||||
"""全局的启用某个组件
|
||||
Parameters:
|
||||
component_name: 组件名称
|
||||
component_type: 组件类型
|
||||
Returns:
|
||||
bool: 启用成功返回True,失败返回False
|
||||
"""
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
target_component_info = self.get_component_info(component_name, component_type)
|
||||
if not target_component_class or not target_component_info:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法启用")
|
||||
return False
|
||||
target_component_info.enabled = True
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
assert isinstance(target_component_info, ActionInfo)
|
||||
self._default_actions[component_name] = target_component_info
|
||||
case ComponentType.COMMAND:
|
||||
assert isinstance(target_component_info, CommandInfo)
|
||||
pattern = target_component_info.command_pattern
|
||||
self._command_patterns[re.compile(pattern)] = component_name
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
assert isinstance(target_component_info, EventHandlerInfo)
|
||||
assert issubclass(target_component_class, BaseEventHandler)
|
||||
self._enabled_event_handlers[component_name] = target_component_class
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
events_manager.register_event_subscriber(target_component_info, target_component_class)
|
||||
self._components[component_name].enabled = True
|
||||
self._components_by_type[component_type][component_name].enabled = True
|
||||
logger.info(f"组件 {component_name} 已启用")
|
||||
return True
|
||||
|
||||
async def disable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
||||
"""全局的禁用某个组件
|
||||
Parameters:
|
||||
component_name: 组件名称
|
||||
component_type: 组件类型
|
||||
Returns:
|
||||
bool: 禁用成功返回True,失败返回False
|
||||
"""
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
target_component_info = self.get_component_info(component_name, component_type)
|
||||
if not target_component_class or not target_component_info:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法禁用")
|
||||
return False
|
||||
target_component_info.enabled = False
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
self._default_actions.pop(component_name, None)
|
||||
case ComponentType.COMMAND:
|
||||
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
self._enabled_event_handlers.pop(component_name, None)
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
self._components[component_name].enabled = False
|
||||
self._components_by_type[component_type][component_name].enabled = False
|
||||
logger.info(f"组件 {component_name} 已禁用")
|
||||
return True
|
||||
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
self, component_name: str, component_type: Optional[ComponentType] = None
|
||||
|
||||
Reference in New Issue
Block a user