From ffa88b5462854d26274b183a34af66c29ad581f3 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Fri, 18 Jul 2025 14:50:15 +0800 Subject: [PATCH] events manager and some typing fix --- src/chat/knowledge/knowledge_lib.py | 54 ++++--- src/plugin_system/base/base_event_plugin.py | 47 +++++- src/plugin_system/base/base_events_handler.py | 34 ++++ src/plugin_system/base/component_types.py | 63 +++++++- src/plugin_system/core/events_manager.py | 151 +++++++++++++++++- src/tools/not_using/lpmm_get_knowledge.py | 2 +- 6 files changed, 318 insertions(+), 33 deletions(-) create mode 100644 src/plugin_system/base/base_events_handler.py diff --git a/src/chat/knowledge/knowledge_lib.py b/src/chat/knowledge/knowledge_lib.py index 180a16ca1..1e87d3824 100644 --- a/src/chat/knowledge/knowledge_lib.py +++ b/src/chat/knowledge/knowledge_lib.py @@ -33,6 +33,7 @@ RAG_PG_HASH_NAMESPACE = "rag-pg-hash" ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) DATA_PATH = os.path.join(ROOT_PATH, "data") + def _initialize_knowledge_local_storage(): """ 初始化知识库相关的本地存储配置 @@ -41,55 +42,58 @@ def _initialize_knowledge_local_storage(): # 定义所有需要初始化的配置项 default_configs = { # 路径配置 - 'root_path': ROOT_PATH, - 'data_path': f"{ROOT_PATH}/data", - + "root_path": ROOT_PATH, + "data_path": f"{ROOT_PATH}/data", # 实体和命名空间配置 - 'lpmm_invalid_entity': INVALID_ENTITY, - 'pg_namespace': PG_NAMESPACE, - 'ent_namespace': ENT_NAMESPACE, - 'rel_namespace': REL_NAMESPACE, - + "lpmm_invalid_entity": INVALID_ENTITY, + "pg_namespace": PG_NAMESPACE, + "ent_namespace": ENT_NAMESPACE, + "rel_namespace": REL_NAMESPACE, # RAG相关命名空间配置 - 'rag_graph_namespace': RAG_GRAPH_NAMESPACE, - 'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE, - 'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE + "rag_graph_namespace": RAG_GRAPH_NAMESPACE, + "rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE, + "rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE, } - + # 日志级别映射:重要配置用info,其他用debug - important_configs = {'root_path', 'data_path'} - + important_configs = {"root_path", "data_path"} + # 批量设置配置项 initialized_count = 0 for key, default_value in default_configs.items(): if local_storage[key] is None: local_storage[key] = default_value - + # 根据重要性选择日志级别 if key in important_configs: logger.info(f"设置{key}: {default_value}") else: logger.debug(f"设置{key}: {default_value}") - + initialized_count += 1 - + if initialized_count > 0: logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置") else: logger.debug("知识库本地存储配置已存在,跳过初始化") - + + # 初始化本地存储路径 +# sourcery skip: dict-comprehension _initialize_knowledge_local_storage() +qa_manager = None +inspire_manager = None + # 检查LPMM知识库是否启用 if bot_global_config.lpmm_knowledge.enable: logger.info("正在初始化Mai-LPMM") logger.info("创建LLM客户端") - llm_client_list = dict() + llm_client_list = {} for key in global_config["llm_providers"]: llm_client_list[key] = LLMClient( - global_config["llm_providers"][key]["base_url"], - global_config["llm_providers"][key]["api_key"], + global_config["llm_providers"][key]["base_url"], # type: ignore + global_config["llm_providers"][key]["api_key"], # type: ignore ) # 初始化Embedding库 @@ -98,7 +102,7 @@ if bot_global_config.lpmm_knowledge.enable: try: embed_manager.load_from_file() except Exception as e: - logger.warning("此消息不会影响正常使用:从文件加载Embedding库时,{}".format(e)) + logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}") # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") logger.info("Embedding库加载完成") # 初始化KG @@ -107,7 +111,7 @@ if bot_global_config.lpmm_knowledge.enable: try: kg_manager.load_from_file() except Exception as e: - logger.warning("此消息不会影响正常使用:从文件加载KG时,{}".format(e)) + logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}") # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") logger.info("KG加载完成") @@ -116,7 +120,7 @@ if bot_global_config.lpmm_knowledge.enable: # 数据比对:Embedding库与KG的段落hash集合 for pg_hash in kg_manager.stored_paragraph_hashes: - key = PG_NAMESPACE + "-" + pg_hash + key = f"{PG_NAMESPACE}-{pg_hash}" if key not in embed_manager.stored_pg_hashes: logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") @@ -134,5 +138,3 @@ if bot_global_config.lpmm_knowledge.enable: else: logger.info("LPMM知识库已禁用,跳过初始化") # 创建空的占位符对象,避免导入错误 - qa_manager = None - inspire_manager = None diff --git a/src/plugin_system/base/base_event_plugin.py b/src/plugin_system/base/base_event_plugin.py index 859d43f06..38ab116cc 100644 --- a/src/plugin_system/base/base_event_plugin.py +++ b/src/plugin_system/base/base_event_plugin.py @@ -1,8 +1,14 @@ from abc import abstractmethod +from typing import List, Tuple, Type, TYPE_CHECKING -from .plugin_base import PluginBase from src.common.logger import get_logger +from .plugin_base import PluginBase +from .component_types import EventHandlerInfo +if TYPE_CHECKING: + from src.plugin_system.base.base_events_handler import BaseEventHandler + +logger = get_logger("base_event_plugin") class BaseEventPlugin(PluginBase): """基于事件的插件基类 @@ -12,3 +18,42 @@ class BaseEventPlugin(PluginBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + + @abstractmethod + def get_plugin_components(self) -> List[Tuple[EventHandlerInfo, Type[BaseEventHandler]]]: + """获取插件包含的事件组件 + + 子类必须实现此方法,返回事件组件 + + Returns: + List[Tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...] + """ + raise NotImplementedError("子类必须实现 get_plugin_components 方法") + + def register_plugin(self) -> bool: + """注册事件插件""" + from src.plugin_system.core.events_manager import events_manager + + components = self.get_plugin_components() + + # 检查依赖 + if not self._check_dependencies(): + logger.error(f"{self.log_prefix} 依赖检查失败,跳过注册") + return False + + registered_components = [] + for handler_info, handler_class in components: + handler_info.plugin_name = self.plugin_name + if events_manager.register_event_subscriber(handler_info, handler_class): + registered_components.append(handler_info) + else: + logger.error(f"{self.log_prefix} 事件处理器 {handler_info.name} 注册失败") + + self.plugin_info.components = registered_components + + if events_manager.register_plugins(self.plugin_info): + logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个事件处理器") + return True + else: + logger.error(f"{self.log_prefix} 插件注册失败") + return False \ No newline at end of file diff --git a/src/plugin_system/base/base_events_handler.py b/src/plugin_system/base/base_events_handler.py new file mode 100644 index 000000000..2541d9abb --- /dev/null +++ b/src/plugin_system/base/base_events_handler.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from typing import Tuple, Optional + +from src.common.logger import get_logger +from .component_types import MaiMessages, EventType + +logger = get_logger("base_event_handler") + + +class BaseEventHandler(ABC): + """事件处理器基类 + + 所有事件处理器都应该继承这个基类,提供事件处理的基本接口 + """ + + event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知 + handler_name: str = "" + handler_description: str = "" + weight: int = 0 # 权重,数值越大优先级越高 + intercept_message: bool = False # 是否拦截消息,默认为否 + + def __init__(self): + self.log_prefix = "[EventHandler]" + if self.event_type == EventType.UNKNOWN: + raise NotImplementedError("事件处理器必须指定 event_type") + + @abstractmethod + async def execute(self, message: MaiMessages) -> Tuple[bool, Optional[str]]: + """执行事件处理的抽象方法,子类必须实现 + + Returns: + Tuple[bool, Optional[str]]: (是否执行成功, 可选的返回消息) + """ + raise NotImplementedError("子类必须实现 execute 方法") diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 14025ed99..2b7636eb7 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -1,6 +1,7 @@ from enum import Enum -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional from dataclasses import dataclass, field +from maim_message import Seg # 组件类型枚举 @@ -12,6 +13,9 @@ class ComponentType(Enum): SCHEDULER = "scheduler" # 定时任务组件(预留) LISTENER = "listener" # 事件监听组件(预留) + def __str__(self) -> str: + return self.value + # 动作激活类型枚举 class ActionActivationType(Enum): @@ -46,12 +50,17 @@ class EventType(Enum): 事件类型枚举类 """ + ON_START = "on_start" # 启动事件,用于调用按时任务 ON_MESSAGE = "on_message" ON_PLAN = "on_plan" POST_LLM = "post_llm" AFTER_LLM = "after_llm" POST_SEND = "post_send" AFTER_SEND = "after_send" + UNKNOWN = "unknown" # 未知事件类型 + + def __str__(self) -> str: + return self.value @dataclass @@ -142,6 +151,19 @@ class CommandInfo(ComponentInfo): self.component_type = ComponentType.COMMAND +@dataclass +class EventHandlerInfo(ComponentInfo): + """事件处理器组件信息""" + + event_type: EventType = EventType.ON_MESSAGE # 监听事件类型 + intercept_message: bool = False # 是否拦截消息处理(默认不拦截) + weight: int = 0 # 事件处理器权重,决定执行顺序 + + def __post_init__(self): + super().__post_init__() + self.component_type = ComponentType.LISTENER + + @dataclass class PluginInfo: """插件信息""" @@ -198,3 +220,42 @@ class PluginInfo: def get_pip_requirements(self) -> List[str]: """获取所有pip安装格式的依赖""" return [dep.get_pip_requirement() for dep in self.python_dependencies] + + +@dataclass +class MaiMessages: + """MaiM插件消息""" + + message_segments: List[Seg] = field(default_factory=list) + """消息段列表,支持多段消息""" + + message_base_info: Dict[str, Any] = field(default_factory=dict) + """消息基本信息,包含平台,用户信息等数据""" + + plain_text: str = "" + """纯文本消息内容""" + + raw_message: Optional[str] = None + """原始消息内容""" + + is_group_message: bool = False + """是否为群组消息""" + + is_private_message: bool = False + """是否为私聊消息""" + + stream_id: Optional[str] = None + """流ID,用于标识消息流""" + + llm_prompt: Optional[str] = None + """LLM提示词""" + + llm_response: Optional[str] = None + """LLM响应内容""" + + additional_data: Dict[Any, Any] = field(default_factory=dict) + """附加数据,可以存储额外信息""" + + def __post_init__(self): + if self.message_segments is None: + self.message_segments = [] diff --git a/src/plugin_system/core/events_manager.py b/src/plugin_system/core/events_manager.py index 1b96da44c..5143d765c 100644 --- a/src/plugin_system/core/events_manager.py +++ b/src/plugin_system/core/events_manager.py @@ -1,11 +1,154 @@ -from typing import List, Dict, Type +import asyncio +from typing import List, Dict, Optional, Type -from src.plugin_system.base.component_types import EventType +from src.chat.message_receive.message import MessageRecv +from src.common.logger import get_logger +from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, PluginInfo +from src.plugin_system.base.base_events_handler import BaseEventHandler + +logger = get_logger("events_manager") class EventsManager: def __init__(self): # 有权重的 events 订阅者注册表 - self.events_subscribers: Dict[EventType, List[Dict[int, Type]]] = {event: [] for event in EventType} + self.events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType} + self.handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表 + self._plugins: Dict[str, PluginInfo] = {} # 插件注册表 -events_manager = EventsManager() \ No newline at end of file + def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool: + """注册事件处理器 + + Args: + handler_info (EventHandlerInfo): 事件处理器信息 + handler_class (Type[BaseEventHandler]): 事件处理器类 + + Returns: + bool: 是否注册成功 + """ + handler_name = handler_info.name + plugin_name = getattr(handler_info, "plugin_name", "unknown") + + namespace_name = f"{plugin_name}.{handler_name}" + if namespace_name in self.handler_mapping: + logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册") + return False + + if not issubclass(handler_class, BaseEventHandler): + logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类") + return False + + self.handler_mapping[namespace_name] = handler_class + + return self._insert_event_handler(handler_class) + + def register_plugins(self, plugin_info: PluginInfo) -> bool: + """注册插件 + + Args: + plugin_info (PluginInfo): 插件信息 + + Returns: + bool: 是否注册成功 + """ + if plugin_info.name in self._plugins: + logger.warning(f"插件 {plugin_info.name} 已存在,跳过注册") + return False + + self._plugins[plugin_info.name] = plugin_info + logger.debug(f"插件 {plugin_info.name} 注册成功") + return True + + async def handler_mai_events( + self, + event_type: EventType, + message: MessageRecv, + llm_prompt: Optional[str] = None, + llm_response: Optional[str] = None, + ) -> None: + """处理 events""" + transformed_message = self._transform_event_message(message, llm_prompt, llm_response) + for handler in self.events_subscribers.get(event_type, []): + if handler.intercept_message: + await handler.execute(transformed_message) + else: + asyncio.create_task(handler.execute(transformed_message)) + + def _insert_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool: + """插入事件处理器到对应的事件类型列表中""" + if handler_class.event_type == EventType.UNKNOWN: + logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册") + return False + + self.events_subscribers[handler_class.event_type].append(handler_class()) + self.events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True) + + return True + + def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool: + """从事件类型列表中移除事件处理器""" + if handler_class.event_type == EventType.UNKNOWN: + logger.warning(f"事件处理器 {handler_class.__name__} 的事件类型未知,不存在于处理器列表中") + return False + + handlers = self.events_subscribers[handler_class.event_type] + for i, handler in enumerate(handlers): + if isinstance(handler, handler_class): + del handlers[i] + logger.debug(f"事件处理器 {handler_class.__name__} 已移除") + return True + + logger.warning(f"未找到事件处理器 {handler_class.__name__},无法移除") + return False + + def _transform_event_message( + self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional[str] = None + ) -> MaiMessages: + """转换事件消息格式""" + # 直接赋值部分内容 + transformed_message = MaiMessages( + llm_prompt=llm_prompt, + llm_response=llm_response, + raw_message=message.raw_message, + additional_data=message.message_info.additional_config or {}, + ) + + # 消息段处理 + if message.message_segment.type == "seglist": + transformed_message.message_segments = list(message.message_segment.data) # type: ignore + else: + transformed_message.message_segments = [message.message_segment] + + # stream_id 处理 + if hasattr(message, "chat_stream"): + transformed_message.stream_id = message.chat_stream.stream_id + + # 处理后文本 + transformed_message.plain_text = message.processed_plain_text + + # 基本信息 + if message.message_info.platform: + transformed_message.message_base_info["platform"] = message.message_info.platform + if message.message_info.group_info: + transformed_message.is_group_message = True + transformed_message.message_base_info.update( + { + "group_id": message.message_info.group_info.group_id, + "group_name": message.message_info.group_info.group_name, + } + ) + if message.message_info.user_info: + if not transformed_message.is_group_message: + transformed_message.is_private_message = True + transformed_message.message_base_info.update( + { + "user_id": message.message_info.user_info.user_id, + "user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称 + "user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名) + } + ) + + return transformed_message + + +events_manager = EventsManager() diff --git a/src/tools/not_using/lpmm_get_knowledge.py b/src/tools/not_using/lpmm_get_knowledge.py index 180c5e699..467db6ed1 100644 --- a/src/tools/not_using/lpmm_get_knowledge.py +++ b/src/tools/not_using/lpmm_get_knowledge.py @@ -33,7 +33,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool): Dict: 工具执行结果 """ try: - query = function_args.get("query") + query: str = function_args.get("query") # type: ignore # threshold = function_args.get("threshold", 0.4) # 检查LPMM知识库是否启用