修复代码格式和文件名大小写问题

This commit is contained in:
Windpicker-owo
2025-08-31 20:50:17 +08:00
parent a187130613
commit fe472dff60
213 changed files with 6897 additions and 8252 deletions

View File

@@ -48,7 +48,7 @@ from .utils.dependency_config import get_dependency_config, configure_dependency
from .apis import (
chat_api,
tool_api,
tool_api,
component_manage_api,
config_api,
database_api,
@@ -91,8 +91,8 @@ __all__ = [
# 增强命令系统
"PlusCommand",
"CommandArgs",
"PlusCommandAdapter",
"create_plus_command_adapter",
"PlusCommandAdapter",
"create_plus_command_adapter",
"create_plus_command_adapter",
# 类型定义
"ComponentType",

View File

@@ -9,19 +9,7 @@
注意此模块现在使用SQLAlchemy实现提供更好的连接管理和错误处理
"""
from src.common.database.sqlalchemy_database_api import (
db_query,
db_save,
db_get,
store_action_info,
MODEL_MAPPING
)
from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get, store_action_info, MODEL_MAPPING
# 保持向后兼容性
__all__ = [
'db_query',
'db_save',
'db_get',
'store_action_info',
'MODEL_MAPPING'
]
__all__ = ["db_query", "db_save", "db_get", "store_action_info", "MODEL_MAPPING"]

View File

@@ -72,7 +72,9 @@ async def generate_with_model(
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens)
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(
prompt, temperature=temperature, max_tokens=max_tokens
)
return True, response, reasoning_content, model_name
except Exception as e:
@@ -80,6 +82,7 @@ async def generate_with_model(
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", ""
async def generate_with_model_with_tools(
prompt: str,
model_config: TaskConfig,
@@ -109,10 +112,7 @@ async def generate_with_model_with_tools(
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
prompt,
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens
prompt, tools=tool_options, temperature=temperature, max_tokens=max_tokens
)
return True, response, reasoning_content, model_name, tool_call

View File

@@ -97,7 +97,9 @@ def get_messages_by_time_in_chat(
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command))
return filter_mai_messages(
get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
)
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
@@ -137,9 +139,13 @@ def get_messages_by_time_in_chat_inclusive(
raise ValueError("chat_id 必须是字符串类型")
if filter_mai:
return filter_mai_messages(
get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command)
get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id, start_time, end_time, limit, limit_mode, filter_command
)
)
return get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command)
return get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id, start_time, end_time, limit, limit_mode, filter_command
)
def get_messages_by_time_in_chat_for_users(

View File

@@ -17,12 +17,14 @@ logger = get_logger(__name__)
class PermissionLevel(Enum):
"""权限等级枚举"""
MASTER = "master" # 最高权限,无视所有权限节点
@dataclass
class PermissionNode:
"""权限节点数据类"""
node_name: str # 权限节点名称,如 "plugin.example.command.test"
description: str # 权限节点描述
plugin_name: str # 所属插件名称
@@ -32,13 +34,14 @@ class PermissionNode:
@dataclass
class UserInfo:
"""用户信息数据类"""
platform: str # 平台类型,如 "qq"
user_id: str # 用户ID
def __post_init__(self):
"""确保user_id是字符串类型"""
self.user_id = str(self.user_id)
def to_tuple(self) -> tuple[str, str]:
"""转换为元组格式"""
return (self.platform, self.user_id)
@@ -46,106 +49,106 @@ class UserInfo:
class IPermissionManager(ABC):
"""权限管理器接口"""
@abstractmethod
def check_permission(self, user: UserInfo, permission_node: str) -> bool:
"""
检查用户是否拥有指定权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 是否拥有权限
"""
pass
@abstractmethod
def is_master(self, user: UserInfo) -> bool:
"""
检查用户是否为Master用户
Args:
user: 用户信息
Returns:
bool: 是否为Master用户
"""
pass
@abstractmethod
def register_permission_node(self, node: PermissionNode) -> bool:
"""
注册权限节点
Args:
node: 权限节点
Returns:
bool: 注册是否成功
"""
pass
@abstractmethod
def grant_permission(self, user: UserInfo, permission_node: str) -> bool:
"""
授权用户权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 授权是否成功
"""
pass
@abstractmethod
def revoke_permission(self, user: UserInfo, permission_node: str) -> bool:
"""
撤销用户权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 撤销是否成功
"""
pass
@abstractmethod
def get_user_permissions(self, user: UserInfo) -> List[str]:
"""
获取用户拥有的所有权限节点
Args:
user: 用户信息
Returns:
List[str]: 权限节点列表
"""
pass
@abstractmethod
def get_all_permission_nodes(self) -> List[PermissionNode]:
"""
获取所有已注册的权限节点
Returns:
List[PermissionNode]: 权限节点列表
"""
pass
@abstractmethod
def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
"""
获取指定插件的所有权限节点
Args:
plugin_name: 插件名称
Returns:
List[PermissionNode]: 权限节点列表
"""
@@ -154,146 +157,144 @@ class IPermissionManager(ABC):
class PermissionAPI:
"""权限系统API类"""
def __init__(self):
self._permission_manager: Optional[IPermissionManager] = None
def set_permission_manager(self, manager: IPermissionManager):
"""设置权限管理器实例"""
self._permission_manager = manager
logger.info("权限管理器已设置")
def _ensure_manager(self):
"""确保权限管理器已设置"""
if self._permission_manager is None:
raise RuntimeError("权限管理器未设置,请先调用 set_permission_manager")
def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
"""
检查用户是否拥有指定权限节点
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
permission_node: 权限节点名称
Returns:
bool: 是否拥有权限
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id))
return self._permission_manager.check_permission(user, permission_node)
def is_master(self, platform: str, user_id: str) -> bool:
"""
检查用户是否为Master用户
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
Returns:
bool: 是否为Master用户
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id))
return self._permission_manager.is_master(user)
def register_permission_node(self, node_name: str, description: str, plugin_name: str,
default_granted: bool = False) -> bool:
def register_permission_node(
self, node_name: str, description: str, plugin_name: str, default_granted: bool = False
) -> bool:
"""
注册权限节点
Args:
node_name: 权限节点名称,如 "plugin.example.command.test"
description: 权限节点描述
plugin_name: 所属插件名称
default_granted: 默认是否授权
Returns:
bool: 注册是否成功
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager()
node = PermissionNode(
node_name=node_name,
description=description,
plugin_name=plugin_name,
default_granted=default_granted
node_name=node_name, description=description, plugin_name=plugin_name, default_granted=default_granted
)
return self._permission_manager.register_permission_node(node)
def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
"""
授权用户权限节点
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
permission_node: 权限节点名称
Returns:
bool: 授权是否成功
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id))
return self._permission_manager.grant_permission(user, permission_node)
def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
"""
撤销用户权限节点
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
permission_node: 权限节点名称
Returns:
bool: 撤销是否成功
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id))
return self._permission_manager.revoke_permission(user, permission_node)
def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
"""
获取用户拥有的所有权限节点
Args:
platform: 平台类型,如 "qq"
user_id: 用户ID
Returns:
List[str]: 权限节点列表
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
self._ensure_manager()
user = UserInfo(platform=platform, user_id=str(user_id))
return self._permission_manager.get_user_permissions(user)
def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
"""
获取所有已注册的权限节点
Returns:
List[Dict[str, Any]]: 权限节点列表,每个节点包含 node_name, description, plugin_name, default_granted
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
@@ -304,21 +305,21 @@ class PermissionAPI:
"node_name": node.node_name,
"description": node.description,
"plugin_name": node.plugin_name,
"default_granted": node.default_granted
"default_granted": node.default_granted,
}
for node in nodes
]
def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
"""
获取指定插件的所有权限节点
Args:
plugin_name: 插件名称
Returns:
List[Dict[str, Any]]: 权限节点列表
Raises:
RuntimeError: 权限管理器未设置时抛出
"""
@@ -329,7 +330,7 @@ class PermissionAPI:
"node_name": node.node_name,
"description": node.description,
"plugin_name": node.plugin_name,
"default_granted": node.default_granted
"default_granted": node.default_granted,
}
for node in nodes
]

View File

