events manager and some typing fix

This commit is contained in:
UnCLAS-Prommer
2025-07-18 14:50:15 +08:00
parent d2b5019c24
commit ffa88b5462
6 changed files with 318 additions and 33 deletions

View File

@@ -33,6 +33,7 @@ RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
DATA_PATH = os.path.join(ROOT_PATH, "data") DATA_PATH = os.path.join(ROOT_PATH, "data")
def _initialize_knowledge_local_storage(): def _initialize_knowledge_local_storage():
""" """
初始化知识库相关的本地存储配置 初始化知识库相关的本地存储配置
@@ -41,55 +42,58 @@ def _initialize_knowledge_local_storage():
# 定义所有需要初始化的配置项 # 定义所有需要初始化的配置项
default_configs = { default_configs = {
# 路径配置 # 路径配置
'root_path': ROOT_PATH, "root_path": ROOT_PATH,
'data_path': f"{ROOT_PATH}/data", "data_path": f"{ROOT_PATH}/data",
# 实体和命名空间配置 # 实体和命名空间配置
'lpmm_invalid_entity': INVALID_ENTITY, "lpmm_invalid_entity": INVALID_ENTITY,
'pg_namespace': PG_NAMESPACE, "pg_namespace": PG_NAMESPACE,
'ent_namespace': ENT_NAMESPACE, "ent_namespace": ENT_NAMESPACE,
'rel_namespace': REL_NAMESPACE, "rel_namespace": REL_NAMESPACE,
# RAG相关命名空间配置 # RAG相关命名空间配置
'rag_graph_namespace': RAG_GRAPH_NAMESPACE, "rag_graph_namespace": RAG_GRAPH_NAMESPACE,
'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE, "rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE,
'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE "rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE,
} }
# 日志级别映射重要配置用info其他用debug # 日志级别映射重要配置用info其他用debug
important_configs = {'root_path', 'data_path'} important_configs = {"root_path", "data_path"}
# 批量设置配置项 # 批量设置配置项
initialized_count = 0 initialized_count = 0
for key, default_value in default_configs.items(): for key, default_value in default_configs.items():
if local_storage[key] is None: if local_storage[key] is None:
local_storage[key] = default_value local_storage[key] = default_value
# 根据重要性选择日志级别 # 根据重要性选择日志级别
if key in important_configs: if key in important_configs:
logger.info(f"设置{key}: {default_value}") logger.info(f"设置{key}: {default_value}")
else: else:
logger.debug(f"设置{key}: {default_value}") logger.debug(f"设置{key}: {default_value}")
initialized_count += 1 initialized_count += 1
if initialized_count > 0: if initialized_count > 0:
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置") logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
else: else:
logger.debug("知识库本地存储配置已存在,跳过初始化") logger.debug("知识库本地存储配置已存在,跳过初始化")
# 初始化本地存储路径 # 初始化本地存储路径
# sourcery skip: dict-comprehension
_initialize_knowledge_local_storage() _initialize_knowledge_local_storage()
qa_manager = None
inspire_manager = None
# 检查LPMM知识库是否启用 # 检查LPMM知识库是否启用
if bot_global_config.lpmm_knowledge.enable: if bot_global_config.lpmm_knowledge.enable:
logger.info("正在初始化Mai-LPMM") logger.info("正在初始化Mai-LPMM")
logger.info("创建LLM客户端") logger.info("创建LLM客户端")
llm_client_list = dict() llm_client_list = {}
for key in global_config["llm_providers"]: for key in global_config["llm_providers"]:
llm_client_list[key] = LLMClient( llm_client_list[key] = LLMClient(
global_config["llm_providers"][key]["base_url"], global_config["llm_providers"][key]["base_url"], # type: ignore
global_config["llm_providers"][key]["api_key"], global_config["llm_providers"][key]["api_key"], # type: ignore
) )
# 初始化Embedding库 # 初始化Embedding库
@@ -98,7 +102,7 @@ if bot_global_config.lpmm_knowledge.enable:
try: try:
embed_manager.load_from_file() embed_manager.load_from_file()
except Exception as e: except Exception as e:
logger.warning("此消息不会影响正常使用从文件加载Embedding库时{}".format(e)) logger.warning(f"此消息不会影响正常使用从文件加载Embedding库时{e}")
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("Embedding库加载完成") logger.info("Embedding库加载完成")
# 初始化KG # 初始化KG
@@ -107,7 +111,7 @@ if bot_global_config.lpmm_knowledge.enable:
try: try:
kg_manager.load_from_file() kg_manager.load_from_file()
except Exception as e: except Exception as e:
logger.warning("此消息不会影响正常使用从文件加载KG时{}".format(e)) logger.warning(f"此消息不会影响正常使用从文件加载KG时{e}")
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("KG加载完成") logger.info("KG加载完成")
@@ -116,7 +120,7 @@ if bot_global_config.lpmm_knowledge.enable:
# 数据比对Embedding库与KG的段落hash集合 # 数据比对Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes: for pg_hash in kg_manager.stored_paragraph_hashes:
key = PG_NAMESPACE + "-" + pg_hash key = f"{PG_NAMESPACE}-{pg_hash}"
if key not in embed_manager.stored_pg_hashes: if key not in embed_manager.stored_pg_hashes:
logger.warning(f"KG中存在Embedding库中不存在的段落{key}") logger.warning(f"KG中存在Embedding库中不存在的段落{key}")
@@ -134,5 +138,3 @@ if bot_global_config.lpmm_knowledge.enable:
else: else:
logger.info("LPMM知识库已禁用跳过初始化") logger.info("LPMM知识库已禁用跳过初始化")
# 创建空的占位符对象,避免导入错误 # 创建空的占位符对象,避免导入错误
qa_manager = None
inspire_manager = None

