增加样例插件,修复统计数据(部分),修复一个bug
This commit is contained in:
@@ -7,7 +7,11 @@ from src.plugin_system import (
|
|||||||
ComponentInfo,
|
ComponentInfo,
|
||||||
ActionActivationType,
|
ActionActivationType,
|
||||||
ConfigField,
|
ConfigField,
|
||||||
|
BaseEventPlugin,
|
||||||
|
BaseEventHandler,
|
||||||
|
EventType,
|
||||||
)
|
)
|
||||||
|
from src.plugin_system.base.component_types import MaiMessages
|
||||||
|
|
||||||
# ===== Action组件 =====
|
# ===== Action组件 =====
|
||||||
|
|
||||||
@@ -93,6 +97,20 @@ class TimeCommand(BaseCommand):
|
|||||||
return True, f"显示了当前时间: {time_str}"
|
return True, f"显示了当前时间: {time_str}"
|
||||||
|
|
||||||
|
|
||||||
|
class PrintMessage(BaseEventHandler):
|
||||||
|
"""打印消息事件处理器 - 处理打印消息事件"""
|
||||||
|
|
||||||
|
event_type = EventType.ON_MESSAGE
|
||||||
|
handler_name = "print_message_handler"
|
||||||
|
handler_description = "打印接收到的消息"
|
||||||
|
|
||||||
|
async def execute(self, message: MaiMessages) -> Tuple[bool, str | None]:
|
||||||
|
"""执行打印消息事件处理"""
|
||||||
|
# 打印接收到的消息
|
||||||
|
print(f"接收到消息: {message.raw_message}")
|
||||||
|
return True, "消息已打印"
|
||||||
|
|
||||||
|
|
||||||
# ===== 插件注册 =====
|
# ===== 插件注册 =====
|
||||||
|
|
||||||
|
|
||||||
@@ -130,3 +148,25 @@ class HelloWorldPlugin(BasePlugin):
|
|||||||
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
|
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
|
||||||
(TimeCommand.get_command_info(), TimeCommand),
|
(TimeCommand.get_command_info(), TimeCommand),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@register_plugin
|
||||||
|
class HelloWorldEventPlugin(BaseEventPlugin):
|
||||||
|
"""Hello World事件插件 - 处理问候和告别事件"""
|
||||||
|
|
||||||
|
plugin_name = "hello_world_event_plugin"
|
||||||
|
enable_plugin = False
|
||||||
|
dependencies = []
|
||||||
|
python_dependencies = []
|
||||||
|
config_file_name = "event_config.toml"
|
||||||
|
|
||||||
|
config_schema = {
|
||||||
|
"plugin": {
|
||||||
|
"name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"),
|
||||||
|
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||||
|
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||||
|
return [(PrintMessage.get_handler_info(), PrintMessage)]
|
||||||
|
|||||||
@@ -179,8 +179,7 @@ class HeartFChatting:
|
|||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
if self.loop_mode == ChatMode.NORMAL:
|
if self.loop_mode == ChatMode.NORMAL:
|
||||||
self.energy_value -= 0.3
|
self.energy_value -= 0.3
|
||||||
if self.energy_value <= 0.3:
|
self.energy_value = max(self.energy_value, 0.3)
|
||||||
self.energy_value = 0.3
|
|
||||||
|
|
||||||
def print_cycle_info(self, cycle_timers):
|
def print_cycle_info(self, cycle_timers):
|
||||||
# 记录循环信息和计时器结果
|
# 记录循环信息和计时器结果
|
||||||
@@ -257,6 +256,7 @@ class HeartFChatting:
|
|||||||
return f"{person_name}:{message_data.get('processed_plain_text')}"
|
return f"{person_name}:{message_data.get('processed_plain_text')}"
|
||||||
|
|
||||||
async def _observe(self, message_data: Optional[Dict[str, Any]] = None):
|
async def _observe(self, message_data: Optional[Dict[str, Any]] = None):
|
||||||
|
# sourcery skip: hoist-statement-from-if, merge-comparisons, reintroduce-else
|
||||||
if not message_data:
|
if not message_data:
|
||||||
message_data = {}
|
message_data = {}
|
||||||
action_type = "no_action"
|
action_type = "no_action"
|
||||||
|
|||||||
@@ -629,7 +629,7 @@ class LLMRequest:
|
|||||||
)
|
)
|
||||||
# 安全地检查和记录请求详情
|
# 安全地检查和记录请求详情
|
||||||
handled_payload = await _safely_record(request_content, payload)
|
handled_payload = await _safely_record(request_content, payload)
|
||||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload[:100]}")
|
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}")
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
|
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
|
||||||
)
|
)
|
||||||
@@ -643,7 +643,7 @@ class LLMRequest:
|
|||||||
logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
|
logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
|
||||||
# 安全地检查和记录请求详情
|
# 安全地检查和记录请求详情
|
||||||
handled_payload = await _safely_record(request_content, payload)
|
handled_payload = await _safely_record(request_content, payload)
|
||||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload[:100]}")
|
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}")
|
||||||
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
|
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
|
||||||
|
|
||||||
async def _transform_parameters(self, params: dict) -> dict:
|
async def _transform_parameters(self, params: dict) -> dict:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ logger = get_logger("plugin_register")
|
|||||||
def register_plugin(cls):
|
def register_plugin(cls):
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
from src.plugin_system.base.base_plugin import BasePlugin
|
from src.plugin_system.base.base_plugin import BasePlugin
|
||||||
|
from src.plugin_system.base.base_event_plugin import BaseEventPlugin
|
||||||
|
|
||||||
"""插件注册装饰器
|
"""插件注册装饰器
|
||||||
|
|
||||||
@@ -18,7 +19,7 @@ def register_plugin(cls):
|
|||||||
plugin_description = "我的插件"
|
plugin_description = "我的插件"
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
if not issubclass(cls, BasePlugin):
|
if not issubclass(cls, BasePlugin) and not issubclass(cls, BaseEventPlugin):
|
||||||
logger.error(f"类 {cls.__name__} 不是 BasePlugin 的子类")
|
logger.error(f"类 {cls.__name__} 不是 BasePlugin 的子类")
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from .component_types import MaiMessages, EventType
|
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType
|
||||||
|
|
||||||
logger = get_logger("base_event_handler")
|
logger = get_logger("base_event_handler")
|
||||||
|
|
||||||
@@ -14,7 +14,7 @@ class BaseEventHandler(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知
|
event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知
|
||||||
handler_name: str = ""
|
handler_name: str = "" # 处理器名称
|
||||||
handler_description: str = ""
|
handler_description: str = ""
|
||||||
weight: int = 0 # 权重,数值越大优先级越高
|
weight: int = 0 # 权重,数值越大优先级越高
|
||||||
intercept_message: bool = False # 是否拦截消息,默认为否
|
intercept_message: bool = False # 是否拦截消息,默认为否
|
||||||
@@ -32,3 +32,17 @@ class BaseEventHandler(ABC):
|
|||||||
Tuple[bool, Optional[str]]: (是否执行成功, 可选的返回消息)
|
Tuple[bool, Optional[str]]: (是否执行成功, 可选的返回消息)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("子类必须实现 execute 方法")
|
raise NotImplementedError("子类必须实现 execute 方法")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_handler_info(cls) -> "EventHandlerInfo":
|
||||||
|
"""获取事件处理器的信息"""
|
||||||
|
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
||||||
|
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
|
||||||
|
return EventHandlerInfo(
|
||||||
|
name=name,
|
||||||
|
component_type=ComponentType.LISTENER,
|
||||||
|
description=getattr(cls, "handler_description", "events处理器"),
|
||||||
|
event_type=cls.event_type,
|
||||||
|
weight=cls.weight,
|
||||||
|
intercept_message=cls.intercept_message,
|
||||||
|
)
|
||||||
|
|||||||
@@ -512,6 +512,12 @@ class PluginManager:
|
|||||||
config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌"
|
config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌"
|
||||||
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
|
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
|
||||||
|
|
||||||
|
root_path = Path(__file__)
|
||||||
|
|
||||||
|
# 查找项目根目录
|
||||||
|
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
|
||||||
|
root_path = root_path.parent
|
||||||
|
|
||||||
# 显示目录统计
|
# 显示目录统计
|
||||||
logger.info("📂 加载目录统计:")
|
logger.info("📂 加载目录统计:")
|
||||||
for directory in self.plugin_directories:
|
for directory in self.plugin_directories:
|
||||||
@@ -519,7 +525,11 @@ class PluginManager:
|
|||||||
plugins_in_dir = []
|
plugins_in_dir = []
|
||||||
for plugin_name in self.loaded_plugins.keys():
|
for plugin_name in self.loaded_plugins.keys():
|
||||||
plugin_path = self.plugin_paths.get(plugin_name, "")
|
plugin_path = self.plugin_paths.get(plugin_name, "")
|
||||||
if plugin_path.startswith(directory):
|
if (
|
||||||
|
Path(plugin_path)
|
||||||
|
.resolve()
|
||||||
|
.is_relative_to(Path(os.path.join(str(root_path), directory)).resolve())
|
||||||
|
):
|
||||||
plugins_in_dir.append(plugin_name)
|
plugins_in_dir.append(plugin_name)
|
||||||
|
|
||||||
if plugins_in_dir:
|
if plugins_in_dir:
|
||||||
|
|||||||
Reference in New Issue
Block a user