@@ -34,7 +34,7 @@ def get_plugin_path(plugin_name: str) -> str:
Returns:
str: 插件目录的绝对路径。
Raises:
ValueError: 如果插件不存在。
"""

View File

@@ -2,7 +2,7 @@ from pathlib import Path
from src.common.logger import get_logger
logger = get_logger("plugin_manager") # 复用plugin_manager名称
logger = get_logger("plugin_manager") # 复用plugin_manager名称
def register_plugin(cls):

View File

@@ -64,7 +64,7 @@ async def wait_adapter_response(request_id: str, timeout: float = 30.0) -> dict:
"""等待适配器响应"""
future = asyncio.Future()
_adapter_response_pool[request_id] = future
try:
response = await asyncio.wait_for(future, timeout=timeout)
return response
@@ -369,10 +369,10 @@ async def adapter_command_to_stream(
platform: Optional[str] = "qq",
stream_id: Optional[str] = None,
timeout: float = 30.0,
storage_message: bool = False
storage_message: bool = False,
) -> dict:
"""向适配器发送命令并获取返回值
雅诺狐的耳朵特别软
Args:
@@ -388,20 +388,20 @@ async def adapter_command_to_stream(
- 成功: {"status": "ok", "data": {...}, "message": "..."}
- 失败: {"status": "failed", "message": "错误信息"}
- 错误: {"status": "error", "message": "错误信息"}
Raises:
ValueError: 当stream_id和platform都未提供时抛出
"""
if not stream_id and not platform:
raise ValueError("必须提供stream_id或platform参数")
try:
try:
logger.debug(f"[SendAPI] 向适配器发送命令: {action}")
# 如果没有提供stream_id则生成一个临时的
if stream_id is None:
import uuid
stream_id = f"adapter_temp_{uuid.uuid4().hex[:8]}"
logger.debug(f"[SendAPI] 自动生成临时stream_id: {stream_id}")
@@ -411,22 +411,15 @@ async def adapter_command_to_stream(
# 如果是自动生成的stream_id且找不到聊天流创建一个临时的虚拟流
if stream_id.startswith("adapter_temp_"):
logger.debug(f"[SendAPI] 创建临时虚拟聊天流: {stream_id}")
# 创建临时的用户信息和聊天流
temp_user_info = UserInfo(
user_id="system",
user_nickname="System",
platform=platform
)
temp_user_info = UserInfo(user_id="system", user_nickname="System", platform=platform)
temp_chat_stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=temp_user_info,
group_info=None
stream_id=stream_id, platform=platform, user_info=temp_user_info, group_info=None
)
target_stream = temp_chat_stream
else:
logger.error(f"[SendAPI] 未找到聊天流: {stream_id}")
@@ -474,10 +467,7 @@ async def adapter_command_to_stream(
# 发送消息
sent_msg = await heart_fc_sender.send_message(
bot_message,
typing=False,
set_reply=False,
storage_message=storage_message
bot_message, typing=False, set_reply=False, storage_message=storage_message
)
if not sent_msg:
@@ -488,9 +478,9 @@ async def adapter_command_to_stream(
# 等待适配器响应
response = await wait_adapter_response(message_id, timeout)
logger.debug(f"[SendAPI] 收到适配器响应: {response}")
return response
except Exception as e:

View File

@@ -31,4 +31,4 @@ def get_llm_available_tool_definitions():
from src.plugin_system.core import component_registry
llm_available_tools = component_registry.get_llm_available_tools()
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]

View File

@@ -147,7 +147,7 @@ class BaseAction(ABC):
logger.debug(
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
)
# 验证聊天类型限制
if not self._validate_chat_type():
logger.warning(
@@ -157,7 +157,7 @@ class BaseAction(ABC):
def _validate_chat_type(self) -> bool:
"""验证当前聊天类型是否允许执行此Action
Returns:
bool: 如果允许执行返回True否则返回False
"""
@@ -172,9 +172,9 @@ class BaseAction(ABC):
def is_chat_type_allowed(self) -> bool:
"""检查当前聊天类型是否允许执行此Action
这是一个公开的方法,供外部调用检查聊天类型限制
Returns:
bool: 如果允许执行返回True否则返回False
"""
@@ -240,9 +240,7 @@ class BaseAction(ABC):
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
return False, f"等待新消息失败: {str(e)}"
async def send_text(
self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None, typing: bool = False
) -> bool:
async def send_text(self, content: str, reply_to: str = "", typing: bool = False) -> bool:
"""发送文本消息
Args:

View File

@@ -46,10 +46,10 @@ class BaseCommand(ABC):
self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
logger.debug(f"{self.log_prefix} Command组件初始化完成")
# 验证聊天类型限制
if not self._validate_chat_type():
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
logger.warning(
f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: "
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
@@ -65,16 +65,16 @@ class BaseCommand(ABC):
def _validate_chat_type(self) -> bool:
"""验证当前聊天类型是否允许执行此Command
Returns:
bool: 如果允许执行返回True否则返回False
"""
if self.chat_type_allow == ChatType.ALL:
return True
# 检查是否为群聊消息
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
if self.chat_type_allow == ChatType.GROUP and is_group:
return True
elif self.chat_type_allow == ChatType.PRIVATE and not is_group:
@@ -84,9 +84,9 @@ class BaseCommand(ABC):
def is_chat_type_allowed(self) -> bool:
"""检查当前聊天类型是否允许执行此Command
这是一个公开的方法,供外部调用检查聊天类型限制
Returns:
bool: 如果允许执行返回True否则返回False
"""

View File

@@ -3,12 +3,14 @@ from typing import List, Dict, Any, Optional
from src.common.logger import get_logger
logger = get_logger("base_event")
class HandlerResult:
"""事件处理器执行结果
所有事件处理器必须返回此类的实例
"""
def __init__(self, success: bool, continue_process: bool, message: Any = None, handler_name: str = ""):
self.success = success
self.continue_process = continue_process
@@ -18,31 +20,32 @@ class HandlerResult:
def __repr__(self):
return f"HandlerResult(success={self.success}, continue_process={self.continue_process}, message='{self.message}', handler_name='{self.handler_name}')"
class HandlerResultsCollection:
"""HandlerResult集合提供便捷的查询方法"""
def __init__(self, results: List[HandlerResult]):
self.results = results
def all_continue_process(self) -> bool:
"""检查是否所有handler的continue_process都为True"""
return all(result.continue_process for result in self.results)
def get_all_results(self) -> List[HandlerResult]:
"""获取所有HandlerResult"""
return self.results
def get_failed_handlers(self) -> List[HandlerResult]:
"""获取执行失败的handler结果"""
return [result for result in self.results if not result.success]
def get_stopped_handlers(self) -> List[HandlerResult]:
"""获取continue_process为False的handler结果"""
return [result for result in self.results if not result.continue_process]
def get_message_result(self) -> Any:
"""获取handler的message
当只有一个handler的结果时直接返回那个handler结果中的message字段
否则用字典的形式{handler_name:message}返回
"""
@@ -52,22 +55,22 @@ class HandlerResultsCollection:
return self.results[0].message
else:
return {result.handler_name: result.message for result in self.results}
def get_handler_result(self, handler_name: str) -> Optional[HandlerResult]:
"""获取指定handler的结果"""
for result in self.results:
if result.handler_name == handler_name:
return result
return None
def get_success_count(self) -> int:
"""获取成功执行的handler数量"""
return sum(1 for result in self.results if result.success)
def get_failure_count(self) -> int:
"""获取执行失败的handler数量"""
return sum(1 for result in self.results if not result.success)
def get_summary(self) -> Dict[str, Any]:
"""获取执行摘要"""
return {
@@ -76,62 +79,63 @@ class HandlerResultsCollection:
"failure_count": self.get_failure_count(),
"continue_process": self.all_continue_process(),
"failed_handlers": [r.handler_name for r in self.get_failed_handlers()],
"stopped_handlers": [r.handler_name for r in self.get_stopped_handlers()]
"stopped_handlers": [r.handler_name for r in self.get_stopped_handlers()],
}
class BaseEvent:
def __init__(
self,
name: str,
allowed_subscribers: List[str] = None,
allowed_triggers: List[str] = None
):
def __init__(self, name: str, allowed_subscribers: List[str] = None, allowed_triggers: List[str] = None):
self.name = name
self.enabled = True
self.allowed_subscribers = allowed_subscribers # 记录事件处理器名
self.allowed_triggers = allowed_triggers # 记录插件名
from src.plugin_system.base.base_events_handler import BaseEventHandler
self.subscribers: List["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表
self.subscribers: List["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表
self.event_handle_lock = asyncio.Lock()
def __name__(self):
return self.name
async def activate(self, params: dict) -> HandlerResultsCollection:
"""激活事件,执行所有订阅的处理器
Args:
params: 传递给处理器的参数
Returns:
HandlerResultsCollection: 所有处理器的执行结果集合
"""
if not self.enabled:
return HandlerResultsCollection([])
# 使用锁确保同一个事件不能同时激活多次
async with self.event_handle_lock:
# 按权重从高到低排序订阅者
# 使用直接属性访问,-1代表自动权重
sorted_subscribers = sorted(self.subscribers, key=lambda h: h.weight if hasattr(h, 'weight') and h.weight != -1 else 0, reverse=True)
sorted_subscribers = sorted(
self.subscribers, key=lambda h: h.weight if hasattr(h, "weight") and h.weight != -1 else 0, reverse=True
)
# 并行执行所有订阅者
tasks = []
for subscriber in sorted_subscribers:
# 为每个订阅者创建执行任务
task = self._execute_subscriber(subscriber, params)
tasks.append(task)
# 等待所有任务完成
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理执行结果
processed_results = []
for i, result in enumerate(results):
subscriber = sorted_subscribers[i]
handler_name = subscriber.handler_name if hasattr(subscriber, 'handler_name') else subscriber.__class__.__name__
handler_name = (
subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__
)
if result:
if isinstance(result, Exception):
# 处理执行异常
@@ -143,13 +147,13 @@ class BaseEvent:
# 补充handler_name
result.handler_name = handler_name
processed_results.append(result)
return HandlerResultsCollection(processed_results)
async def _execute_subscriber(self, subscriber, params: dict) -> HandlerResult:
"""执行单个订阅者处理器"""
try:
return await subscriber.execute(params)
except Exception as e:
# 异常会在 gather 中捕获,这里直接抛出让 gather 处理
raise e
raise e

View File

@@ -51,11 +51,11 @@ class BaseEventHandler(ABC):
event_name (str): 要订阅的事件名称
"""
from src.plugin_system.core.event_manager import event_manager
if not event_manager.subscribe_handler_to_event(self.handler_name, event_name):
logger.error(f"事件处理器 {self.handler_name} 订阅事件 {event_name} 失败")
return
logger.debug(f"{self.log_prefix} 订阅事件 {event_name}")
self.subscribed_events.append(event_name)
@@ -66,7 +66,7 @@ class BaseEventHandler(ABC):
event_name (str): 要取消订阅的事件名称
"""
from src.plugin_system.core.event_manager import event_manager
if event_manager.unsubscribe_handler_from_event(self.handler_name, event_name):
logger.debug(f"{self.log_prefix} 取消订阅事件 {event_name}")
if event_name in self.subscribed_events:

View File

@@ -9,32 +9,32 @@ import shlex
class CommandArgs:
"""命令参数解析类
提供方便的方法来处理命令参数
"""
def __init__(self, raw_args: str = ""):
"""初始化命令参数
Args:
raw_args: 原始参数字符串
"""
self._raw_args = raw_args.strip()
self._parsed_args: Optional[List[str]] = None
def get_raw(self) -> str:
"""获取完整的参数字符串
Returns:
str: 原始参数字符串
"""
return self._raw_args
def get_args(self) -> List[str]:
"""获取解析后的参数列表
将参数按空格分割,支持引号包围的参数
Returns:
List[str]: 参数列表
"""
@@ -48,25 +48,25 @@ class CommandArgs:
except ValueError:
# 如果shlex解析失败fallback到简单的split
self._parsed_args = self._raw_args.split()
return self._parsed_args
@property
def is_empty(self) -> bool:
"""检查参数是否为空
Returns:
bool: 如果没有参数返回True
"""
return len(self.get_args()) == 0
def get_arg(self, index: int, default: str = "") -> str:
"""获取指定索引的参数
Args:
index: 参数索引从0开始
default: 默认值
Returns:
str: 参数值或默认值
"""
@@ -78,21 +78,21 @@ class CommandArgs:
@property
def get_first(self, default: str = "") -> str:
"""获取第一个参数
Args:
default: 默认值
Returns:
str: 第一个参数或默认值
"""
return self.get_arg(0, default)
def get_remaining(self, start_index: int = 0) -> str:
"""获取从指定索引开始的剩余参数字符串
Args:
start_index: 起始索引
Returns:
str: 剩余参数组成的字符串
"""
@@ -100,45 +100,45 @@ class CommandArgs:
if start_index < len(args):
return " ".join(args[start_index:])
return ""
def count(self) -> int:
"""获取参数数量
Returns:
int: 参数数量
"""
return len(self.get_args())
def has_flag(self, flag: str) -> bool:
"""检查是否包含指定的标志参数
Args:
flag: 标志名(如 "--verbose""-v"
Returns:
bool: 如果包含该标志返回True
"""
return flag in self.get_args()
def get_flag_value(self, flag: str, default: str = "") -> str:
"""获取标志参数的值
查找 --key=value 或 --key value 形式的参数
Args:
flag: 标志名(如 "--output"
default: 默认值
Returns:
str: 标志的值或默认值
"""
args = self.get_args()
# 查找 --key=value 形式
for arg in args:
if arg.startswith(f"{flag}="):
return arg[len(flag) + 1:]
return arg[len(flag) + 1 :]
# 查找 --key value 形式
try:
flag_index = args.index(flag)
@@ -146,13 +146,13 @@ class CommandArgs:
return args[flag_index + 1]
except ValueError:
pass
return default
def __str__(self) -> str:
"""字符串表示"""
return self._raw_args
def __repr__(self) -> str:
"""调试表示"""
return f"CommandArgs(raw='{self._raw_args}', parsed={self.get_args()})"

View File

@@ -6,6 +6,7 @@ from maim_message import Seg
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
# 组件类型枚举
class ComponentType(Enum):
"""组件类型枚举"""
@@ -185,7 +186,9 @@ class PlusCommandInfo(ComponentInfo):
class ToolInfo(ComponentInfo):
"""工具组件信息"""
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(
default_factory=list
) # 工具参数定义
tool_description: str = "" # 工具描述
def __post_init__(self):
@@ -205,6 +208,7 @@ class EventHandlerInfo(ComponentInfo):
super().__post_init__()
self.component_type = ComponentType.EVENT_HANDLER
@dataclass
class EventInfo(ComponentInfo):
"""事件组件信息"""
@@ -213,6 +217,7 @@ class EventInfo(ComponentInfo):
super().__post_init__()
self.component_type = ComponentType.EVENT
# 事件类型枚举
class EventType(Enum):
"""
@@ -232,6 +237,7 @@ class EventType(Enum):
def __str__(self) -> str:
return self.value
@dataclass
class PluginInfo:
"""插件信息"""
@@ -320,16 +326,16 @@ class MaiMessages:
llm_response_content: Optional[str] = None
"""LLM响应内容"""
llm_response_reasoning: Optional[str] = None
"""LLM响应推理内容"""
llm_response_model: Optional[str] = None
"""LLM响应模型名称"""
llm_response_tool_call: Optional[List[ToolCall]] = None
"""LLM使用的工具调用"""
action_usage: Optional[List[str]] = None
"""使用的Action"""

View File

@@ -90,10 +90,10 @@ class PluginBase(ABC):
# 标准化Python依赖为PythonDependency对象
normalized_python_deps = self._normalize_python_dependencies(self.python_dependencies)
# 检查Python依赖
self._check_python_dependencies(normalized_python_deps)
# 创建插件信息对象
self.plugin_info = PluginInfo(
name=self.plugin_name,
@@ -560,7 +560,7 @@ class PluginBase(ABC):
def _normalize_python_dependencies(self, dependencies: Any) -> List[PythonDependency]:
"""将依赖列表标准化为PythonDependency对象"""
from packaging.requirements import Requirement
normalized = []
for dep in dependencies:
if isinstance(dep, str):
@@ -568,23 +568,22 @@ class PluginBase(ABC):
# 尝试解析为requirement格式 (如 "package>=1.0.0")
req = Requirement(dep)
version_spec = str(req.specifier) if req.specifier else ""
normalized.append(PythonDependency(
package_name=req.name,
version=version_spec,
install_name=dep # 保持原始的安装名称
))
normalized.append(
PythonDependency(
package_name=req.name,
version=version_spec,
install_name=dep, # 保持原始的安装名称
)
)
except Exception:
# 如果解析失败,作为简单包名处理
normalized.append(PythonDependency(
package_name=dep,
install_name=dep
))
normalized.append(PythonDependency(package_name=dep, install_name=dep))
elif isinstance(dep, PythonDependency):
normalized.append(dep)
else:
logger.warning(f"{self.log_prefix} 未知的依赖格式: {dep}")
return normalized
def _check_python_dependencies(self, dependencies: List[PythonDependency]) -> bool:
@@ -596,10 +595,10 @@ class PluginBase(ABC):
try:
# 延迟导入以避免循环依赖
from src.plugin_system.utils.dependency_manager import get_dependency_manager
dependency_manager = get_dependency_manager()
success, errors = dependency_manager.check_and_install_dependencies(dependencies, self.plugin_name)
if success:
logger.info(f"{self.log_prefix} Python依赖检查通过")
return True
@@ -608,7 +607,7 @@ class PluginBase(ABC):
for error in errors:
logger.error(f"{self.log_prefix} - {error}")
return False
except Exception as e:
logger.error(f"{self.log_prefix} Python依赖检查时发生异常: {e}", exc_info=True)
return False

View File

@@ -20,12 +20,12 @@ logger = get_logger("plus_command")
class PlusCommand(ABC):
"""增强版命令基类
提供更简单的命令定义方式,无需手写正则表达式
子类只需要定义:
- command_name: 命令名称
- command_description: 命令描述
- command_description: 命令描述
- command_aliases: 命令别名列表(可选)
- priority: 优先级(可选,数字越大优先级越高)
- chat_type_allow: 允许的聊天类型(可选)
@@ -35,19 +35,19 @@ class PlusCommand(ABC):
# 子类需要定义的属性
command_name: str = ""
"""命令名称,如 'echo'"""
command_description: str = ""
"""命令描述"""
command_aliases: List[str] = []
"""命令别名列表,如 ['say', 'repeat']"""
priority: int = 0
"""命令优先级,数字越大优先级越高"""
chat_type_allow: ChatType = ChatType.ALL
"""允许的聊天类型"""
intercept_message: bool = False
"""是否拦截消息,不进行后续处理"""
@@ -61,13 +61,13 @@ class PlusCommand(ABC):
self.message = message
self.plugin_config = plugin_config or {}
self.log_prefix = "[PlusCommand]"
# 解析命令参数
self._parse_command()
# 验证聊天类型限制
if not self._validate_chat_type():
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
logger.warning(
f"{self.log_prefix} 命令 '{self.command_name}' 不支持当前聊天类型: "
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
@@ -75,59 +75,59 @@ class PlusCommand(ABC):
def _parse_command(self) -> None:
"""解析命令和参数"""
if not hasattr(self.message, 'plain_text') or not self.message.plain_text:
if not hasattr(self.message, "plain_text") or not self.message.plain_text:
self.args = CommandArgs("")
return
plain_text = self.message.plain_text.strip()
# 获取配置的命令前缀
prefixes = global_config.command.command_prefixes
# 检查是否以任何前缀开头
matched_prefix = None
for prefix in prefixes:
if plain_text.startswith(prefix):
matched_prefix = prefix
break
if not matched_prefix:
self.args = CommandArgs("")
return
# 移除前缀
command_part = plain_text[len(matched_prefix):].strip()
command_part = plain_text[len(matched_prefix) :].strip()
# 分离命令名和参数
parts = command_part.split(None, 1)
if not parts:
self.args = CommandArgs("")
return
command_word = parts[0].lower()
args_text = parts[1] if len(parts) > 1 else ""
# 检查命令名是否匹配
all_commands = [self.command_name.lower()] + [alias.lower() for alias in self.command_aliases]
if command_word not in all_commands:
self.args = CommandArgs("")
return
# 创建参数对象
self.args = CommandArgs(args_text)
def _validate_chat_type(self) -> bool:
"""验证当前聊天类型是否允许执行此命令
Returns:
bool: 如果允许执行返回True否则返回False
"""
if self.chat_type_allow == ChatType.ALL:
return True
# 检查是否为群聊消息
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
if self.chat_type_allow == ChatType.GROUP and is_group:
return True
elif self.chat_type_allow == ChatType.PRIVATE and not is_group:
@@ -137,7 +137,7 @@ class PlusCommand(ABC):
def is_chat_type_allowed(self) -> bool:
"""检查当前聊天类型是否允许执行此命令
Returns:
bool: 如果允许执行返回True否则返回False
"""
@@ -145,30 +145,30 @@ class PlusCommand(ABC):
def is_command_match(self) -> bool:
"""检查当前消息是否匹配此命令
Returns:
bool: 如果匹配返回True
"""
return not self.args.is_empty() or self._is_exact_command_call()
def _is_exact_command_call(self) -> bool:
"""检查是否是精确的命令调用(无参数)"""
if not hasattr(self.message, 'plain_text') or not self.message.plain_text:
if not hasattr(self.message, "plain_text") or not self.message.plain_text:
return False
plain_text = self.message.plain_text.strip()
# 获取配置的命令前缀
prefixes = global_config.command.command_prefixes
# 检查每个前缀
for prefix in prefixes:
if plain_text.startswith(prefix):
command_part = plain_text[len(prefix):].strip()
command_part = plain_text[len(prefix) :].strip()
all_commands = [self.command_name.lower()] + [alias.lower() for alias in self.command_aliases]
if command_part.lower() in all_commands:
return True
return False
@abstractmethod
@@ -298,10 +298,10 @@ class PlusCommand(ABC):
if "." in cls.command_name:
logger.error(f"命令名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"命令名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
# 生成正则表达式模式来匹配命令
command_pattern = cls._generate_command_pattern()
return CommandInfo(
name=cls.command_name,
component_type=ComponentType.COMMAND,
@@ -320,7 +320,7 @@ class PlusCommand(ABC):
if "." in cls.command_name:
logger.error(f"命令名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"命令名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
return PlusCommandInfo(
name=cls.command_name,
component_type=ComponentType.PLUS_COMMAND,
@@ -334,38 +334,38 @@ class PlusCommand(ABC):
@classmethod
def _generate_command_pattern(cls) -> str:
"""生成命令匹配的正则表达式
Returns:
str: 正则表达式字符串
"""
# 获取所有可能的命令名(主命令名 + 别名)
all_commands = [cls.command_name] + getattr(cls, 'command_aliases', [])
all_commands = [cls.command_name] + getattr(cls, "command_aliases", [])
# 转义特殊字符并创建选择组
escaped_commands = [re.escape(cmd) for cmd in all_commands]
commands_pattern = "|".join(escaped_commands)
# 获取默认前缀列表(这里先用硬编码,后续可以优化为动态获取)
default_prefixes = ["/", "!", ".", "#"]
escaped_prefixes = [re.escape(prefix) for prefix in default_prefixes]
prefixes_pattern = "|".join(escaped_prefixes)
# 生成完整的正则表达式
# 匹配: [前缀][命令名][可选空白][任意参数]
pattern = f"^(?P<prefix>{prefixes_pattern})(?P<command>{commands_pattern})(?P<args>\\s.*)?$"
return pattern
class PlusCommandAdapter(BaseCommand):
"""PlusCommand适配器
将PlusCommand适配到现有的插件系统继承BaseCommand
"""
def __init__(self, plus_command_class, message: MessageRecv, plugin_config: Optional[dict] = None):
"""初始化适配器
Args:
plus_command_class: PlusCommand子类
message: 消息对象
@@ -378,27 +378,27 @@ class PlusCommandAdapter(BaseCommand):
self.chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
self.priority = getattr(plus_command_class, "priority", 0)
self.intercept_message = getattr(plus_command_class, "intercept_message", False)
# 调用父类初始化
super().__init__(message, plugin_config)
# 创建PlusCommand实例
self.plus_command = plus_command_class(message, plugin_config)
async def execute(self) -> Tuple[bool, Optional[str], bool]:
"""执行命令
Returns:
Tuple[bool, Optional[str], bool]: 执行结果
"""
# 检查命令是否匹配
if not self.plus_command.is_command_match():
return False, "命令不匹配", False
# 检查聊天类型权限
if not self.plus_command.is_chat_type_allowed():
return False, "不支持当前聊天类型", self.intercept_message
# 执行命令
try:
return await self.plus_command.execute(self.plus_command.args)
@@ -409,49 +409,50 @@ class PlusCommandAdapter(BaseCommand):
def create_plus_command_adapter(plus_command_class):
"""创建PlusCommand适配器的工厂函数
Args:
plus_command_class: PlusCommand子类
Returns:
适配器类
"""
class AdapterClass(BaseCommand):
command_name = plus_command_class.command_name
command_description = plus_command_class.command_description
command_pattern = plus_command_class._generate_command_pattern()
chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
super().__init__(message, plugin_config)
self.plus_command = plus_command_class(message, plugin_config)
self.priority = getattr(plus_command_class, "priority", 0)
self.intercept_message = getattr(plus_command_class, "intercept_message", False)
async def execute(self) -> Tuple[bool, Optional[str], bool]:
"""执行命令"""
# 从BaseCommand的正则匹配结果中提取参数
args_text = ""
if hasattr(self, 'matched_groups') and self.matched_groups:
if hasattr(self, "matched_groups") and self.matched_groups:
# 从正则匹配组中获取参数部分
args_match = self.matched_groups.get('args', '')
args_match = self.matched_groups.get("args", "")
if args_match:
args_text = args_match.strip()
# 创建CommandArgs对象
command_args = CommandArgs(args_text)
# 检查聊天类型权限
if not self.plus_command.is_chat_type_allowed():
return False, "不支持当前聊天类型", self.intercept_message
# 执行命令,传递正确解析的参数
try:
return await self.plus_command.execute(command_args)
except Exception as e:
logger.error(f"执行命令时出错: {e}", exc_info=True)
return False, f"命令执行出错: {str(e)}", self.intercept_message
return AdapterClass

View File

@@ -34,7 +34,9 @@ class ComponentRegistry:
"""组件注册表 命名空间式组件名 -> 组件信息"""
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
"""类型 -> 组件原名称 -> 组件信息"""
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler, PlusCommand]]] = {}
self._components_classes: Dict[
str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler, PlusCommand]]
] = {}
"""命名空间式组件名 -> 组件类"""
# 插件注册表
@@ -166,7 +168,7 @@ class ComponentRegistry:
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
logger.error(f"注册失败: {action_name} 不是有效的Action")
return False
action_class.plugin_name = action_info.plugin_name
self._action_registry[action_name] = action_class
@@ -200,7 +202,9 @@ class ComponentRegistry:
return True
def _register_plus_command_component(self, plus_command_info: PlusCommandInfo, plus_command_class: Type[PlusCommand]) -> bool:
def _register_plus_command_component(
self, plus_command_info: PlusCommandInfo, plus_command_class: Type[PlusCommand]
) -> bool:
"""注册PlusCommand组件到特定注册表"""
plus_command_name = plus_command_info.name
@@ -212,7 +216,7 @@ class ComponentRegistry:
return False
# 创建专门的PlusCommand注册表如果还没有
if not hasattr(self, '_plus_command_registry'):
if not hasattr(self, "_plus_command_registry"):
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
plus_command_class.plugin_name = plus_command_info.plugin_name
@@ -249,10 +253,11 @@ class ComponentRegistry:
if not handler_info.enabled:
logger.warning(f"EventHandler组件 {handler_name} 未启用")
return True # 未启用,但是也是注册成功
handler_class.plugin_name = handler_info.plugin_name
# 使用EventManager进行事件处理器注册
from src.plugin_system.core.event_manager import event_manager
return event_manager.register_event_handler(handler_class)
# === 组件移除相关 ===
@@ -281,7 +286,7 @@ class ComponentRegistry:
case ComponentType.PLUS_COMMAND:
# 移除PlusCommand注册
if hasattr(self, '_plus_command_registry'):
if hasattr(self, "_plus_command_registry"):
self._plus_command_registry.pop(component_name, None)
logger.debug(f"已移除PlusCommand组件: {component_name}")
@@ -371,6 +376,7 @@ class ComponentRegistry:
assert issubclass(target_component_class, BaseEventHandler)
self._enabled_event_handlers[component_name] = target_component_class
from .event_manager import event_manager # 延迟导入防止循环导入问题
event_manager.register_event_handler(component_name)
namespaced_name = f"{component_type}.{component_name}"
@@ -572,7 +578,7 @@ class ComponentRegistry:
candidates[0].match(text).groupdict(), # type: ignore
command_info,
)
return None
# === Tool 特定查询方法 ===
@@ -599,7 +605,7 @@ class ComponentRegistry:
# === PlusCommand 特定查询方法 ===
def get_plus_command_registry(self) -> Dict[str, Type[PlusCommand]]:
"""获取PlusCommand注册表"""
if not hasattr(self, '_plus_command_registry'):
if not hasattr(self, "_plus_command_registry"):
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
return self._plus_command_registry.copy()

View File

@@ -2,55 +2,57 @@
事件管理器 - 实现Event和EventHandler的单例管理
提供统一的事件注册、管理和触发接口
"""
from typing import Dict, Type, List, Optional, Any, Union
from threading import Lock
from src.common.logger import get_logger
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection, HandlerResult
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.component_types import EventType
logger = get_logger("event_manager")
class EventManager:
"""事件管理器单例类
负责管理所有事件和事件处理器的注册、订阅、触发等操作
使用单例模式确保全局只有一个事件管理实例
"""
_instance: Optional['EventManager'] = None
_instance: Optional["EventManager"] = None
_lock = Lock()
def __new__(cls) -> 'EventManager':
def __new__(cls) -> "EventManager":
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._events: Dict[str, BaseEvent] = {}
self._event_handlers: Dict[str, Type[BaseEventHandler]] = {}
self._pending_subscriptions: Dict[str, List[str]] = {} # 缓存失败的订阅
self._initialized = True
logger.info("EventManager 单例初始化完成")
def register_event(
self,
event_name: Union[EventType, str],
allowed_subscribers: List[str]=None,
allowed_triggers: List[str]=None
) -> bool:
self,
event_name: Union[EventType, str],
allowed_subscribers: List[str] = None,
allowed_triggers: List[str] = None,
) -> bool:
"""注册一个新的事件
Args:
event_name Union[EventType, str]: 事件名称
allowed_subscribers: List[str]: 事件订阅者白名单,
allowed_subscribers: List[str]: 事件订阅者白名单,
allowed_triggers: List[str]: 事件触发插件白名单
Returns:
bool: 注册成功返回True已存在返回False
@@ -62,57 +64,57 @@ class EventManager:
if event_name in self._events:
logger.warning(f"事件 {event_name} 已存在,跳过注册")
return False
event = BaseEvent(event_name,allowed_subscribers,allowed_triggers)
event = BaseEvent(event_name, allowed_subscribers, allowed_triggers)
self._events[event_name] = event
logger.info(f"事件 {event_name} 注册成功")
# 检查是否有缓存的订阅需要处理
self._process_pending_subscriptions(event_name)
return True
def get_event(self, event_name: Union[EventType, str]) -> Optional[BaseEvent]:
"""获取指定事件实例
Args:
event_name Union[EventType, str]: 事件名称
Returns:
BaseEvent: 事件实例不存在返回None
"""
return self._events.get(event_name)
def get_all_events(self) -> Dict[str, BaseEvent]:
"""获取所有已注册的事件
Returns:
Dict[str, BaseEvent]: 所有事件的字典
"""
return self._events.copy()
def get_enabled_events(self) -> Dict[str, BaseEvent]:
"""获取所有已启用的事件
Returns:
Dict[str, BaseEvent]: 已启用事件的字典
"""
return {name: event for name, event in self._events.items() if event.enabled}
def get_disabled_events(self) -> Dict[str, BaseEvent]:
"""获取所有已禁用的事件
Returns:
Dict[str, BaseEvent]: 已禁用事件的字典
"""
return {name: event for name, event in self._events.items() if not event.enabled}
def enable_event(self, event_name: Union[EventType, str]) -> bool:
"""启用指定事件
Args:
event_name Union[EventType, str]: 事件名称
Returns:
bool: 成功返回True事件不存在返回False
"""
@@ -120,17 +122,17 @@ class EventManager:
if event is None:
logger.error(f"事件 {event_name} 不存在,无法启用")
return False
event.enabled = True
logger.info(f"事件 {event_name} 已启用")
return True
def disable_event(self, event_name: Union[EventType, str]) -> bool:
"""禁用指定事件
Args:
event_name Union[EventType, str]: 事件名称
Returns:
bool: 成功返回True事件不存在返回False
"""
@@ -138,38 +140,38 @@ class EventManager:
if event is None:
logger.error(f"事件 {event_name} 不存在,无法禁用")
return False
event.enabled = False
logger.info(f"事件 {event_name} 已禁用")
return True
def register_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
"""注册事件处理器
Args:
handler_class (Type[BaseEventHandler]): 事件处理器类
Returns:
bool: 注册成功返回True已存在返回False
"""
handler_name = handler_class.handler_name or handler_class.__name__.lower().replace("handler", "")
if EventType.UNKNOWN in handler_class.init_subscribe:
logger.error(f"事件处理器 {handler_name} 不能订阅 UNKNOWN 事件")
return False
if handler_name in self._event_handlers:
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
return False
self._event_handlers[handler_name] = handler_class()
# 处理init_subscribe缓存失败的订阅
if self._event_handlers[handler_name].init_subscribe:
failed_subscriptions = []
for event_name in self._event_handlers[handler_name].init_subscribe:
if not self.subscribe_handler_to_event(handler_name, event_name):
failed_subscriptions.append(event_name)
# 缓存失败的订阅
if failed_subscriptions:
self._pending_subscriptions[handler_name] = failed_subscriptions
@@ -177,33 +179,33 @@ class EventManager:
logger.info(f"事件处理器 {handler_name} 注册成功")
return True
def get_event_handler(self, handler_name: str) -> Optional[Type[BaseEventHandler]]:
"""获取指定事件处理器实例
Args:
handler_name (str): 处理器名称
Returns:
Type[BaseEventHandler]: 处理器实例不存在返回None
"""
return self._event_handlers.get(handler_name)
def get_all_event_handlers(self) -> Dict[str, BaseEventHandler]:
"""获取所有已注册的事件处理器
Returns:
Dict[str, Type[BaseEventHandler]]: 所有处理器的字典
"""
return self._event_handlers.copy()
def subscribe_handler_to_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
"""订阅事件处理器到指定事件
Args:
handler_name (str): 处理器名称
event_name Union[EventType, str]: 事件名称
Returns:
bool: 订阅成功返回True
"""
@@ -211,36 +213,36 @@ class EventManager:
if handler_instance is None:
logger.error(f"事件处理器 {handler_name} 不存在,无法订阅到事件 {event_name}")
return False
event = self.get_event(event_name)
if event is None:
logger.error(f"事件 {event_name} 不存在,无法订阅事件处理器 {handler_name}")
return False
if handler_instance in event.subscribers:
logger.warning(f"事件处理器 {handler_name} 已经订阅了事件 {event_name},跳过重复订阅")
return True
# 白名单检查
if event.allowed_subscribers and handler_name not in event.allowed_subscribers:
logger.warning(f"事件处理器 {handler_name} 不在事件 {event_name} 的订阅者白名单中,无法订阅")
return False
event.subscribers.append(handler_instance)
# 按权重从高到低排序订阅者
event.subscribers.sort(key=lambda h: getattr(h, 'weight', 0), reverse=True)
event.subscribers.sort(key=lambda h: getattr(h, "weight", 0), reverse=True)
logger.info(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成")
return True
def unsubscribe_handler_from_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
"""从指定事件取消订阅事件处理器
Args:
handler_name (str): 处理器名称
event_name Union[EventType, str]: 事件名称
Returns:
bool: 取消订阅成功返回True
"""
@@ -248,55 +250,57 @@ class EventManager:
if event is None:
logger.error(f"事件 {event_name} 不存在,无法取消订阅")
return False
# 查找并移除处理器实例
removed = False
for subscriber in event.subscribers[:]:
if hasattr(subscriber, 'handler_name') and subscriber.handler_name == handler_name:
if hasattr(subscriber, "handler_name") and subscriber.handler_name == handler_name:
event.subscribers.remove(subscriber)
removed = True
break
if removed:
logger.info(f"事件处理器 {handler_name} 成功从事件 {event_name} 取消订阅")
else:
logger.warning(f"事件处理器 {handler_name} 未订阅事件 {event_name}")
return removed
def get_event_subscribers(self, event_name: Union[EventType, str]) -> Dict[str, BaseEventHandler]:
"""获取订阅指定事件的所有事件处理器
Args:
event_name Union[EventType, str]: 事件名称
Returns:
Dict[str, BaseEventHandler]: 处理器字典,键为处理器名称,值为处理器实例
"""
event = self.get_event(event_name)
if event is None:
return {}
return {handler.handler_name: handler for handler in event.subscribers}
async def trigger_event(self, event_name: Union[EventType, str], plugin_name: Optional[str]="", **kwargs) -> Optional[HandlerResultsCollection]:
async def trigger_event(
self, event_name: Union[EventType, str], plugin_name: Optional[str] = "", **kwargs
) -> Optional[HandlerResultsCollection]:
"""触发指定事件
Args:
event_name Union[EventType, str]: 事件名称
plugin_name str: 触发事件的插件名
**kwargs: 传递给处理器的参数
Returns:
HandlerResultsCollection: 所有处理器的执行结果事件不存在返回None
"""
params = kwargs or {}
event = self.get_event(event_name)
if event is None:
logger.error(f"事件 {event_name} 不存在,无法触发")
return None
# 插件白名单检查
if event.allowed_triggers and not plugin_name:
logger.warning(f"事件 {event_name} 存在触发者白名单缺少plugin_name无法验证权限已拒绝触发")
@@ -304,9 +308,9 @@ class EventManager:
elif event.allowed_triggers and plugin_name not in event.allowed_triggers:
logger.warning(f"插件 {plugin_name} 没有权限触发事件 {event_name},已拒绝触发!")
return None
return await event.activate(params)
def init_default_events(self) -> None:
"""初始化默认事件"""
default_events = [
@@ -317,29 +321,29 @@ class EventManager:
EventType.POST_LLM,
EventType.AFTER_LLM,
EventType.POST_SEND,
EventType.AFTER_SEND
EventType.AFTER_SEND,
]
for event_name in default_events:
self.register_event(event_name,allowed_triggers=["SYSTEM"])
self.register_event(event_name, allowed_triggers=["SYSTEM"])
logger.info("默认事件初始化完成")
def clear_all_events(self) -> None:
"""清除所有事件和处理器(主要用于测试)"""
self._events.clear()
self._event_handlers.clear()
logger.info("所有事件和处理器已清除")
def get_event_summary(self) -> Dict[str, Any]:
"""获取事件系统摘要
Returns:
Dict[str, Any]: 包含事件系统统计信息的字典
"""
enabled_events = self.get_enabled_events()
disabled_events = self.get_disabled_events()
return {
"total_events": len(self._events),
"enabled_events": len(enabled_events),
@@ -347,58 +351,58 @@ class EventManager:
"total_handlers": len(self._event_handlers),
"event_names": list(self._events.keys()),
"handler_names": list(self._event_handlers.keys()),
"pending_subscriptions": len(self._pending_subscriptions)
"pending_subscriptions": len(self._pending_subscriptions),
}
def _process_pending_subscriptions(self, event_name: Union[EventType, str]) -> None:
"""处理指定事件的缓存订阅
Args:
event_name Union[EventType, str]: 事件名称
"""
handlers_to_remove = []
for handler_name, pending_events in self._pending_subscriptions.items():
if event_name in pending_events:
if self.subscribe_handler_to_event(handler_name, event_name):
pending_events.remove(event_name)
logger.info(f"成功处理缓存订阅: {handler_name} -> {event_name}")
# 如果该处理器没有更多待处理订阅,标记为移除
if not pending_events:
handlers_to_remove.append(handler_name)
# 清理已完成的处理器缓存
for handler_name in handlers_to_remove:
del self._pending_subscriptions[handler_name]
def process_all_pending_subscriptions(self) -> int:
"""处理所有缓存的订阅
Returns:
int: 成功处理的订阅数量
"""
processed_count = 0
# 复制待处理订阅,避免在迭代时修改字典
pending_copy = dict(self._pending_subscriptions)
for handler_name, pending_events in pending_copy.items():
for event_name in pending_events[:]: # 使用切片避免修改列表
if self.subscribe_handler_to_event(handler_name, event_name):
pending_events.remove(event_name)
processed_count += 1
# 清理已完成的处理器缓存
handlers_to_remove = [name for name, events in self._pending_subscriptions.items() if not events]
for handler_name in handlers_to_remove:
del self._pending_subscriptions[handler_name]
if processed_count > 0:
logger.info(f"批量处理缓存订阅完成,共处理 {processed_count} 个订阅")
return processed_count
# 创建全局事件管理器实例
event_manager = EventManager()
event_manager = EventManager()

View File

@@ -88,7 +88,7 @@ class GlobalAnnouncementManager:
return False
self._user_disabled_tools[chat_id].append(tool_name)
return True
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
"""启用特定聊天的某个工具"""
if chat_id in self._user_disabled_tools:
@@ -111,7 +111,7 @@ class GlobalAnnouncementManager:
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有事件处理器"""
return self._user_disabled_event_handlers.get(chat_id, []).copy()
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有工具"""
return self._user_disabled_tools.get(chat_id, []).copy()

View File

@@ -19,14 +19,14 @@ logger = get_logger(__name__)
class PermissionManager(IPermissionManager):
"""权限管理器实现类"""
def __init__(self):
self.engine = get_engine()
self.SessionLocal = sessionmaker(bind=self.engine)
self._master_users: Set[Tuple[str, str]] = set()
self._load_master_users()
logger.info("权限管理器初始化完成")
def _load_master_users(self):
"""从配置文件加载Master用户列表"""
try:
@@ -40,19 +40,19 @@ class PermissionManager(IPermissionManager):
except Exception as e:
logger.warning(f"加载Master用户配置失败: {e}")
self._master_users = set()
def reload_master_users(self):
"""重新加载Master用户配置"""
self._load_master_users()
logger.info("Master用户配置已重新加载")
def is_master(self, user: UserInfo) -> bool:
"""
检查用户是否为Master用户
Args:
user: 用户信息
Returns:
bool: 是否为Master用户
"""
@@ -61,15 +61,15 @@ class PermissionManager(IPermissionManager):
if is_master:
logger.debug(f"用户 {user.platform}:{user.user_id} 是Master用户")
return is_master
def check_permission(self, user: UserInfo, permission_node: str) -> bool:
"""
检查用户是否拥有指定权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 是否拥有权限
"""
@@ -78,46 +78,50 @@ class PermissionManager(IPermissionManager):
if self.is_master(user):
logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}")
return True
with self.SessionLocal() as session:
# 检查权限节点是否存在
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first()
if not node:
logger.warning(f"权限节点 {permission_node} 不存在")
return False
# 检查用户是否有明确的权限设置
user_perm = session.query(UserPermissions).filter_by(
platform=user.platform,
user_id=user.user_id,
permission_node=permission_node
).first()
user_perm = (
session.query(UserPermissions)
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
.first()
)
if user_perm:
# 有明确设置,返回设置的值
result = user_perm.granted
logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {result}")
logger.debug(
f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {result}"
)
return result
else:
# 没有明确设置,使用默认值
result = node.default_granted
logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {result}")
logger.debug(
f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {result}"
)
return result
except SQLAlchemyError as e:
logger.error(f"检查权限时数据库错误: {e}")
return False
except Exception as e:
logger.error(f"检查权限时发生未知错误: {e}")
return False
def register_permission_node(self, node: PermissionNode) -> bool:
"""
注册权限节点
Args:
node: 权限节点
Returns:
bool: 注册是否成功
"""
@@ -133,20 +137,20 @@ class PermissionManager(IPermissionManager):
session.commit()
logger.debug(f"更新权限节点: {node.node_name}")
return True
# 创建新节点
new_node = PermissionNodes(
node_name=node.node_name,
description=node.description,
plugin_name=node.plugin_name,
default_granted=node.default_granted,
created_at=datetime.utcnow()
created_at=datetime.utcnow(),
)
session.add(new_node)
session.commit()
logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})")
return True
except IntegrityError as e:
logger.error(f"注册权限节点时发生完整性错误: {e}")
return False
@@ -156,15 +160,15 @@ class PermissionManager(IPermissionManager):
except Exception as e:
logger.error(f"注册权限节点时发生未知错误: {e}")
return False
def grant_permission(self, user: UserInfo, permission_node: str) -> bool:
"""
授权用户权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 授权是否成功
"""
@@ -175,14 +179,14 @@ class PermissionManager(IPermissionManager):
if not node:
logger.error(f"尝试授权不存在的权限节点: {permission_node}")
return False
# 检查是否已有权限记录
existing_perm = session.query(UserPermissions).filter_by(
platform=user.platform,
user_id=user.user_id,
permission_node=permission_node
).first()
existing_perm = (
session.query(UserPermissions)
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
.first()
)
if existing_perm:
# 更新现有记录
existing_perm.granted = True
@@ -194,29 +198,29 @@ class PermissionManager(IPermissionManager):
user_id=user.user_id,
permission_node=permission_node,
granted=True,
granted_at=datetime.utcnow()
granted_at=datetime.utcnow(),
)
session.add(new_perm)
session.commit()
logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}")
return True
except SQLAlchemyError as e:
logger.error(f"授权权限时数据库错误: {e}")
return False
except Exception as e:
logger.error(f"授权权限时发生未知错误: {e}")
return False
def revoke_permission(self, user: UserInfo, permission_node: str) -> bool:
"""
撤销用户权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 撤销是否成功
"""
@@ -227,14 +231,14 @@ class PermissionManager(IPermissionManager):
if not node:
logger.error(f"尝试撤销不存在的权限节点: {permission_node}")
return False
# 检查是否已有权限记录
existing_perm = session.query(UserPermissions).filter_by(
platform=user.platform,
user_id=user.user_id,
permission_node=permission_node
).first()
existing_perm = (
session.query(UserPermissions)
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
.first()
)
if existing_perm:
# 更新现有记录
existing_perm.granted = False
@@ -246,28 +250,28 @@ class PermissionManager(IPermissionManager):
user_id=user.user_id,
permission_node=permission_node,
granted=False,
granted_at=datetime.utcnow()
granted_at=datetime.utcnow(),
)
session.add(new_perm)
session.commit()
logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}")
return True
except SQLAlchemyError as e:
logger.error(f"撤销权限时数据库错误: {e}")
return False
except Exception as e:
logger.error(f"撤销权限时发生未知错误: {e}")
return False
def get_user_permissions(self, user: UserInfo) -> List[str]:
"""
获取用户拥有的所有权限节点
Args:
user: 用户信息
Returns:
List[str]: 权限节点列表
"""
@@ -277,21 +281,21 @@ class PermissionManager(IPermissionManager):
with self.SessionLocal() as session:
all_nodes = session.query(PermissionNodes.node_name).all()
return [node.node_name for node in all_nodes]
permissions = []
with self.SessionLocal() as session:
# 获取所有权限节点
all_nodes = session.query(PermissionNodes).all()
for node in all_nodes:
# 检查用户是否有明确的权限设置
user_perm = session.query(UserPermissions).filter_by(
platform=user.platform,
user_id=user.user_id,
permission_node=node.node_name
).first()
user_perm = (
session.query(UserPermissions)
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=node.node_name)
.first()
)
if user_perm:
# 有明确设置,使用设置的值
if user_perm.granted:
@@ -300,20 +304,20 @@ class PermissionManager(IPermissionManager):
# 没有明确设置,使用默认值
if node.default_granted:
permissions.append(node.node_name)
return permissions
except SQLAlchemyError as e:
logger.error(f"获取用户权限时数据库错误: {e}")
return []
except Exception as e:
logger.error(f"获取用户权限时发生未知错误: {e}")
return []
def get_all_permission_nodes(self) -> List[PermissionNode]:
"""
获取所有已注册的权限节点
Returns:
List[PermissionNode]: 权限节点列表
"""
@@ -325,25 +329,25 @@ class PermissionManager(IPermissionManager):
node_name=node.node_name,
description=node.description,
plugin_name=node.plugin_name,
default_granted=node.default_granted
default_granted=node.default_granted,
)
for node in nodes
]
except SQLAlchemyError as e:
logger.error(f"获取所有权限节点时数据库错误: {e}")
return []
except Exception as e:
logger.error(f"获取所有权限节点时发生未知错误: {e}")
return []
def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
"""
获取指定插件的所有权限节点
Args:
plugin_name: 插件名称
Returns:
List[PermissionNode]: 权限节点列表
"""
@@ -355,25 +359,25 @@ class PermissionManager(IPermissionManager):
node_name=node.node_name,
description=node.description,
plugin_name=node.plugin_name,
default_granted=node.default_granted
default_granted=node.default_granted,
)
for node in nodes
]
except SQLAlchemyError as e:
logger.error(f"获取插件权限节点时数据库错误: {e}")
return []
except Exception as e:
logger.error(f"获取插件权限节点时发生未知错误: {e}")
return []
def delete_plugin_permissions(self, plugin_name: str) -> bool:
"""
删除指定插件的所有权限节点(用于插件卸载时清理)
Args:
plugin_name: 插件名称
Returns:
bool: 删除是否成功
"""
@@ -382,68 +386,71 @@ class PermissionManager(IPermissionManager):
# 获取插件的所有权限节点
plugin_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all()
node_names = [node.node_name for node in plugin_nodes]
if not node_names:
logger.info(f"插件 {plugin_name} 没有注册任何权限节点")
return True
# 删除用户权限记录
deleted_user_perms = session.query(UserPermissions).filter(
UserPermissions.permission_node.in_(node_names)
).delete(synchronize_session=False)
deleted_user_perms = (
session.query(UserPermissions)
.filter(UserPermissions.permission_node.in_(node_names))
.delete(synchronize_session=False)
)
# 删除权限节点
deleted_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).delete()
session.commit()
logger.info(f"已删除插件 {plugin_name}{deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录")
logger.info(
f"已删除插件 {plugin_name}{deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录"
)
return True
except SQLAlchemyError as e:
logger.error(f"删除插件权限时数据库错误: {e}")
return False
except Exception as e:
logger.error(f"删除插件权限时发生未知错误: {e}")
return False
def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]:
"""
获取拥有指定权限的所有用户
Args:
permission_node: 权限节点名称
Returns:
List[Tuple[str, str]]: 用户列表,格式为 [(platform, user_id), ...]
"""
try:
users = []
with self.SessionLocal() as session:
# 检查权限节点是否存在
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first()
if not node:
logger.warning(f"权限节点 {permission_node} 不存在")
return users
# 获取明确授权的用户
granted_users = session.query(UserPermissions).filter_by(
permission_node=permission_node,
granted=True
).all()
granted_users = (
session.query(UserPermissions).filter_by(permission_node=permission_node, granted=True).all()
)
for user_perm in granted_users:
users.append((user_perm.platform, user_perm.user_id))
# 如果是默认授权的权限节点,还需要考虑没有明确设置的用户
# 但这里我们只返回明确授权的用户,避免返回所有用户
# 添加Master用户他们拥有所有权限
users.extend(list(self._master_users))
# 去重
return list(set(users))
except SQLAlchemyError as e:
logger.error(f"获取拥有权限的用户时数据库错误: {e}")
return []

View File

@@ -36,21 +36,21 @@ class PluginFileHandler(FileSystemEventHandler):
"""文件修改事件"""
if not event.is_directory:
file_path = str(event.src_path)
if file_path.endswith(('.py', '.toml')):
if file_path.endswith((".py", ".toml")):
self._handle_file_change(file_path, "modified")
def on_created(self, event):
"""文件创建事件"""
if not event.is_directory:
file_path = str(event.src_path)
if file_path.endswith(('.py', '.toml')):
if file_path.endswith((".py", ".toml")):
self._handle_file_change(file_path, "created")
def on_deleted(self, event):
"""文件删除事件"""
if not event.is_directory:
file_path = str(event.src_path)
if file_path.endswith(('.py', '.toml')):
if file_path.endswith((".py", ".toml")):
self._handle_file_change(file_path, "deleted")
def _handle_file_change(self, file_path: str, change_type: str):
@@ -63,14 +63,14 @@ class PluginFileHandler(FileSystemEventHandler):
plugin_name, source_type = plugin_info
current_time = time.time()
# 文件变化缓存,避免重复处理同一文件的快速连续变化
file_cache_key = f"{file_path}_{change_type}"
last_file_time = self.file_change_cache.get(file_cache_key, 0)
if current_time - last_file_time < 0.5: # 0.5秒内的重复文件变化忽略
return
self.file_change_cache[file_cache_key] = current_time
# 插件级别的防抖处理
last_plugin_time = self.last_reload_time.get(plugin_name, 0)
if current_time - last_plugin_time < self.debounce_delay:
@@ -85,20 +85,28 @@ class PluginFileHandler(FileSystemEventHandler):
if change_type == "deleted":
# 解析实际的插件名称
actual_plugin_name = self.hot_reload_manager._resolve_plugin_name(plugin_name)
if file_name == "plugin.py":
if actual_plugin_name in plugin_manager.loaded_plugins:
logger.info(f"🗑️ 插件主文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]")
logger.info(
f"🗑️ 插件主文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
self.hot_reload_manager._unload_plugin(actual_plugin_name)
else:
logger.info(f"🗑️ 插件主文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]")
logger.info(
f"🗑️ 插件主文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
return
elif file_name in ("manifest.toml", "_manifest.json"):
if actual_plugin_name in plugin_manager.loaded_plugins:
logger.info(f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]")
logger.info(
f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
self.hot_reload_manager._unload_plugin(actual_plugin_name)
else:
logger.info(f"🗑️ 插件配置文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]")
logger.info(
f"🗑️ 插件配置文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
return
# 对于修改和创建事件,都进行重载
@@ -108,9 +116,7 @@ class PluginFileHandler(FileSystemEventHandler):
# 延迟重载,确保文件写入完成
reload_thread = Thread(
target=self._delayed_reload,
args=(plugin_name, source_type, current_time),
daemon=True
target=self._delayed_reload, args=(plugin_name, source_type, current_time), daemon=True
)
reload_thread.start()
@@ -126,14 +132,14 @@ class PluginFileHandler(FileSystemEventHandler):
# 检查是否还需要重载(可能在等待期间有更新的变化)
if plugin_name not in self.pending_reloads:
return
# 检查是否有更新的重载请求
if self.last_reload_time.get(plugin_name, 0) > trigger_time:
return
self.pending_reloads.discard(plugin_name)
logger.info(f"🔄 开始延迟重载插件: {plugin_name} [{source_type}]")
# 执行深度重载
success = self.hot_reload_manager._deep_reload_plugin(plugin_name)
if success:
@@ -146,7 +152,7 @@ class PluginFileHandler(FileSystemEventHandler):
def _get_plugin_info_from_path(self, file_path: str) -> Optional[Tuple[str, str]]:
"""从文件路径获取插件信息
Returns:
tuple[插件名称, 源类型] 或 None
"""
@@ -162,12 +168,12 @@ class PluginFileHandler(FileSystemEventHandler):
source_type = "built-in"
else:
source_type = "external"
# 获取插件目录名(插件名)
relative_path = path.relative_to(plugin_root)
if len(relative_path.parts) == 0:
continue
plugin_name = relative_path.parts[0]
# 确认这是一个有效的插件目录
@@ -175,9 +181,10 @@ class PluginFileHandler(FileSystemEventHandler):
if plugin_dir.is_dir():
# 检查是否有插件主文件或配置文件
has_plugin_py = (plugin_dir / "plugin.py").exists()
has_manifest = ((plugin_dir / "manifest.toml").exists() or
(plugin_dir / "_manifest.json").exists())
has_manifest = (plugin_dir / "manifest.toml").exists() or (
plugin_dir / "_manifest.json"
).exists()
if has_plugin_py or has_manifest:
return plugin_name, source_type
@@ -195,11 +202,11 @@ class PluginHotReloadManager:
# 默认监听两个目录:根目录下的 plugins 和 src 下的插件目录
self.watch_directories = [
os.path.join(os.getcwd(), "plugins"), # 外部插件目录
os.path.join(os.getcwd(), "src", "plugins", "built_in") # 内置插件目录
os.path.join(os.getcwd(), "src", "plugins", "built_in"), # 内置插件目录
]
else:
self.watch_directories = watch_directories
self.observers = []
self.file_handlers = []
self.is_running = False
@@ -221,13 +228,9 @@ class PluginHotReloadManager:
for watch_dir in self.watch_directories:
observer = Observer()
file_handler = PluginFileHandler(self)
observer.schedule(
file_handler,
watch_dir,
recursive=True
)
observer.schedule(file_handler, watch_dir, recursive=True)
observer.start()
self.observers.append(observer)
self.file_handlers.append(file_handler)
@@ -296,26 +299,26 @@ class PluginHotReloadManager:
if folder_name in plugin_manager.plugin_classes:
logger.debug(f"🔍 直接匹配插件名: {folder_name}")
return folder_name
# 如果没有直接匹配,搜索路径映射,并优先返回在插件类中存在的名称
matched_plugins = []
for plugin_name, plugin_path in plugin_manager.plugin_paths.items():
# 检查路径是否包含该文件夹名
if folder_name in plugin_path:
matched_plugins.append((plugin_name, plugin_path))
# 在匹配的插件中,优先选择在插件类中存在的
for plugin_name, plugin_path in matched_plugins:
if plugin_name in plugin_manager.plugin_classes:
logger.debug(f"🔍 文件夹名 '{folder_name}' 映射到插件名 '{plugin_name}' (路径: {plugin_path})")
return plugin_name
# 如果还是没找到在插件类中存在的,返回第一个匹配项
if matched_plugins:
plugin_name, plugin_path = matched_plugins[0]
logger.warning(f"⚠️ 文件夹 '{folder_name}' 映射到 '{plugin_name}',但该插件类不存在")
return plugin_name
# 如果还是没找到,返回原文件夹名
logger.warning(f"⚠️ 无法找到文件夹 '{folder_name}' 对应的插件名,使用原名称")
return folder_name
@@ -326,13 +329,13 @@ class PluginHotReloadManager:
# 解析实际的插件名称
actual_plugin_name = self._resolve_plugin_name(plugin_name)
logger.info(f"🔄 开始深度重载插件: {plugin_name} -> {actual_plugin_name}")
# 强制清理相关模块缓存
self._force_clear_plugin_modules(plugin_name)
# 使用插件管理器的强制重载功能
success = plugin_manager.force_reload_plugin(actual_plugin_name)
if success:
logger.info(f"✅ 插件深度重载成功: {actual_plugin_name}")
return True
@@ -348,15 +351,15 @@ class PluginHotReloadManager:
def _force_clear_plugin_modules(self, plugin_name: str):
"""强制清理插件相关的模块缓存"""
# 找到所有相关的模块名
modules_to_remove = []
plugin_module_prefix = f"src.plugins.built_in.{plugin_name}"
for module_name in list(sys.modules.keys()):
if plugin_module_prefix in module_name:
modules_to_remove.append(module_name)
# 删除模块缓存
for module_name in modules_to_remove:
if module_name in sys.modules:
@@ -369,7 +372,7 @@ class PluginHotReloadManager:
# 使用插件管理器的重载功能
success = plugin_manager.reload_plugin(plugin_name)
return success
except Exception as e:
logger.error(f"❌ 强制重新导入插件 {plugin_name} 时发生错误: {e}", exc_info=True)
return False
@@ -378,7 +381,7 @@ class PluginHotReloadManager:
"""卸载指定插件"""
try:
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
if plugin_manager.unload_plugin(plugin_name):
logger.info(f"✅ 插件卸载成功: {plugin_name}")
return True
@@ -409,7 +412,7 @@ class PluginHotReloadManager:
fail_count += 1
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count}")
# 清理全局缓存
importlib.invalidate_caches()
@@ -420,21 +423,21 @@ class PluginHotReloadManager:
"""手动强制重载指定插件(委托给插件管理器)"""
try:
logger.info(f"🔄 手动强制重载插件: {plugin_name}")
# 清理待重载列表中的该插件(避免重复重载)
for handler in self.file_handlers:
handler.pending_reloads.discard(plugin_name)
# 使用插件管理器的强制重载功能
success = plugin_manager.force_reload_plugin(plugin_name)
if success:
logger.info(f"✅ 手动强制重载成功: {plugin_name}")
else:
logger.error(f"❌ 手动强制重载失败: {plugin_name}")
return success
except Exception as e:
logger.error(f"❌ 手动强制重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
return False
@@ -457,19 +460,15 @@ class PluginHotReloadManager:
try:
observer = Observer()
file_handler = PluginFileHandler(self)
observer.schedule(
file_handler,
directory,
recursive=True
)
observer.schedule(file_handler, directory, recursive=True)
observer.start()
self.observers.append(observer)
self.file_handlers.append(file_handler)
logger.info(f"📂 已添加新的监听目录: {directory}")
except Exception as e:
logger.error(f"❌ 添加监听目录 {directory} 失败: {e}")
self.watch_directories.remove(directory)
@@ -480,7 +479,7 @@ class PluginHotReloadManager:
if self.file_handlers:
for handler in self.file_handlers:
pending_reloads.update(handler.pending_reloads)
return {
"is_running": self.is_running,
"watch_directories": self.watch_directories,
@@ -495,11 +494,11 @@ class PluginHotReloadManager:
"""清理所有Python模块缓存"""
try:
logger.info("🧹 开始清理所有Python模块缓存...")
# 重新扫描所有插件目录,这会重新加载模块
plugin_manager.rescan_plugin_directory()
logger.info("✅ 模块缓存清理完成")
except Exception as e:
logger.error(f"❌ 清理模块缓存时发生错误: {e}", exc_info=True)

View File

@@ -104,7 +104,7 @@ class PluginManager:
return False # 目标文件不存在,视为不同
# 使用 'rb' 模式以二进制方式读取文件,确保哈希值计算的一致性
with open(file1, 'rb') as f1, open(file2, 'rb') as f2:
with open(file1, "rb") as f1, open(file2, "rb") as f2:
return hashlib.md5(f1.read()).hexdigest() == hashlib.md5(f2.read()).hexdigest()
# === 插件目录管理 ===
@@ -300,7 +300,7 @@ class PluginManager:
list: 已注册的插件类名称列表。
"""
return list(self.plugin_classes.keys())
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
"""
获取指定插件的路径。
@@ -366,7 +366,7 @@ class PluginManager:
# 生成模块名和插件信息
plugin_path = Path(plugin_file)
plugin_dir = plugin_path.parent # 插件目录
plugin_name = plugin_dir.name # 插件名称
plugin_name = plugin_dir.name # 插件名称
module_name = ".".join(plugin_path.parent.parts)
try:
@@ -386,7 +386,7 @@ class PluginManager:
except Exception as e:
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
logger.error(error_msg)
self.failed_plugins[plugin_name if 'plugin_name' in locals() else module_name] = error_msg
self.failed_plugins[plugin_name if "plugin_name" in locals() else module_name] = error_msg
return False
# == 兼容性检查 ==
@@ -478,9 +478,7 @@ class PluginManager:
command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
]
tool_components = [
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
]
tool_components = [c for c in plugin_info.components if c.component_type == ComponentType.TOOL]
event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
]
@@ -591,7 +589,7 @@ class PluginManager:
plugin_instance = self.loaded_plugins[plugin_name]
# 调用插件的清理方法(如果有的话)
if hasattr(plugin_instance, 'on_unload'):
if hasattr(plugin_instance, "on_unload"):
plugin_instance.on_unload()
# 从组件注册表中移除插件的所有组件
@@ -654,10 +652,10 @@ class PluginManager:
def force_reload_plugin(self, plugin_name: str) -> bool:
"""强制重载插件(使用简化的方法)
Args:
plugin_name: 插件名称
Returns:
bool: 重载是否成功
"""

View File

@@ -129,17 +129,17 @@ class ToolExecutor:
if not tool_calls:
logger.debug(f"{self.log_prefix}无需执行工具")
return [], []
# 提取tool_calls中的函数名称
func_names = []
for call in tool_calls:
try:
if hasattr(call, 'func_name'):
if hasattr(call, "func_name"):
func_names.append(call.func_name)
except Exception as e:
logger.error(f"{self.log_prefix}获取工具名称失败: {e}")
continue
if func_names:
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
else:
@@ -185,9 +185,11 @@ class ToolExecutor:
return tool_results, used_tools
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
async def execute_tool_call(
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
) -> Optional[Dict[str, Any]]:
"""执行单个工具调用,并处理缓存"""
function_args = tool_call.args or {}
tool_instance = tool_instance or get_tool_instance(tool_call.func_name)
@@ -206,7 +208,7 @@ class ToolExecutor:
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
semantic_query=semantic_query
semantic_query=semantic_query,
)
if cached_result:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
@@ -223,14 +225,14 @@ class ToolExecutor:
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query
semantic_query=semantic_query,
)
except Exception as e:
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
@@ -238,12 +240,16 @@ class ToolExecutor:
return result
async def _original_execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
async def _original_execute_tool_call(
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
) -> Optional[Dict[str, Any]]:
"""执行单个工具调用的原始逻辑"""
try:
function_name = tool_call.func_name
function_args = tool_call.args or {}
logger.info(f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
logger.info(
f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}"
)
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
tool_instance = tool_instance or get_tool_instance(function_name)
@@ -261,7 +267,7 @@ class ToolExecutor:
"role": "tool",
"name": function_name,
"type": "function",
"content": result.get("content", "")
"content": result.get("content", ""),
}
logger.warning(f"{self.log_prefix}工具 {function_name} 返回空结果")
return None
@@ -308,7 +314,6 @@ class ToolExecutor:
return None
"""
ToolExecutor使用示例

View File

@@ -23,112 +23,105 @@
INSTALL_NAME_TO_IMPORT_NAME = {
# ============== 数据科学与机器学习 (Data Science & Machine Learning) ==============
"scikit-learn": "sklearn", # 机器学习库
"scikit-image": "skimage", # 图像处理库
"opencv-python": "cv2", # OpenCV 计算机视觉库
"opencv-contrib-python": "cv2", # OpenCV 扩展模块
"tensorflow-gpu": "tensorflow", # TensorFlow GPU版本
"tensorboardx": "tensorboardX", # TensorBoard 的封装
"torchvision": "torchvision", # PyTorch 视觉库 (通常与 torch 一起)
"torchaudio": "torchaudio", # PyTorch 音频库
"catboost": "catboost", # CatBoost 梯度提升库
"lightgbm": "lightgbm", # LightGBM 梯度提升库
"xgboost": "xgboost", # XGBoost 梯度提升库
"imbalanced-learn": "imblearn", # 处理不平衡数据集
"seqeval": "seqeval", # 序列标注评估
"gensim": "gensim", # 主题建模和NLP
"nltk": "nltk", # 自然语言工具包
"spacy": "spacy", # 工业级自然语言处理
"fuzzywuzzy": "fuzzywuzzy", # 模糊字符串匹配
"python-levenshtein": "Levenshtein", # Levenshtein 距离计算
"scikit-learn": "sklearn", # 机器学习库
"scikit-image": "skimage", # 图像处理库
"opencv-python": "cv2", # OpenCV 计算机视觉库
"opencv-contrib-python": "cv2", # OpenCV 扩展模块
"tensorflow-gpu": "tensorflow", # TensorFlow GPU版本
"tensorboardx": "tensorboardX", # TensorBoard 的封装
"torchvision": "torchvision", # PyTorch 视觉库 (通常与 torch 一起)
"torchaudio": "torchaudio", # PyTorch 音频库
"catboost": "catboost", # CatBoost 梯度提升库
"lightgbm": "lightgbm", # LightGBM 梯度提升库
"xgboost": "xgboost", # XGBoost 梯度提升库
"imbalanced-learn": "imblearn", # 处理不平衡数据集
"seqeval": "seqeval", # 序列标注评估
"gensim": "gensim", # 主题建模和NLP
"nltk": "nltk", # 自然语言工具包
"spacy": "spacy", # 工业级自然语言处理
"fuzzywuzzy": "fuzzywuzzy", # 模糊字符串匹配
"python-levenshtein": "Levenshtein", # Levenshtein 距离计算
# ============== Web开发与API (Web Development & API) ==============
"python-socketio": "socketio", # Socket.IO 服务器和客户端
"python-engineio": "engineio", # Engine.IO 底层库
"aiohttp": "aiohttp", # 异步HTTP客户端/服务器
"python-multipart": "multipart", # 解析 multipart/form-data
"uvloop": "uvloop", # 高性能asyncio事件循环
"httptools": "httptools", # 高性能HTTP解析器
"websockets": "websockets", # WebSocket实现
"fastapi": "fastapi", # 高性能Web框架
"starlette": "starlette", # ASGI框架
"uvicorn": "uvicorn", # ASGI服务器
"gunicorn": "gunicorn", # WSGI服务器
"django-rest-framework": "rest_framework", # Django REST框架
"django-cors-headers": "corsheaders", # Django CORS处理
"flask-jwt-extended": "flask_jwt_extended", # Flask JWT扩展
"flask-sqlalchemy": "flask_sqlalchemy", # Flask SQLAlchemy扩展
"flask-migrate": "flask_migrate", # Flask Alembic迁移扩展
"python-jose": "jose", # JOSE (JWT, JWS, JWE) 实现
"passlib": "passlib", # 密码哈希库
"bcrypt": "bcrypt", # Bcrypt密码哈希
"python-socketio": "socketio", # Socket.IO 服务器和客户端
"python-engineio": "engineio", # Engine.IO 底层库
"aiohttp": "aiohttp", # 异步HTTP客户端/服务器
"python-multipart": "multipart", # 解析 multipart/form-data
"uvloop": "uvloop", # 高性能asyncio事件循环
"httptools": "httptools", # 高性能HTTP解析器
"websockets": "websockets", # WebSocket实现
"fastapi": "fastapi", # 高性能Web框架
"starlette": "starlette", # ASGI框架
"uvicorn": "uvicorn", # ASGI服务器
"gunicorn": "gunicorn", # WSGI服务器
"django-rest-framework": "rest_framework", # Django REST框架
"django-cors-headers": "corsheaders", # Django CORS处理
"flask-jwt-extended": "flask_jwt_extended", # Flask JWT扩展
"flask-sqlalchemy": "flask_sqlalchemy", # Flask SQLAlchemy扩展
"flask-migrate": "flask_migrate", # Flask Alembic迁移扩展
"python-jose": "jose", # JOSE (JWT, JWS, JWE) 实现
"passlib": "passlib", # 密码哈希库
"bcrypt": "bcrypt", # Bcrypt密码哈希
# ============== 数据库 (Database) ==============
"mysql-connector-python": "mysql.connector", # MySQL官方驱动
"psycopg2-binary": "psycopg2", # PostgreSQL驱动 (二进制)
"pymongo": "pymongo", # MongoDB驱动
"redis": "redis", # Redis客户端
"aioredis": "aioredis", # 异步Redis客户端
"sqlalchemy": "sqlalchemy", # SQL工具包和ORM
"alembic": "alembic", # SQLAlchemy数据库迁移工具
"tortoise-orm": "tortoise", # 异步ORM
"mysql-connector-python": "mysql.connector", # MySQL官方驱动
"psycopg2-binary": "psycopg2", # PostgreSQL驱动 (二进制)
"pymongo": "pymongo", # MongoDB驱动
"redis": "redis", # Redis客户端
"aioredis": "aioredis", # 异步Redis客户端
"sqlalchemy": "sqlalchemy", # SQL工具包和ORM
"alembic": "alembic", # SQLAlchemy数据库迁移工具
"tortoise-orm": "tortoise", # 异步ORM
# ============== 图像与多媒体 (Image & Multimedia) ==============
"Pillow": "PIL", # Python图像处理库 (PIL Fork)
"moviepy": "moviepy", # 视频编辑库
"pydub": "pydub", # 音频处理库
"pycairo": "cairo", # Cairo 2D图形库的Python绑定
"wand": "wand", # ImageMagick的Python绑定
"Pillow": "PIL", # Python图像处理库 (PIL Fork)
"moviepy": "moviepy", # 视频编辑库
"pydub": "pydub", # 音频处理库
"pycairo": "cairo", # Cairo 2D图形库的Python绑定
"wand": "wand", # ImageMagick的Python绑定
# ============== 解析与序列化 (Parsing & Serialization) ==============
"beautifulsoup4": "bs4", # HTML/XML解析库
"lxml": "lxml", # 高性能HTML/XML解析库
"PyYAML": "yaml", # YAML解析库
"python-dotenv": "dotenv", # .env文件解析
"python-dateutil": "dateutil", # 强大的日期时间解析
"protobuf": "google.protobuf", # Protocol Buffers
"msgpack": "msgpack", # MessagePack序列化
"orjson": "orjson", # 高性能JSON库
"pydantic": "pydantic", # 数据验证和设置管理
"beautifulsoup4": "bs4", # HTML/XML解析库
"lxml": "lxml", # 高性能HTML/XML解析库
"PyYAML": "yaml", # YAML解析库
"python-dotenv": "dotenv", # .env文件解析
"python-dateutil": "dateutil", # 强大的日期时间解析
"protobuf": "google.protobuf", # Protocol Buffers
"msgpack": "msgpack", # MessagePack序列化
"orjson": "orjson", # 高性能JSON库
"pydantic": "pydantic", # 数据验证和设置管理
# ============== 系统与硬件 (System & Hardware) ==============
"pyserial": "serial", # 串口通信
"pyusb": "usb", # USB访问
"pybluez": "bluetooth", # 蓝牙通信 (可能因平台而异)
"psutil": "psutil", # 系统信息和进程管理
"watchdog": "watchdog", # 文件系统事件监控
"python-gnupg": "gnupg", # GnuPG的Python接口
"pyserial": "serial", # 串口通信
"pyusb": "usb", # USB访问
"pybluez": "bluetooth", # 蓝牙通信 (可能因平台而异)
"psutil": "psutil", # 系统信息和进程管理
"watchdog": "watchdog", # 文件系统事件监控
"python-gnupg": "gnupg", # GnuPG的Python接口
# ============== 加密与安全 (Cryptography & Security) ==============
"pycrypto": "Crypto", # 加密库 (较旧)
"pycryptodome": "Crypto", # PyCrypto的现代分支
"cryptography": "cryptography", # 现代加密库
"pyopenssl": "OpenSSL", # OpenSSL的Python接口
"service-identity": "service_identity", # 服务身份验证
"pycrypto": "Crypto", # 加密库 (较旧)
"pycryptodome": "Crypto", # PyCrypto的现代分支
"cryptography": "cryptography", # 现代加密库
"pyopenssl": "OpenSSL", # OpenSSL的Python接口
"service-identity": "service_identity", # 服务身份验证
# ============== 工具与杂项 (Utilities & Miscellaneous) ==============
"setuptools": "setuptools", # 打包工具
"pip": "pip", # 包安装器
"tqdm": "tqdm", # 进度条
"regex": "regex", # 替代的正则表达式引擎
"colorama": "colorama", # 跨平台彩色终端文本
"termcolor": "termcolor", # 终端颜色格式化
"requests-oauthlib": "requests_oauthlib", # OAuth for Requests
"oauthlib": "oauthlib", # 通用OAuth库
"authlib": "authlib", # OAuth和OpenID Connect客户端/服务器
"pyjwt": "jwt", # JSON Web Token实现
"python-editor": "editor", # 程序化地调用编辑器
"prompt-toolkit": "prompt_toolkit", # 构建交互式命令行
"pygments": "pygments", # 语法高亮
"tabulate": "tabulate", # 生成漂亮的表格
"nats-client": "nats", # NATS客户端
"gitpython": "git", # Git的Python接口
"pygithub": "github", # GitHub API v3的Python接口
"python-gitlab": "gitlab", # GitLab API的Python接口
"jira": "jira", # JIRA API的Python接口
"python-jenkins": "jenkins", # Jenkins API的Python接口
"huggingface-hub": "huggingface_hub", # Hugging Face Hub API
"apache-airflow": "airflow", # Airflow工作流管理
"pandas-stubs": "pandas-stubs", # Pandas的类型存根
"data-science-types": "data_science_types", # 数据科学类型
}
"setuptools": "setuptools", # 打包工具
"pip": "pip", # 包安装器
"tqdm": "tqdm", # 进度条
"regex": "regex", # 替代的正则表达式引擎
"colorama": "colorama", # 跨平台彩色终端文本
"termcolor": "termcolor", # 终端颜色格式化
"requests-oauthlib": "requests_oauthlib", # OAuth for Requests
"oauthlib": "oauthlib", # 通用OAuth库
"authlib": "authlib", # OAuth和OpenID Connect客户端/服务器
"pyjwt": "jwt", # JSON Web Token实现
"python-editor": "editor", # 程序化地调用编辑器
"prompt-toolkit": "prompt_toolkit", # 构建交互式命令行
"pygments": "pygments", # 语法高亮
"tabulate": "tabulate", # 生成漂亮的表格
"nats-client": "nats", # NATS客户端
"gitpython": "git", # Git的Python接口
"pygithub": "github", # GitHub API v3的Python接口
"python-gitlab": "gitlab", # GitLab API的Python接口
"jira": "jira", # JIRA API的Python接口
"python-jenkins": "jenkins", # Jenkins API的Python接口
"huggingface-hub": "huggingface_hub", # Hugging Face Hub API
"apache-airflow": "airflow", # Airflow工作流管理
"pandas-stubs": "pandas-stubs", # Pandas的类型存根
"data-science-types": "data_science_types", # 数据科学类型
}

View File

@@ -6,62 +6,61 @@ logger = get_logger("dependency_config")
class DependencyConfig:
"""依赖管理配置类 - 现在使用全局配置"""
def __init__(self, global_config=None):
self._global_config = global_config
def _get_config(self):
"""获取全局配置对象"""
if self._global_config is not None:
return self._global_config
# 延迟导入以避免循环依赖
try:
from src.config.config import global_config
return global_config
except ImportError:
logger.warning("无法导入全局配置,使用默认设置")
return None
@property
def auto_install(self) -> bool:
"""是否启用自动安装"""
config = self._get_config()
if config and hasattr(config, 'dependency_management'):
if config and hasattr(config, "dependency_management"):
return config.dependency_management.auto_install
return True
@property
def use_mirror(self) -> bool:
"""是否使用PyPI镜像源"""
config = self._get_config()
if config and hasattr(config, 'dependency_management'):
if config and hasattr(config, "dependency_management"):
return config.dependency_management.use_mirror
return False
@property
def mirror_url(self) -> str:
"""PyPI镜像源URL"""
config = self._get_config()
if config and hasattr(config, 'dependency_management'):
if config and hasattr(config, "dependency_management"):
return config.dependency_management.mirror_url
return ""
@property
def install_timeout(self) -> int:
"""安装超时时间(秒)"""
config = self._get_config()
if config and hasattr(config, 'dependency_management'):
if config and hasattr(config, "dependency_management"):
return config.dependency_management.auto_install_timeout
return 300
@property
def prompt_before_install(self) -> bool:
"""安装前是否提示用户"""
config = self._get_config()
if config and hasattr(config, 'dependency_management'):
if config and hasattr(config, "dependency_management"):
return config.dependency_management.prompt_before_install
return False
@@ -82,4 +81,4 @@ def configure_dependency_settings(**kwargs) -> None:
"""配置依赖管理设置 - 注意这个函数现在仅用于兼容性实际配置需要修改bot_config.toml"""
logger.info("依赖管理设置现在通过 bot_config.toml 的 [dependency_management] 节进行配置")
logger.info(f"请求的配置更改: {kwargs}")
logger.warning("configure_dependency_settings 函数仅用于兼容性,配置更改不会持久化")
logger.warning("configure_dependency_settings 函数仅用于兼容性,配置更改不会持久化")

View File

@@ -15,13 +15,13 @@ logger = get_logger("dependency_manager")
class DependencyManager:
"""Python包依赖管理器
负责检查和自动安装插件的Python包依赖
"""
def __init__(self, auto_install: bool = True, use_mirror: bool = False, mirror_url: Optional[str] = None):
"""初始化依赖管理器
Args:
auto_install: 是否自动安装缺失的依赖
use_mirror: 是否使用PyPI镜像源
@@ -30,38 +30,39 @@ class DependencyManager:
# 延迟导入配置以避免循环依赖
try:
from src.plugin_system.utils.dependency_config import get_dependency_config
config = get_dependency_config()
# 优先使用配置文件中的设置,参数作为覆盖
self.auto_install = config.auto_install if auto_install is True else auto_install
self.use_mirror = config.use_mirror if use_mirror is False else use_mirror
self.mirror_url = config.mirror_url if mirror_url is None else mirror_url
self.install_timeout = config.install_timeout
except Exception as e:
logger.warning(f"无法加载依赖配置,使用默认设置: {e}")
self.auto_install = auto_install
self.use_mirror = use_mirror or False
self.mirror_url = mirror_url or ""
self.install_timeout = 300
def check_dependencies(self, dependencies: Any, plugin_name: str = "") -> Tuple[bool, List[str], List[str]]:
"""检查依赖包是否满足要求
Args:
dependencies: 依赖列表支持字符串或PythonDependency对象
plugin_name: 插件名称,用于日志记录
Returns:
Tuple[bool, List[str], List[str]]: (是否全部满足, 缺失的包, 错误信息)
"""
missing_packages = []
error_messages = []
log_prefix = f"[Plugin:{plugin_name}] " if plugin_name else ""
# 标准化依赖格式
normalized_deps = self._normalize_dependencies(dependencies)
for dep in normalized_deps:
try:
if not self._check_single_dependency(dep):
@@ -71,38 +72,40 @@ class DependencyManager:
error_msg = f"检查依赖 {dep.package_name} 时发生错误: {str(e)}"
error_messages.append(error_msg)
logger.error(f"{log_prefix}{error_msg}")
all_satisfied = len(missing_packages) == 0 and len(error_messages) == 0
if all_satisfied:
logger.debug(f"{log_prefix}所有Python依赖检查通过")
else:
logger.warning(f"{log_prefix}Python依赖检查失败: 缺失{len(missing_packages)}个包, {len(error_messages)}个错误")
logger.warning(
f"{log_prefix}Python依赖检查失败: 缺失{len(missing_packages)}个包, {len(error_messages)}个错误"
)
return all_satisfied, missing_packages, error_messages
def install_dependencies(self, packages: List[str], plugin_name: str = "") -> Tuple[bool, List[str]]:
"""自动安装缺失的依赖包
Args:
packages: 要安装的包列表
plugin_name: 插件名称,用于日志记录
Returns:
Tuple[bool, List[str]]: (是否全部安装成功, 失败的包列表)
"""
if not packages:
return True, []
if not self.auto_install:
logger.info(f"[Plugin:{plugin_name}] 自动安装已禁用,跳过安装: {packages}")
return False, packages
log_prefix = f"[Plugin:{plugin_name}] " if plugin_name else ""
logger.info(f"{log_prefix}开始自动安装Python依赖: {packages}")
failed_packages = []
for package in packages:
try:
if self._install_single_package(package, plugin_name):
@@ -113,37 +116,37 @@ class DependencyManager:
except Exception as e:
failed_packages.append(package)
logger.error(f"{log_prefix}❌ 安装 {package} 时发生异常: {str(e)}")
success = len(failed_packages) == 0
if success:
logger.info(f"{log_prefix}🎉 所有依赖安装完成")
else:
logger.error(f"{log_prefix}⚠️ 部分依赖安装失败: {failed_packages}")
return success, failed_packages
def check_and_install_dependencies(self, dependencies: Any, plugin_name: str = "") -> Tuple[bool, List[str]]:
"""检查并自动安装依赖(组合操作)
Args:
dependencies: 依赖列表
plugin_name: 插件名称
Returns:
Tuple[bool, List[str]]: (是否全部满足, 错误信息列表)
"""
# 第一步:检查依赖
all_satisfied, missing_packages, check_errors = self.check_dependencies(dependencies, plugin_name)
if all_satisfied:
return True, []
all_errors = check_errors.copy()
# 第二步:尝试安装缺失的包
if missing_packages and self.auto_install:
install_success, failed_packages = self.install_dependencies(missing_packages, plugin_name)
if not install_success:
all_errors.extend([f"安装失败: {pkg}" for pkg in failed_packages])
else:
@@ -156,13 +159,13 @@ class DependencyManager:
return True, []
else:
all_errors.extend([f"缺失依赖: {pkg}" for pkg in missing_packages])
return False, all_errors
def _normalize_dependencies(self, dependencies: Any) -> List[PythonDependency]:
"""将依赖列表标准化为PythonDependency对象"""
normalized = []
for dep in dependencies:
if isinstance(dep, str):
# 解析字符串格式的依赖
@@ -170,28 +173,27 @@ class DependencyManager:
# 尝试解析为requirement格式 (如 "package>=1.0.0")
req = Requirement(dep)
version_spec = str(req.specifier) if req.specifier else ""
normalized.append(PythonDependency(
package_name=req.name,
version=version_spec,
install_name=dep # 保持原始的安装名称
))
normalized.append(
PythonDependency(
package_name=req.name,
version=version_spec,
install_name=dep, # 保持原始的安装名称
)
)
except Exception:
# 如果解析失败,作为简单包名处理
normalized.append(PythonDependency(
package_name=dep,
install_name=dep
))
normalized.append(PythonDependency(package_name=dep, install_name=dep))
elif isinstance(dep, PythonDependency):
normalized.append(dep)
else:
logger.warning(f"未知的依赖格式: {dep}")
return normalized
def _check_single_dependency(self, dep: PythonDependency) -> bool:
"""检查单个依赖是否满足要求"""
def _try_check(import_name: str) -> bool:
"""尝试使用给定的导入名进行检查"""
try:
@@ -206,11 +208,11 @@ class DependencyManager:
# 检查版本要求
try:
module = importlib.import_module(import_name)
installed_version = getattr(module, '__version__', None)
installed_version = getattr(module, "__version__", None)
if installed_version is None:
# 尝试其他常见的版本属性
installed_version = getattr(module, 'VERSION', None)
installed_version = getattr(module, "VERSION", None)
if installed_version is None:
logger.debug(f"无法获取包 {import_name} 的版本信息,假设满足要求")
return True
@@ -243,33 +245,27 @@ class DependencyManager:
# 3. 如果别名也失败了,或者没有别名,最终确认失败
return False
def _install_single_package(self, package: str, plugin_name: str = "") -> bool:
"""安装单个包"""
try:
cmd = [sys.executable, "-m", "pip", "install", package]
# 添加镜像源设置
if self.use_mirror and self.mirror_url:
cmd.extend(["-i", self.mirror_url])
logger.debug(f"[Plugin:{plugin_name}] 使用PyPI镜像源: {self.mirror_url}")
logger.debug(f"[Plugin:{plugin_name}] 执行安装命令: {' '.join(cmd)}")
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=self.install_timeout,
check=False
)
result = subprocess.run(cmd, capture_output=True, text=True, timeout=self.install_timeout, check=False)
if result.returncode == 0:
return True
else:
logger.error(f"[Plugin:{plugin_name}] pip安装失败: {result.stderr}")
return False
except subprocess.TimeoutExpired:
logger.error(f"[Plugin:{plugin_name}] 安装 {package} 超时")
return False
@@ -294,7 +290,5 @@ def configure_dependency_manager(auto_install: bool = True, use_mirror: bool = F
"""配置全局依赖管理器"""
global _global_dependency_manager
_global_dependency_manager = DependencyManager(
auto_install=auto_install,
use_mirror=use_mirror,
mirror_url=mirror_url
)
auto_install=auto_install, use_mirror=use_mirror, mirror_url=mirror_url
)

View File

@@ -19,65 +19,64 @@ logger = get_logger(__name__)
def require_permission(permission_node: str, deny_message: Optional[str] = None):
"""
权限检查装饰器
用于装饰需要特定权限才能执行的函数。如果用户没有权限,会发送拒绝消息并阻止函数执行。
Args:
permission_node: 所需的权限节点名称
deny_message: 权限不足时的提示消息如果为None则使用默认消息
Example:
@require_permission("plugin.example.admin")
async def admin_command(message: Message, chat_stream: ChatStream):
# 只有拥有 plugin.example.admin 权限的用户才能执行
pass
"""
def decorator(func: Callable):
@wraps(func)
async def async_wrapper(*args, **kwargs):
# 尝试从参数中提取 ChatStream 对象
chat_stream = None
# 首先检查位置参数中的 ChatStream
for arg in args:
if isinstance(arg, ChatStream):
chat_stream = arg
break
# 如果在位置参数中没找到,尝试从关键字参数中查找
if chat_stream is None:
chat_stream = kwargs.get('chat_stream')
chat_stream = kwargs.get("chat_stream")
# 如果还没找到,检查是否是 PlusCommand 方法调用
if chat_stream is None and args:
# 检查第一个参数是否有 message.chat_stream 属性PlusCommand 实例)
instance = args[0]
if hasattr(instance, 'message') and hasattr(instance.message, 'chat_stream'):
if hasattr(instance, "message") and hasattr(instance.message, "chat_stream"):
chat_stream = instance.message.chat_stream
if chat_stream is None:
logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
return
# 检查权限
has_permission = permission_api.check_permission(
chat_stream.platform,
chat_stream.user_info.user_id,
permission_node
chat_stream.platform, chat_stream.user_info.user_id, permission_node
)
if not has_permission:
# 权限不足,发送拒绝消息
message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}"
await text_to_stream(message, chat_stream.stream_id)
# 对于PlusCommand的execute方法需要返回适当的元组
if func.__name__ == 'execute' and hasattr(args[0], 'send_text'):
if func.__name__ == "execute" and hasattr(args[0], "send_text"):
return False, "权限不足", True
return
# 权限检查通过,执行原函数
return await func(*args, **kwargs)
def sync_wrapper(*args, **kwargs):
# 对于同步函数,我们不能发送异步消息,只能记录日志
chat_stream = None
@@ -85,95 +84,93 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None)
if isinstance(arg, ChatStream):
chat_stream = arg
break
if chat_stream is None:
chat_stream = kwargs.get('chat_stream')
chat_stream = kwargs.get("chat_stream")
if chat_stream is None:
logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
return
# 检查权限
has_permission = permission_api.check_permission(
chat_stream.platform,
chat_stream.user_info.user_id,
permission_node
chat_stream.platform, chat_stream.user_info.user_id, permission_node
)
if not has_permission:
logger.warning(f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 没有权限 {permission_node}")
logger.warning(
f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 没有权限 {permission_node}"
)
return
# 权限检查通过,执行原函数
return func(*args, **kwargs)
# 根据函数类型选择包装器
if iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
def require_master(deny_message: Optional[str] = None):
"""
Master权限检查装饰器
用于装饰只有Master用户才能执行的函数。
Args:
deny_message: 权限不足时的提示消息如果为None则使用默认消息
Example:
@require_master()
async def master_only_command(message: Message, chat_stream: ChatStream):
# 只有Master用户才能执行
pass
"""
def decorator(func: Callable):
@wraps(func)
async def async_wrapper(*args, **kwargs):
# 尝试从参数中提取 ChatStream 对象
chat_stream = None
# 首先检查位置参数中的 ChatStream
for arg in args:
if isinstance(arg, ChatStream):
chat_stream = arg
break
# 如果在位置参数中没找到,尝试从关键字参数中查找
if chat_stream is None:
chat_stream = kwargs.get('chat_stream')
chat_stream = kwargs.get("chat_stream")
# 如果还没找到,检查是否是 PlusCommand 方法调用
if chat_stream is None and args:
# 检查第一个参数是否有 message.chat_stream 属性PlusCommand 实例)
instance = args[0]
if hasattr(instance, 'message') and hasattr(instance.message, 'chat_stream'):
if hasattr(instance, "message") and hasattr(instance.message, "chat_stream"):
chat_stream = instance.message.chat_stream
if chat_stream is None:
logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
return
# 检查是否为Master用户
is_master = permission_api.is_master(
chat_stream.platform,
chat_stream.user_info.user_id
)
is_master = permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id)
if not is_master:
message = deny_message or "❌ 此操作仅限Master用户执行"
await text_to_stream(message, chat_stream.stream_id)
if func.__name__ == 'execute' and hasattr(args[0], 'send_text'):
if func.__name__ == "execute" and hasattr(args[0], "send_text"):
return False, "需要Master权限", True
return
# 权限检查通过,执行原函数
return await func(*args, **kwargs)
def sync_wrapper(*args, **kwargs):
# 对于同步函数,我们不能发送异步消息,只能记录日志
chat_stream = None
@@ -181,116 +178,106 @@ def require_master(deny_message: Optional[str] = None):
if isinstance(arg, ChatStream):
chat_stream = arg
break
if chat_stream is None:
chat_stream = kwargs.get('chat_stream')
chat_stream = kwargs.get("chat_stream")
if chat_stream is None:
logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
return
# 检查是否为Master用户
is_master = permission_api.is_master(
chat_stream.platform,
chat_stream.user_info.user_id
)
is_master = permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id)
if not is_master:
logger.warning(f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 不是Master用户")
return
# 权限检查通过,执行原函数
return func(*args, **kwargs)
# 根据函数类型选择包装器
if iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
class PermissionChecker:
"""
权限检查工具类
提供一些便捷的权限检查方法,用于在代码中进行权限验证。
"""
@staticmethod
def check_permission(chat_stream: ChatStream, permission_node: str) -> bool:
"""
检查用户是否拥有指定权限
Args:
chat_stream: 聊天流对象
permission_node: 权限节点名称
Returns:
bool: 是否拥有权限
"""
return permission_api.check_permission(
chat_stream.platform,
chat_stream.user_info.user_id,
permission_node
)
return permission_api.check_permission(chat_stream.platform, chat_stream.user_info.user_id, permission_node)
@staticmethod
def is_master(chat_stream: ChatStream) -> bool:
"""
检查用户是否为Master用户
Args:
chat_stream: 聊天流对象
Returns:
bool: 是否为Master用户
"""
return permission_api.is_master(
chat_stream.platform,
chat_stream.user_info.user_id
)
return permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id)
@staticmethod
async def ensure_permission(chat_stream: ChatStream, permission_node: str,
deny_message: Optional[str] = None) -> bool:
async def ensure_permission(
chat_stream: ChatStream, permission_node: str, deny_message: Optional[str] = None
) -> bool:
"""
确保用户拥有指定权限如果没有权限会发送消息并返回False
Args:
chat_stream: 聊天流对象
permission_node: 权限节点名称
deny_message: 权限不足时的提示消息
Returns:
bool: 是否拥有权限
"""
has_permission = PermissionChecker.check_permission(chat_stream, permission_node)
if not has_permission:
message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}"
await text_to_stream(message, chat_stream.stream_id)
return has_permission
@staticmethod
async def ensure_master(chat_stream: ChatStream,
deny_message: Optional[str] = None) -> bool:
async def ensure_master(chat_stream: ChatStream, deny_message: Optional[str] = None) -> bool:
"""
确保用户为Master用户如果不是会发送消息并返回False
Args:
chat_stream: 聊天流对象
deny_message: 权限不足时的提示消息
Returns:
bool: 是否为Master用户
"""
is_master = PermissionChecker.is_master(chat_stream)
if not is_master:
message = deny_message or "❌ 此操作仅限Master用户执行"
await text_to_stream(message, chat_stream.stream_id)
return is_master