View File

@@ -1,8 +1,14 @@
from abc import abstractmethod from abc import abstractmethod
from typing import List, Tuple, Type, TYPE_CHECKING
from .plugin_base import PluginBase
from src.common.logger import get_logger from src.common.logger import get_logger
from .plugin_base import PluginBase
from .component_types import EventHandlerInfo
if TYPE_CHECKING:
from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("base_event_plugin")
class BaseEventPlugin(PluginBase): class BaseEventPlugin(PluginBase):
"""基于事件的插件基类 """基于事件的插件基类
@@ -12,3 +18,42 @@ class BaseEventPlugin(PluginBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@abstractmethod
def get_plugin_components(self) -> List[Tuple[EventHandlerInfo, Type[BaseEventHandler]]]:
"""获取插件包含的事件组件
子类必须实现此方法,返回事件组件
Returns:
List[Tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...]
"""
raise NotImplementedError("子类必须实现 get_plugin_components 方法")
def register_plugin(self) -> bool:
"""注册事件插件"""
from src.plugin_system.core.events_manager import events_manager
components = self.get_plugin_components()
# 检查依赖
if not self._check_dependencies():
logger.error(f"{self.log_prefix} 依赖检查失败,跳过注册")
return False
registered_components = []
for handler_info, handler_class in components:
handler_info.plugin_name = self.plugin_name
if events_manager.register_event_subscriber(handler_info, handler_class):
registered_components.append(handler_info)
else:
logger.error(f"{self.log_prefix} 事件处理器 {handler_info.name} 注册失败")
self.plugin_info.components = registered_components
if events_manager.register_plugins(self.plugin_info):
logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个事件处理器")
return True
else:
logger.error(f"{self.log_prefix} 插件注册失败")
return False

View File

@@ -0,0 +1,34 @@
from abc import ABC, abstractmethod
from typing import Tuple, Optional
from src.common.logger import get_logger
from .component_types import MaiMessages, EventType
logger = get_logger("base_event_handler")
class BaseEventHandler(ABC):
"""事件处理器基类
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
"""
event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知
handler_name: str = ""
handler_description: str = ""
weight: int = 0 # 权重,数值越大优先级越高
intercept_message: bool = False # 是否拦截消息,默认为否
def __init__(self):
self.log_prefix = "[EventHandler]"
if self.event_type == EventType.UNKNOWN:
raise NotImplementedError("事件处理器必须指定 event_type")
@abstractmethod
async def execute(self, message: MaiMessages) -> Tuple[bool, Optional[str]]:
"""执行事件处理的抽象方法,子类必须实现
Returns:
Tuple[bool, Optional[str]]: (是否执行成功, 可选的返回消息)
"""
raise NotImplementedError("子类必须实现 execute 方法")

View File

@@ -1,6 +1,7 @@
from enum import Enum from enum import Enum
from typing import Dict, Any, List from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from maim_message import Seg
# 组件类型枚举 # 组件类型枚举
@@ -12,6 +13,9 @@ class ComponentType(Enum):
SCHEDULER = "scheduler" # 定时任务组件(预留) SCHEDULER = "scheduler" # 定时任务组件(预留)
LISTENER = "listener" # 事件监听组件(预留) LISTENER = "listener" # 事件监听组件(预留)
def __str__(self) -> str:
return self.value
# 动作激活类型枚举 # 动作激活类型枚举
class ActionActivationType(Enum): class ActionActivationType(Enum):
@@ -46,12 +50,17 @@ class EventType(Enum):
事件类型枚举类 事件类型枚举类
""" """
ON_START = "on_start" # 启动事件,用于调用按时任务
ON_MESSAGE = "on_message" ON_MESSAGE = "on_message"
ON_PLAN = "on_plan" ON_PLAN = "on_plan"
POST_LLM = "post_llm" POST_LLM = "post_llm"
AFTER_LLM = "after_llm" AFTER_LLM = "after_llm"
POST_SEND = "post_send" POST_SEND = "post_send"
AFTER_SEND = "after_send" AFTER_SEND = "after_send"
UNKNOWN = "unknown" # 未知事件类型
def __str__(self) -> str:
return self.value
@dataclass @dataclass
@@ -142,6 +151,19 @@ class CommandInfo(ComponentInfo):
self.component_type = ComponentType.COMMAND self.component_type = ComponentType.COMMAND
@dataclass
class EventHandlerInfo(ComponentInfo):
"""事件处理器组件信息"""
event_type: EventType = EventType.ON_MESSAGE # 监听事件类型
intercept_message: bool = False # 是否拦截消息处理(默认不拦截)
weight: int = 0 # 事件处理器权重,决定执行顺序
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.LISTENER
@dataclass @dataclass
class PluginInfo: class PluginInfo:
"""插件信息""" """插件信息"""
@@ -198,3 +220,42 @@ class PluginInfo:
def get_pip_requirements(self) -> List[str]: def get_pip_requirements(self) -> List[str]:
"""获取所有pip安装格式的依赖""" """获取所有pip安装格式的依赖"""
return [dep.get_pip_requirement() for dep in self.python_dependencies] return [dep.get_pip_requirement() for dep in self.python_dependencies]
@dataclass
class MaiMessages:
"""MaiM插件消息"""
message_segments: List[Seg] = field(default_factory=list)
"""消息段列表,支持多段消息"""
message_base_info: Dict[str, Any] = field(default_factory=dict)
"""消息基本信息,包含平台,用户信息等数据"""
plain_text: str = ""
"""纯文本消息内容"""
raw_message: Optional[str] = None
"""原始消息内容"""
is_group_message: bool = False
"""是否为群组消息"""
is_private_message: bool = False
"""是否为私聊消息"""
stream_id: Optional[str] = None
"""流ID用于标识消息流"""
llm_prompt: Optional[str] = None
"""LLM提示词"""
llm_response: Optional[str] = None
"""LLM响应内容"""
additional_data: Dict[Any, Any] = field(default_factory=dict)
"""附加数据,可以存储额外信息"""
def __post_init__(self):
if self.message_segments is None:
self.message_segments = []

View File

@@ -1,11 +1,154 @@
from typing import List, Dict, Type import asyncio
from typing import List, Dict, Optional, Type
from src.plugin_system.base.component_types import EventType from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, PluginInfo
from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("events_manager")
class EventsManager: class EventsManager:
def __init__(self): def __init__(self):
# 有权重的 events 订阅者注册表 # 有权重的 events 订阅者注册表
self.events_subscribers: Dict[EventType, List[Dict[int, Type]]] = {event: [] for event in EventType} self.events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
self.handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
self._plugins: Dict[str, PluginInfo] = {} # 插件注册表
events_manager = EventsManager() def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
"""注册事件处理器
Args:
handler_info (EventHandlerInfo): 事件处理器信息
handler_class (Type[BaseEventHandler]): 事件处理器类
Returns:
bool: 是否注册成功
"""
handler_name = handler_info.name
plugin_name = getattr(handler_info, "plugin_name", "unknown")
namespace_name = f"{plugin_name}.{handler_name}"
if namespace_name in self.handler_mapping:
logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册")
return False
if not issubclass(handler_class, BaseEventHandler):
logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类")
return False
self.handler_mapping[namespace_name] = handler_class
return self._insert_event_handler(handler_class)
def register_plugins(self, plugin_info: PluginInfo) -> bool:
"""注册插件
Args:
plugin_info (PluginInfo): 插件信息
Returns:
bool: 是否注册成功
"""
if plugin_info.name in self._plugins:
logger.warning(f"插件 {plugin_info.name} 已存在,跳过注册")
return False
self._plugins[plugin_info.name] = plugin_info
logger.debug(f"插件 {plugin_info.name} 注册成功")
return True
async def handler_mai_events(
self,
event_type: EventType,
message: MessageRecv,
llm_prompt: Optional[str] = None,
llm_response: Optional[str] = None,
) -> None:
"""处理 events"""
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
for handler in self.events_subscribers.get(event_type, []):
if handler.intercept_message:
await handler.execute(transformed_message)
else:
asyncio.create_task(handler.execute(transformed_message))
def _insert_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
"""插入事件处理器到对应的事件类型列表中"""
if handler_class.event_type == EventType.UNKNOWN:
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
return False
self.events_subscribers[handler_class.event_type].append(handler_class())
self.events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
return True
def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
"""从事件类型列表中移除事件处理器"""
if handler_class.event_type == EventType.UNKNOWN:
logger.warning(f"事件处理器 {handler_class.__name__} 的事件类型未知,不存在于处理器列表中")
return False
handlers = self.events_subscribers[handler_class.event_type]
for i, handler in enumerate(handlers):
if isinstance(handler, handler_class):
del handlers[i]
logger.debug(f"事件处理器 {handler_class.__name__} 已移除")
return True
logger.warning(f"未找到事件处理器 {handler_class.__name__},无法移除")
return False
def _transform_event_message(
self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional[str] = None
) -> MaiMessages:
"""转换事件消息格式"""
# 直接赋值部分内容
transformed_message = MaiMessages(
llm_prompt=llm_prompt,
llm_response=llm_response,
raw_message=message.raw_message,
additional_data=message.message_info.additional_config or {},
)
# 消息段处理
if message.message_segment.type == "seglist":
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
else:
transformed_message.message_segments = [message.message_segment]
# stream_id 处理
if hasattr(message, "chat_stream"):
transformed_message.stream_id = message.chat_stream.stream_id
# 处理后文本
transformed_message.plain_text = message.processed_plain_text
# 基本信息
if message.message_info.platform:
transformed_message.message_base_info["platform"] = message.message_info.platform
if message.message_info.group_info:
transformed_message.is_group_message = True
transformed_message.message_base_info.update(
{
"group_id": message.message_info.group_info.group_id,
"group_name": message.message_info.group_info.group_name,
}
)
if message.message_info.user_info:
if not transformed_message.is_group_message:
transformed_message.is_private_message = True
transformed_message.message_base_info.update(
{
"user_id": message.message_info.user_info.user_id,
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名)
}
)
return transformed_message
events_manager = EventsManager()

View File

@@ -33,7 +33,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
Dict: 工具执行结果 Dict: 工具执行结果
""" """
try: try:
query = function_args.get("query") query: str = function_args.get("query") # type: ignore
# threshold = function_args.get("threshold", 0.4) # threshold = function_args.get("threshold", 0.4)
# 检查LPMM知识库是否启用 # 检查LPMM知识库是否启用