修复代码格式和文件名大小写问题
This commit is contained in:
@@ -48,7 +48,7 @@ from .utils.dependency_config import get_dependency_config, configure_dependency
|
||||
|
||||
from .apis import (
|
||||
chat_api,
|
||||
tool_api,
|
||||
tool_api,
|
||||
component_manage_api,
|
||||
config_api,
|
||||
database_api,
|
||||
@@ -91,8 +91,8 @@ __all__ = [
|
||||
# 增强命令系统
|
||||
"PlusCommand",
|
||||
"CommandArgs",
|
||||
"PlusCommandAdapter",
|
||||
"create_plus_command_adapter",
|
||||
"PlusCommandAdapter",
|
||||
"create_plus_command_adapter",
|
||||
"create_plus_command_adapter",
|
||||
# 类型定义
|
||||
"ComponentType",
|
||||
|
||||
@@ -9,19 +9,7 @@
|
||||
注意:此模块现在使用SQLAlchemy实现,提供更好的连接管理和错误处理
|
||||
"""
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import (
|
||||
db_query,
|
||||
db_save,
|
||||
db_get,
|
||||
store_action_info,
|
||||
MODEL_MAPPING
|
||||
)
|
||||
from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get, store_action_info, MODEL_MAPPING
|
||||
|
||||
# 保持向后兼容性
|
||||
__all__ = [
|
||||
'db_query',
|
||||
'db_save',
|
||||
'db_get',
|
||||
'store_action_info',
|
||||
'MODEL_MAPPING'
|
||||
]
|
||||
__all__ = ["db_query", "db_save", "db_get", "store_action_info", "MODEL_MAPPING"]
|
||||
|
||||
@@ -72,7 +72,9 @@ async def generate_with_model(
|
||||
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens)
|
||||
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(
|
||||
prompt, temperature=temperature, max_tokens=max_tokens
|
||||
)
|
||||
return True, response, reasoning_content, model_name
|
||||
|
||||
except Exception as e:
|
||||
@@ -80,6 +82,7 @@ async def generate_with_model(
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", ""
|
||||
|
||||
|
||||
async def generate_with_model_with_tools(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
@@ -109,10 +112,7 @@ async def generate_with_model_with_tools(
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
|
||||
prompt,
|
||||
tools=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
prompt, tools=tool_options, temperature=temperature, max_tokens=max_tokens
|
||||
)
|
||||
return True, response, reasoning_content, model_name, tool_call
|
||||
|
||||
|
||||
@@ -97,7 +97,9 @@ def get_messages_by_time_in_chat(
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command))
|
||||
return filter_mai_messages(
|
||||
get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
)
|
||||
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
|
||||
|
||||
@@ -137,9 +139,13 @@ def get_messages_by_time_in_chat_inclusive(
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id, start_time, end_time, limit, limit_mode, filter_command
|
||||
)
|
||||
)
|
||||
return get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
return get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id, start_time, end_time, limit, limit_mode, filter_command
|
||||
)
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat_for_users(
|
||||
|
||||
@@ -17,12 +17,14 @@ logger = get_logger(__name__)
|
||||
|
||||
class PermissionLevel(Enum):
|
||||
"""权限等级枚举"""
|
||||
|
||||
MASTER = "master" # 最高权限,无视所有权限节点
|
||||
|
||||
|
||||
@dataclass
|
||||
class PermissionNode:
|
||||
"""权限节点数据类"""
|
||||
|
||||
node_name: str # 权限节点名称,如 "plugin.example.command.test"
|
||||
description: str # 权限节点描述
|
||||
plugin_name: str # 所属插件名称
|
||||
@@ -32,13 +34,14 @@ class PermissionNode:
|
||||
@dataclass
|
||||
class UserInfo:
|
||||
"""用户信息数据类"""
|
||||
|
||||
platform: str # 平台类型,如 "qq"
|
||||
user_id: str # 用户ID
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
"""确保user_id是字符串类型"""
|
||||
self.user_id = str(self.user_id)
|
||||
|
||||
|
||||
def to_tuple(self) -> tuple[str, str]:
|
||||
"""转换为元组格式"""
|
||||
return (self.platform, self.user_id)
|
||||
@@ -46,106 +49,106 @@ class UserInfo:
|
||||
|
||||
class IPermissionManager(ABC):
|
||||
"""权限管理器接口"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def check_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||
"""
|
||||
检查用户是否拥有指定权限节点
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否拥有权限
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def is_master(self, user: UserInfo) -> bool:
|
||||
"""
|
||||
检查用户是否为Master用户
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否为Master用户
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def register_permission_node(self, node: PermissionNode) -> bool:
|
||||
"""
|
||||
注册权限节点
|
||||
|
||||
|
||||
Args:
|
||||
node: 权限节点
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 注册是否成功
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def grant_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||
"""
|
||||
授权用户权限节点
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 授权是否成功
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def revoke_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||
"""
|
||||
撤销用户权限节点
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 撤销是否成功
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_user_permissions(self, user: UserInfo) -> List[str]:
|
||||
"""
|
||||
获取用户拥有的所有权限节点
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 权限节点列表
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_all_permission_nodes(self) -> List[PermissionNode]:
|
||||
"""
|
||||
获取所有已注册的权限节点
|
||||
|
||||
|
||||
Returns:
|
||||
List[PermissionNode]: 权限节点列表
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
|
||||
"""
|
||||
获取指定插件的所有权限节点
|
||||
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
|
||||
Returns:
|
||||
List[PermissionNode]: 权限节点列表
|
||||
"""
|
||||
@@ -154,146 +157,144 @@ class IPermissionManager(ABC):
|
||||
|
||||
class PermissionAPI:
|
||||
"""权限系统API类"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._permission_manager: Optional[IPermissionManager] = None
|
||||
|
||||
|
||||
def set_permission_manager(self, manager: IPermissionManager):
|
||||
"""设置权限管理器实例"""
|
||||
self._permission_manager = manager
|
||||
logger.info("权限管理器已设置")
|
||||
|
||||
|
||||
def _ensure_manager(self):
|
||||
"""确保权限管理器已设置"""
|
||||
if self._permission_manager is None:
|
||||
raise RuntimeError("权限管理器未设置,请先调用 set_permission_manager")
|
||||
|
||||
|
||||
def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||
"""
|
||||
检查用户是否拥有指定权限节点
|
||||
|
||||
|
||||
Args:
|
||||
platform: 平台类型,如 "qq"
|
||||
user_id: 用户ID
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否拥有权限
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
self._ensure_manager()
|
||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||
return self._permission_manager.check_permission(user, permission_node)
|
||||
|
||||
|
||||
def is_master(self, platform: str, user_id: str) -> bool:
|
||||
"""
|
||||
检查用户是否为Master用户
|
||||
|
||||
|
||||
Args:
|
||||
platform: 平台类型,如 "qq"
|
||||
user_id: 用户ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否为Master用户
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
self._ensure_manager()
|
||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||
return self._permission_manager.is_master(user)
|
||||
|
||||
def register_permission_node(self, node_name: str, description: str, plugin_name: str,
|
||||
default_granted: bool = False) -> bool:
|
||||
|
||||
def register_permission_node(
|
||||
self, node_name: str, description: str, plugin_name: str, default_granted: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
注册权限节点
|
||||
|
||||
|
||||
Args:
|
||||
node_name: 权限节点名称,如 "plugin.example.command.test"
|
||||
description: 权限节点描述
|
||||
plugin_name: 所属插件名称
|
||||
default_granted: 默认是否授权
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 注册是否成功
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
self._ensure_manager()
|
||||
node = PermissionNode(
|
||||
node_name=node_name,
|
||||
description=description,
|
||||
plugin_name=plugin_name,
|
||||
default_granted=default_granted
|
||||
node_name=node_name, description=description, plugin_name=plugin_name, default_granted=default_granted
|
||||
)
|
||||
return self._permission_manager.register_permission_node(node)
|
||||
|
||||
|
||||
def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||
"""
|
||||
授权用户权限节点
|
||||
|
||||
|
||||
Args:
|
||||
platform: 平台类型,如 "qq"
|
||||
user_id: 用户ID
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 授权是否成功
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
self._ensure_manager()
|
||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||
return self._permission_manager.grant_permission(user, permission_node)
|
||||
|
||||
|
||||
def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||
"""
|
||||
撤销用户权限节点
|
||||
|
||||
|
||||
Args:
|
||||
platform: 平台类型,如 "qq"
|
||||
user_id: 用户ID
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 撤销是否成功
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
self._ensure_manager()
|
||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||
return self._permission_manager.revoke_permission(user, permission_node)
|
||||
|
||||
|
||||
def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
|
||||
"""
|
||||
获取用户拥有的所有权限节点
|
||||
|
||||
|
||||
Args:
|
||||
platform: 平台类型,如 "qq"
|
||||
user_id: 用户ID
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 权限节点列表
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
self._ensure_manager()
|
||||
user = UserInfo(platform=platform, user_id=str(user_id))
|
||||
return self._permission_manager.get_user_permissions(user)
|
||||
|
||||
|
||||
def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取所有已注册的权限节点
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 权限节点列表,每个节点包含 node_name, description, plugin_name, default_granted
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
@@ -304,21 +305,21 @@ class PermissionAPI:
|
||||
"node_name": node.node_name,
|
||||
"description": node.description,
|
||||
"plugin_name": node.plugin_name,
|
||||
"default_granted": node.default_granted
|
||||
"default_granted": node.default_granted,
|
||||
}
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
|
||||
def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定插件的所有权限节点
|
||||
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 权限节点列表
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: 权限管理器未设置时抛出
|
||||
"""
|
||||
@@ -329,7 +330,7 @@ class PermissionAPI:
|
||||
"node_name": node.node_name,
|
||||
"description": node.description,
|
||||
"plugin_name": node.plugin_name,
|
||||
"default_granted": node.default_granted
|
||||
"default_granted": node.default_granted,
|
||||
}
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
@@ -34,7 +34,7 @@ def get_plugin_path(plugin_name: str) -> str:
|
||||
|
||||
Returns:
|
||||
str: 插件目录的绝对路径。
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 如果插件不存在。
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plugin_manager") # 复用plugin_manager名称
|
||||
logger = get_logger("plugin_manager") # 复用plugin_manager名称
|
||||
|
||||
|
||||
def register_plugin(cls):
|
||||
|
||||
@@ -64,7 +64,7 @@ async def wait_adapter_response(request_id: str, timeout: float = 30.0) -> dict:
|
||||
"""等待适配器响应"""
|
||||
future = asyncio.Future()
|
||||
_adapter_response_pool[request_id] = future
|
||||
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
return response
|
||||
@@ -369,10 +369,10 @@ async def adapter_command_to_stream(
|
||||
platform: Optional[str] = "qq",
|
||||
stream_id: Optional[str] = None,
|
||||
timeout: float = 30.0,
|
||||
storage_message: bool = False
|
||||
storage_message: bool = False,
|
||||
) -> dict:
|
||||
"""向适配器发送命令并获取返回值
|
||||
|
||||
|
||||
雅诺狐的耳朵特别软
|
||||
|
||||
Args:
|
||||
@@ -388,20 +388,20 @@ async def adapter_command_to_stream(
|
||||
- 成功: {"status": "ok", "data": {...}, "message": "..."}
|
||||
- 失败: {"status": "failed", "message": "错误信息"}
|
||||
- 错误: {"status": "error", "message": "错误信息"}
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 当stream_id和platform都未提供时抛出
|
||||
"""
|
||||
if not stream_id and not platform:
|
||||
raise ValueError("必须提供stream_id或platform参数")
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
logger.debug(f"[SendAPI] 向适配器发送命令: {action}")
|
||||
|
||||
# 如果没有提供stream_id,则生成一个临时的
|
||||
if stream_id is None:
|
||||
import uuid
|
||||
|
||||
stream_id = f"adapter_temp_{uuid.uuid4().hex[:8]}"
|
||||
logger.debug(f"[SendAPI] 自动生成临时stream_id: {stream_id}")
|
||||
|
||||
@@ -411,22 +411,15 @@ async def adapter_command_to_stream(
|
||||
# 如果是自动生成的stream_id且找不到聊天流,创建一个临时的虚拟流
|
||||
if stream_id.startswith("adapter_temp_"):
|
||||
logger.debug(f"[SendAPI] 创建临时虚拟聊天流: {stream_id}")
|
||||
|
||||
|
||||
# 创建临时的用户信息和聊天流
|
||||
|
||||
temp_user_info = UserInfo(
|
||||
user_id="system",
|
||||
user_nickname="System",
|
||||
platform=platform
|
||||
)
|
||||
|
||||
temp_user_info = UserInfo(user_id="system", user_nickname="System", platform=platform)
|
||||
|
||||
temp_chat_stream = ChatStream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=temp_user_info,
|
||||
group_info=None
|
||||
stream_id=stream_id, platform=platform, user_info=temp_user_info, group_info=None
|
||||
)
|
||||
|
||||
|
||||
target_stream = temp_chat_stream
|
||||
else:
|
||||
logger.error(f"[SendAPI] 未找到聊天流: {stream_id}")
|
||||
@@ -474,10 +467,7 @@ async def adapter_command_to_stream(
|
||||
|
||||
# 发送消息
|
||||
sent_msg = await heart_fc_sender.send_message(
|
||||
bot_message,
|
||||
typing=False,
|
||||
set_reply=False,
|
||||
storage_message=storage_message
|
||||
bot_message, typing=False, set_reply=False, storage_message=storage_message
|
||||
)
|
||||
|
||||
if not sent_msg:
|
||||
@@ -488,9 +478,9 @@ async def adapter_command_to_stream(
|
||||
|
||||
# 等待适配器响应
|
||||
response = await wait_adapter_response(message_id, timeout)
|
||||
|
||||
|
||||
logger.debug(f"[SendAPI] 收到适配器响应: {response}")
|
||||
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -31,4 +31,4 @@ def get_llm_available_tool_definitions():
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
llm_available_tools = component_registry.get_llm_available_tools()
|
||||
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
|
||||
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
|
||||
|
||||
@@ -147,7 +147,7 @@ class BaseAction(ABC):
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
||||
)
|
||||
|
||||
|
||||
# 验证聊天类型限制
|
||||
if not self._validate_chat_type():
|
||||
logger.warning(
|
||||
@@ -157,7 +157,7 @@ class BaseAction(ABC):
|
||||
|
||||
def _validate_chat_type(self) -> bool:
|
||||
"""验证当前聊天类型是否允许执行此Action
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
@@ -172,9 +172,9 @@ class BaseAction(ABC):
|
||||
|
||||
def is_chat_type_allowed(self) -> bool:
|
||||
"""检查当前聊天类型是否允许执行此Action
|
||||
|
||||
|
||||
这是一个公开的方法,供外部调用检查聊天类型限制
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
@@ -240,9 +240,7 @@ class BaseAction(ABC):
|
||||
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
|
||||
return False, f"等待新消息失败: {str(e)}"
|
||||
|
||||
async def send_text(
|
||||
self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None, typing: bool = False
|
||||
) -> bool:
|
||||
async def send_text(self, content: str, reply_to: str = "", typing: bool = False) -> bool:
|
||||
"""发送文本消息
|
||||
|
||||
Args:
|
||||
|
||||
@@ -46,10 +46,10 @@ class BaseCommand(ABC):
|
||||
self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
||||
|
||||
logger.debug(f"{self.log_prefix} Command组件初始化完成")
|
||||
|
||||
|
||||
# 验证聊天类型限制
|
||||
if not self._validate_chat_type():
|
||||
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
|
||||
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
|
||||
logger.warning(
|
||||
f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: "
|
||||
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
|
||||
@@ -65,16 +65,16 @@ class BaseCommand(ABC):
|
||||
|
||||
def _validate_chat_type(self) -> bool:
|
||||
"""验证当前聊天类型是否允许执行此Command
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
if self.chat_type_allow == ChatType.ALL:
|
||||
return True
|
||||
|
||||
|
||||
# 检查是否为群聊消息
|
||||
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
|
||||
|
||||
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
|
||||
|
||||
if self.chat_type_allow == ChatType.GROUP and is_group:
|
||||
return True
|
||||
elif self.chat_type_allow == ChatType.PRIVATE and not is_group:
|
||||
@@ -84,9 +84,9 @@ class BaseCommand(ABC):
|
||||
|
||||
def is_chat_type_allowed(self) -> bool:
|
||||
"""检查当前聊天类型是否允许执行此Command
|
||||
|
||||
|
||||
这是一个公开的方法,供外部调用检查聊天类型限制
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
|
||||
@@ -3,12 +3,14 @@ from typing import List, Dict, Any, Optional
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("base_event")
|
||||
|
||||
|
||||
|
||||
class HandlerResult:
|
||||
"""事件处理器执行结果
|
||||
|
||||
|
||||
所有事件处理器必须返回此类的实例
|
||||
"""
|
||||
|
||||
def __init__(self, success: bool, continue_process: bool, message: Any = None, handler_name: str = ""):
|
||||
self.success = success
|
||||
self.continue_process = continue_process
|
||||
@@ -18,31 +20,32 @@ class HandlerResult:
|
||||
def __repr__(self):
|
||||
return f"HandlerResult(success={self.success}, continue_process={self.continue_process}, message='{self.message}', handler_name='{self.handler_name}')"
|
||||
|
||||
|
||||
class HandlerResultsCollection:
|
||||
"""HandlerResult集合,提供便捷的查询方法"""
|
||||
|
||||
|
||||
def __init__(self, results: List[HandlerResult]):
|
||||
self.results = results
|
||||
|
||||
|
||||
def all_continue_process(self) -> bool:
|
||||
"""检查是否所有handler的continue_process都为True"""
|
||||
return all(result.continue_process for result in self.results)
|
||||
|
||||
|
||||
def get_all_results(self) -> List[HandlerResult]:
|
||||
"""获取所有HandlerResult"""
|
||||
return self.results
|
||||
|
||||
|
||||
def get_failed_handlers(self) -> List[HandlerResult]:
|
||||
"""获取执行失败的handler结果"""
|
||||
return [result for result in self.results if not result.success]
|
||||
|
||||
|
||||
def get_stopped_handlers(self) -> List[HandlerResult]:
|
||||
"""获取continue_process为False的handler结果"""
|
||||
return [result for result in self.results if not result.continue_process]
|
||||
|
||||
|
||||
def get_message_result(self) -> Any:
|
||||
"""获取handler的message
|
||||
|
||||
|
||||
当只有一个handler的结果时,直接返回那个handler结果中的message字段
|
||||
否则用字典的形式{handler_name:message}返回
|
||||
"""
|
||||
@@ -52,22 +55,22 @@ class HandlerResultsCollection:
|
||||
return self.results[0].message
|
||||
else:
|
||||
return {result.handler_name: result.message for result in self.results}
|
||||
|
||||
|
||||
def get_handler_result(self, handler_name: str) -> Optional[HandlerResult]:
|
||||
"""获取指定handler的结果"""
|
||||
for result in self.results:
|
||||
if result.handler_name == handler_name:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def get_success_count(self) -> int:
|
||||
"""获取成功执行的handler数量"""
|
||||
return sum(1 for result in self.results if result.success)
|
||||
|
||||
|
||||
def get_failure_count(self) -> int:
|
||||
"""获取执行失败的handler数量"""
|
||||
return sum(1 for result in self.results if not result.success)
|
||||
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""获取执行摘要"""
|
||||
return {
|
||||
@@ -76,62 +79,63 @@ class HandlerResultsCollection:
|
||||
"failure_count": self.get_failure_count(),
|
||||
"continue_process": self.all_continue_process(),
|
||||
"failed_handlers": [r.handler_name for r in self.get_failed_handlers()],
|
||||
"stopped_handlers": [r.handler_name for r in self.get_stopped_handlers()]
|
||||
"stopped_handlers": [r.handler_name for r in self.get_stopped_handlers()],
|
||||
}
|
||||
|
||||
|
||||
class BaseEvent:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
allowed_subscribers: List[str] = None,
|
||||
allowed_triggers: List[str] = None
|
||||
):
|
||||
def __init__(self, name: str, allowed_subscribers: List[str] = None, allowed_triggers: List[str] = None):
|
||||
self.name = name
|
||||
self.enabled = True
|
||||
self.allowed_subscribers = allowed_subscribers # 记录事件处理器名
|
||||
self.allowed_triggers = allowed_triggers # 记录插件名
|
||||
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
self.subscribers: List["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表
|
||||
|
||||
self.subscribers: List["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表
|
||||
|
||||
self.event_handle_lock = asyncio.Lock()
|
||||
|
||||
def __name__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
async def activate(self, params: dict) -> HandlerResultsCollection:
|
||||
"""激活事件,执行所有订阅的处理器
|
||||
|
||||
|
||||
Args:
|
||||
params: 传递给处理器的参数
|
||||
|
||||
|
||||
Returns:
|
||||
HandlerResultsCollection: 所有处理器的执行结果集合
|
||||
"""
|
||||
if not self.enabled:
|
||||
return HandlerResultsCollection([])
|
||||
|
||||
|
||||
# 使用锁确保同一个事件不能同时激活多次
|
||||
async with self.event_handle_lock:
|
||||
# 按权重从高到低排序订阅者
|
||||
# 使用直接属性访问,-1代表自动权重
|
||||
sorted_subscribers = sorted(self.subscribers, key=lambda h: h.weight if hasattr(h, 'weight') and h.weight != -1 else 0, reverse=True)
|
||||
|
||||
sorted_subscribers = sorted(
|
||||
self.subscribers, key=lambda h: h.weight if hasattr(h, "weight") and h.weight != -1 else 0, reverse=True
|
||||
)
|
||||
|
||||
# 并行执行所有订阅者
|
||||
tasks = []
|
||||
for subscriber in sorted_subscribers:
|
||||
# 为每个订阅者创建执行任务
|
||||
task = self._execute_subscriber(subscriber, params)
|
||||
tasks.append(task)
|
||||
|
||||
|
||||
# 等待所有任务完成
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 处理执行结果
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
subscriber = sorted_subscribers[i]
|
||||
handler_name = subscriber.handler_name if hasattr(subscriber, 'handler_name') else subscriber.__class__.__name__
|
||||
handler_name = (
|
||||
subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__
|
||||
)
|
||||
if result:
|
||||
if isinstance(result, Exception):
|
||||
# 处理执行异常
|
||||
@@ -143,13 +147,13 @@ class BaseEvent:
|
||||
# 补充handler_name
|
||||
result.handler_name = handler_name
|
||||
processed_results.append(result)
|
||||
|
||||
|
||||
return HandlerResultsCollection(processed_results)
|
||||
|
||||
|
||||
async def _execute_subscriber(self, subscriber, params: dict) -> HandlerResult:
|
||||
"""执行单个订阅者处理器"""
|
||||
try:
|
||||
return await subscriber.execute(params)
|
||||
except Exception as e:
|
||||
# 异常会在 gather 中捕获,这里直接抛出让 gather 处理
|
||||
raise e
|
||||
raise e
|
||||
|
||||
@@ -51,11 +51,11 @@ class BaseEventHandler(ABC):
|
||||
event_name (str): 要订阅的事件名称
|
||||
"""
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
|
||||
if not event_manager.subscribe_handler_to_event(self.handler_name, event_name):
|
||||
logger.error(f"事件处理器 {self.handler_name} 订阅事件 {event_name} 失败")
|
||||
return
|
||||
|
||||
|
||||
logger.debug(f"{self.log_prefix} 订阅事件 {event_name}")
|
||||
self.subscribed_events.append(event_name)
|
||||
|
||||
@@ -66,7 +66,7 @@ class BaseEventHandler(ABC):
|
||||
event_name (str): 要取消订阅的事件名称
|
||||
"""
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
|
||||
if event_manager.unsubscribe_handler_from_event(self.handler_name, event_name):
|
||||
logger.debug(f"{self.log_prefix} 取消订阅事件 {event_name}")
|
||||
if event_name in self.subscribed_events:
|
||||
|
||||
@@ -9,32 +9,32 @@ import shlex
|
||||
|
||||
class CommandArgs:
|
||||
"""命令参数解析类
|
||||
|
||||
|
||||
提供方便的方法来处理命令参数
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, raw_args: str = ""):
|
||||
"""初始化命令参数
|
||||
|
||||
|
||||
Args:
|
||||
raw_args: 原始参数字符串
|
||||
"""
|
||||
self._raw_args = raw_args.strip()
|
||||
self._parsed_args: Optional[List[str]] = None
|
||||
|
||||
|
||||
def get_raw(self) -> str:
|
||||
"""获取完整的参数字符串
|
||||
|
||||
|
||||
Returns:
|
||||
str: 原始参数字符串
|
||||
"""
|
||||
return self._raw_args
|
||||
|
||||
|
||||
def get_args(self) -> List[str]:
|
||||
"""获取解析后的参数列表
|
||||
|
||||
|
||||
将参数按空格分割,支持引号包围的参数
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 参数列表
|
||||
"""
|
||||
@@ -48,25 +48,25 @@ class CommandArgs:
|
||||
except ValueError:
|
||||
# 如果shlex解析失败,fallback到简单的split
|
||||
self._parsed_args = self._raw_args.split()
|
||||
|
||||
|
||||
return self._parsed_args
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""检查参数是否为空
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果没有参数返回True
|
||||
"""
|
||||
return len(self.get_args()) == 0
|
||||
|
||||
|
||||
def get_arg(self, index: int, default: str = "") -> str:
|
||||
"""获取指定索引的参数
|
||||
|
||||
|
||||
Args:
|
||||
index: 参数索引(从0开始)
|
||||
default: 默认值
|
||||
|
||||
|
||||
Returns:
|
||||
str: 参数值或默认值
|
||||
"""
|
||||
@@ -78,21 +78,21 @@ class CommandArgs:
|
||||
@property
|
||||
def get_first(self, default: str = "") -> str:
|
||||
"""获取第一个参数
|
||||
|
||||
|
||||
Args:
|
||||
default: 默认值
|
||||
|
||||
|
||||
Returns:
|
||||
str: 第一个参数或默认值
|
||||
"""
|
||||
return self.get_arg(0, default)
|
||||
|
||||
|
||||
def get_remaining(self, start_index: int = 0) -> str:
|
||||
"""获取从指定索引开始的剩余参数字符串
|
||||
|
||||
|
||||
Args:
|
||||
start_index: 起始索引
|
||||
|
||||
|
||||
Returns:
|
||||
str: 剩余参数组成的字符串
|
||||
"""
|
||||
@@ -100,45 +100,45 @@ class CommandArgs:
|
||||
if start_index < len(args):
|
||||
return " ".join(args[start_index:])
|
||||
return ""
|
||||
|
||||
|
||||
def count(self) -> int:
|
||||
"""获取参数数量
|
||||
|
||||
|
||||
Returns:
|
||||
int: 参数数量
|
||||
"""
|
||||
return len(self.get_args())
|
||||
|
||||
|
||||
def has_flag(self, flag: str) -> bool:
|
||||
"""检查是否包含指定的标志参数
|
||||
|
||||
|
||||
Args:
|
||||
flag: 标志名(如 "--verbose" 或 "-v")
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果包含该标志返回True
|
||||
"""
|
||||
return flag in self.get_args()
|
||||
|
||||
|
||||
def get_flag_value(self, flag: str, default: str = "") -> str:
|
||||
"""获取标志参数的值
|
||||
|
||||
|
||||
查找 --key=value 或 --key value 形式的参数
|
||||
|
||||
|
||||
Args:
|
||||
flag: 标志名(如 "--output")
|
||||
default: 默认值
|
||||
|
||||
|
||||
Returns:
|
||||
str: 标志的值或默认值
|
||||
"""
|
||||
args = self.get_args()
|
||||
|
||||
|
||||
# 查找 --key=value 形式
|
||||
for arg in args:
|
||||
if arg.startswith(f"{flag}="):
|
||||
return arg[len(flag) + 1:]
|
||||
|
||||
return arg[len(flag) + 1 :]
|
||||
|
||||
# 查找 --key value 形式
|
||||
try:
|
||||
flag_index = args.index(flag)
|
||||
@@ -146,13 +146,13 @@ class CommandArgs:
|
||||
return args[flag_index + 1]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示"""
|
||||
return self._raw_args
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""调试表示"""
|
||||
return f"CommandArgs(raw='{self._raw_args}', parsed={self.get_args()})"
|
||||
|
||||
@@ -6,6 +6,7 @@ from maim_message import Seg
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
||||
|
||||
|
||||
# 组件类型枚举
|
||||
class ComponentType(Enum):
|
||||
"""组件类型枚举"""
|
||||
@@ -185,7 +186,9 @@ class PlusCommandInfo(ComponentInfo):
|
||||
class ToolInfo(ComponentInfo):
|
||||
"""工具组件信息"""
|
||||
|
||||
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义
|
||||
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(
|
||||
default_factory=list
|
||||
) # 工具参数定义
|
||||
tool_description: str = "" # 工具描述
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -205,6 +208,7 @@ class EventHandlerInfo(ComponentInfo):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.EVENT_HANDLER
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventInfo(ComponentInfo):
|
||||
"""事件组件信息"""
|
||||
@@ -213,6 +217,7 @@ class EventInfo(ComponentInfo):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.EVENT
|
||||
|
||||
|
||||
# 事件类型枚举
|
||||
class EventType(Enum):
|
||||
"""
|
||||
@@ -232,6 +237,7 @@ class EventType(Enum):
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginInfo:
|
||||
"""插件信息"""
|
||||
@@ -320,16 +326,16 @@ class MaiMessages:
|
||||
|
||||
llm_response_content: Optional[str] = None
|
||||
"""LLM响应内容"""
|
||||
|
||||
|
||||
llm_response_reasoning: Optional[str] = None
|
||||
"""LLM响应推理内容"""
|
||||
|
||||
|
||||
llm_response_model: Optional[str] = None
|
||||
"""LLM响应模型名称"""
|
||||
|
||||
|
||||
llm_response_tool_call: Optional[List[ToolCall]] = None
|
||||
"""LLM使用的工具调用"""
|
||||
|
||||
|
||||
action_usage: Optional[List[str]] = None
|
||||
"""使用的Action"""
|
||||
|
||||
|
||||
@@ -90,10 +90,10 @@ class PluginBase(ABC):
|
||||
|
||||
# 标准化Python依赖为PythonDependency对象
|
||||
normalized_python_deps = self._normalize_python_dependencies(self.python_dependencies)
|
||||
|
||||
|
||||
# 检查Python依赖
|
||||
self._check_python_dependencies(normalized_python_deps)
|
||||
|
||||
|
||||
# 创建插件信息对象
|
||||
self.plugin_info = PluginInfo(
|
||||
name=self.plugin_name,
|
||||
@@ -560,7 +560,7 @@ class PluginBase(ABC):
|
||||
def _normalize_python_dependencies(self, dependencies: Any) -> List[PythonDependency]:
|
||||
"""将依赖列表标准化为PythonDependency对象"""
|
||||
from packaging.requirements import Requirement
|
||||
|
||||
|
||||
normalized = []
|
||||
for dep in dependencies:
|
||||
if isinstance(dep, str):
|
||||
@@ -568,23 +568,22 @@ class PluginBase(ABC):
|
||||
# 尝试解析为requirement格式 (如 "package>=1.0.0")
|
||||
req = Requirement(dep)
|
||||
version_spec = str(req.specifier) if req.specifier else ""
|
||||
|
||||
normalized.append(PythonDependency(
|
||||
package_name=req.name,
|
||||
version=version_spec,
|
||||
install_name=dep # 保持原始的安装名称
|
||||
))
|
||||
|
||||
normalized.append(
|
||||
PythonDependency(
|
||||
package_name=req.name,
|
||||
version=version_spec,
|
||||
install_name=dep, # 保持原始的安装名称
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# 如果解析失败,作为简单包名处理
|
||||
normalized.append(PythonDependency(
|
||||
package_name=dep,
|
||||
install_name=dep
|
||||
))
|
||||
normalized.append(PythonDependency(package_name=dep, install_name=dep))
|
||||
elif isinstance(dep, PythonDependency):
|
||||
normalized.append(dep)
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 未知的依赖格式: {dep}")
|
||||
|
||||
|
||||
return normalized
|
||||
|
||||
def _check_python_dependencies(self, dependencies: List[PythonDependency]) -> bool:
|
||||
@@ -596,10 +595,10 @@ class PluginBase(ABC):
|
||||
try:
|
||||
# 延迟导入以避免循环依赖
|
||||
from src.plugin_system.utils.dependency_manager import get_dependency_manager
|
||||
|
||||
|
||||
dependency_manager = get_dependency_manager()
|
||||
success, errors = dependency_manager.check_and_install_dependencies(dependencies, self.plugin_name)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} Python依赖检查通过")
|
||||
return True
|
||||
@@ -608,7 +607,7 @@ class PluginBase(ABC):
|
||||
for error in errors:
|
||||
logger.error(f"{self.log_prefix} - {error}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} Python依赖检查时发生异常: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@@ -20,12 +20,12 @@ logger = get_logger("plus_command")
|
||||
|
||||
class PlusCommand(ABC):
|
||||
"""增强版命令基类
|
||||
|
||||
|
||||
提供更简单的命令定义方式,无需手写正则表达式
|
||||
|
||||
|
||||
子类只需要定义:
|
||||
- command_name: 命令名称
|
||||
- command_description: 命令描述
|
||||
- command_description: 命令描述
|
||||
- command_aliases: 命令别名列表(可选)
|
||||
- priority: 优先级(可选,数字越大优先级越高)
|
||||
- chat_type_allow: 允许的聊天类型(可选)
|
||||
@@ -35,19 +35,19 @@ class PlusCommand(ABC):
|
||||
# 子类需要定义的属性
|
||||
command_name: str = ""
|
||||
"""命令名称,如 'echo'"""
|
||||
|
||||
|
||||
command_description: str = ""
|
||||
"""命令描述"""
|
||||
|
||||
|
||||
command_aliases: List[str] = []
|
||||
"""命令别名列表,如 ['say', 'repeat']"""
|
||||
|
||||
|
||||
priority: int = 0
|
||||
"""命令优先级,数字越大优先级越高"""
|
||||
|
||||
|
||||
chat_type_allow: ChatType = ChatType.ALL
|
||||
"""允许的聊天类型"""
|
||||
|
||||
|
||||
intercept_message: bool = False
|
||||
"""是否拦截消息,不进行后续处理"""
|
||||
|
||||
@@ -61,13 +61,13 @@ class PlusCommand(ABC):
|
||||
self.message = message
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.log_prefix = "[PlusCommand]"
|
||||
|
||||
|
||||
# 解析命令参数
|
||||
self._parse_command()
|
||||
|
||||
|
||||
# 验证聊天类型限制
|
||||
if not self._validate_chat_type():
|
||||
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
|
||||
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
|
||||
logger.warning(
|
||||
f"{self.log_prefix} 命令 '{self.command_name}' 不支持当前聊天类型: "
|
||||
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
|
||||
@@ -75,59 +75,59 @@ class PlusCommand(ABC):
|
||||
|
||||
def _parse_command(self) -> None:
|
||||
"""解析命令和参数"""
|
||||
if not hasattr(self.message, 'plain_text') or not self.message.plain_text:
|
||||
if not hasattr(self.message, "plain_text") or not self.message.plain_text:
|
||||
self.args = CommandArgs("")
|
||||
return
|
||||
|
||||
|
||||
plain_text = self.message.plain_text.strip()
|
||||
|
||||
|
||||
# 获取配置的命令前缀
|
||||
prefixes = global_config.command.command_prefixes
|
||||
|
||||
|
||||
# 检查是否以任何前缀开头
|
||||
matched_prefix = None
|
||||
for prefix in prefixes:
|
||||
if plain_text.startswith(prefix):
|
||||
matched_prefix = prefix
|
||||
break
|
||||
|
||||
|
||||
if not matched_prefix:
|
||||
self.args = CommandArgs("")
|
||||
return
|
||||
|
||||
|
||||
# 移除前缀
|
||||
command_part = plain_text[len(matched_prefix):].strip()
|
||||
|
||||
command_part = plain_text[len(matched_prefix) :].strip()
|
||||
|
||||
# 分离命令名和参数
|
||||
parts = command_part.split(None, 1)
|
||||
if not parts:
|
||||
self.args = CommandArgs("")
|
||||
return
|
||||
|
||||
|
||||
command_word = parts[0].lower()
|
||||
args_text = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
|
||||
# 检查命令名是否匹配
|
||||
all_commands = [self.command_name.lower()] + [alias.lower() for alias in self.command_aliases]
|
||||
if command_word not in all_commands:
|
||||
self.args = CommandArgs("")
|
||||
return
|
||||
|
||||
|
||||
# 创建参数对象
|
||||
self.args = CommandArgs(args_text)
|
||||
|
||||
def _validate_chat_type(self) -> bool:
|
||||
"""验证当前聊天类型是否允许执行此命令
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
if self.chat_type_allow == ChatType.ALL:
|
||||
return True
|
||||
|
||||
|
||||
# 检查是否为群聊消息
|
||||
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
|
||||
|
||||
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
|
||||
|
||||
if self.chat_type_allow == ChatType.GROUP and is_group:
|
||||
return True
|
||||
elif self.chat_type_allow == ChatType.PRIVATE and not is_group:
|
||||
@@ -137,7 +137,7 @@ class PlusCommand(ABC):
|
||||
|
||||
def is_chat_type_allowed(self) -> bool:
|
||||
"""检查当前聊天类型是否允许执行此命令
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
@@ -145,30 +145,30 @@ class PlusCommand(ABC):
|
||||
|
||||
def is_command_match(self) -> bool:
|
||||
"""检查当前消息是否匹配此命令
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果匹配返回True
|
||||
"""
|
||||
return not self.args.is_empty() or self._is_exact_command_call()
|
||||
|
||||
|
||||
def _is_exact_command_call(self) -> bool:
|
||||
"""检查是否是精确的命令调用(无参数)"""
|
||||
if not hasattr(self.message, 'plain_text') or not self.message.plain_text:
|
||||
if not hasattr(self.message, "plain_text") or not self.message.plain_text:
|
||||
return False
|
||||
|
||||
|
||||
plain_text = self.message.plain_text.strip()
|
||||
|
||||
|
||||
# 获取配置的命令前缀
|
||||
prefixes = global_config.command.command_prefixes
|
||||
|
||||
|
||||
# 检查每个前缀
|
||||
for prefix in prefixes:
|
||||
if plain_text.startswith(prefix):
|
||||
command_part = plain_text[len(prefix):].strip()
|
||||
command_part = plain_text[len(prefix) :].strip()
|
||||
all_commands = [self.command_name.lower()] + [alias.lower() for alias in self.command_aliases]
|
||||
if command_part.lower() in all_commands:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
@@ -298,10 +298,10 @@ class PlusCommand(ABC):
|
||||
if "." in cls.command_name:
|
||||
logger.error(f"命令名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"命令名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
|
||||
|
||||
# 生成正则表达式模式来匹配命令
|
||||
command_pattern = cls._generate_command_pattern()
|
||||
|
||||
|
||||
return CommandInfo(
|
||||
name=cls.command_name,
|
||||
component_type=ComponentType.COMMAND,
|
||||
@@ -320,7 +320,7 @@ class PlusCommand(ABC):
|
||||
if "." in cls.command_name:
|
||||
logger.error(f"命令名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"命令名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
|
||||
|
||||
return PlusCommandInfo(
|
||||
name=cls.command_name,
|
||||
component_type=ComponentType.PLUS_COMMAND,
|
||||
@@ -334,38 +334,38 @@ class PlusCommand(ABC):
|
||||
@classmethod
|
||||
def _generate_command_pattern(cls) -> str:
|
||||
"""生成命令匹配的正则表达式
|
||||
|
||||
|
||||
Returns:
|
||||
str: 正则表达式字符串
|
||||
"""
|
||||
# 获取所有可能的命令名(主命令名 + 别名)
|
||||
all_commands = [cls.command_name] + getattr(cls, 'command_aliases', [])
|
||||
|
||||
all_commands = [cls.command_name] + getattr(cls, "command_aliases", [])
|
||||
|
||||
# 转义特殊字符并创建选择组
|
||||
escaped_commands = [re.escape(cmd) for cmd in all_commands]
|
||||
commands_pattern = "|".join(escaped_commands)
|
||||
|
||||
|
||||
# 获取默认前缀列表(这里先用硬编码,后续可以优化为动态获取)
|
||||
default_prefixes = ["/", "!", ".", "#"]
|
||||
escaped_prefixes = [re.escape(prefix) for prefix in default_prefixes]
|
||||
prefixes_pattern = "|".join(escaped_prefixes)
|
||||
|
||||
|
||||
# 生成完整的正则表达式
|
||||
# 匹配: [前缀][命令名][可选空白][任意参数]
|
||||
pattern = f"^(?P<prefix>{prefixes_pattern})(?P<command>{commands_pattern})(?P<args>\\s.*)?$"
|
||||
|
||||
|
||||
return pattern
|
||||
|
||||
|
||||
class PlusCommandAdapter(BaseCommand):
|
||||
"""PlusCommand适配器
|
||||
|
||||
|
||||
将PlusCommand适配到现有的插件系统,继承BaseCommand
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, plus_command_class, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
"""初始化适配器
|
||||
|
||||
|
||||
Args:
|
||||
plus_command_class: PlusCommand子类
|
||||
message: 消息对象
|
||||
@@ -378,27 +378,27 @@ class PlusCommandAdapter(BaseCommand):
|
||||
self.chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
|
||||
self.priority = getattr(plus_command_class, "priority", 0)
|
||||
self.intercept_message = getattr(plus_command_class, "intercept_message", False)
|
||||
|
||||
|
||||
# 调用父类初始化
|
||||
super().__init__(message, plugin_config)
|
||||
|
||||
|
||||
# 创建PlusCommand实例
|
||||
self.plus_command = plus_command_class(message, plugin_config)
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
"""执行命令
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str], bool]: 执行结果
|
||||
"""
|
||||
# 检查命令是否匹配
|
||||
if not self.plus_command.is_command_match():
|
||||
return False, "命令不匹配", False
|
||||
|
||||
|
||||
# 检查聊天类型权限
|
||||
if not self.plus_command.is_chat_type_allowed():
|
||||
return False, "不支持当前聊天类型", self.intercept_message
|
||||
|
||||
|
||||
# 执行命令
|
||||
try:
|
||||
return await self.plus_command.execute(self.plus_command.args)
|
||||
@@ -409,49 +409,50 @@ class PlusCommandAdapter(BaseCommand):
|
||||
|
||||
def create_plus_command_adapter(plus_command_class):
|
||||
"""创建PlusCommand适配器的工厂函数
|
||||
|
||||
|
||||
Args:
|
||||
plus_command_class: PlusCommand子类
|
||||
|
||||
|
||||
Returns:
|
||||
适配器类
|
||||
"""
|
||||
|
||||
class AdapterClass(BaseCommand):
|
||||
command_name = plus_command_class.command_name
|
||||
command_description = plus_command_class.command_description
|
||||
command_pattern = plus_command_class._generate_command_pattern()
|
||||
chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
|
||||
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
super().__init__(message, plugin_config)
|
||||
self.plus_command = plus_command_class(message, plugin_config)
|
||||
self.priority = getattr(plus_command_class, "priority", 0)
|
||||
self.intercept_message = getattr(plus_command_class, "intercept_message", False)
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
"""执行命令"""
|
||||
# 从BaseCommand的正则匹配结果中提取参数
|
||||
args_text = ""
|
||||
if hasattr(self, 'matched_groups') and self.matched_groups:
|
||||
if hasattr(self, "matched_groups") and self.matched_groups:
|
||||
# 从正则匹配组中获取参数部分
|
||||
args_match = self.matched_groups.get('args', '')
|
||||
args_match = self.matched_groups.get("args", "")
|
||||
if args_match:
|
||||
args_text = args_match.strip()
|
||||
|
||||
|
||||
# 创建CommandArgs对象
|
||||
command_args = CommandArgs(args_text)
|
||||
|
||||
|
||||
# 检查聊天类型权限
|
||||
if not self.plus_command.is_chat_type_allowed():
|
||||
return False, "不支持当前聊天类型", self.intercept_message
|
||||
|
||||
|
||||
# 执行命令,传递正确解析的参数
|
||||
try:
|
||||
return await self.plus_command.execute(command_args)
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {e}", exc_info=True)
|
||||
return False, f"命令执行出错: {str(e)}", self.intercept_message
|
||||
|
||||
|
||||
return AdapterClass
|
||||
|
||||
|
||||
|
||||
@@ -34,7 +34,9 @@ class ComponentRegistry:
|
||||
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
||||
"""类型 -> 组件原名称 -> 组件信息"""
|
||||
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler, PlusCommand]]] = {}
|
||||
self._components_classes: Dict[
|
||||
str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler, PlusCommand]]
|
||||
] = {}
|
||||
"""命名空间式组件名 -> 组件类"""
|
||||
|
||||
# 插件注册表
|
||||
@@ -166,7 +168,7 @@ class ComponentRegistry:
|
||||
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
|
||||
logger.error(f"注册失败: {action_name} 不是有效的Action")
|
||||
return False
|
||||
|
||||
|
||||
action_class.plugin_name = action_info.plugin_name
|
||||
self._action_registry[action_name] = action_class
|
||||
|
||||
@@ -200,7 +202,9 @@ class ComponentRegistry:
|
||||
|
||||
return True
|
||||
|
||||
def _register_plus_command_component(self, plus_command_info: PlusCommandInfo, plus_command_class: Type[PlusCommand]) -> bool:
|
||||
def _register_plus_command_component(
|
||||
self, plus_command_info: PlusCommandInfo, plus_command_class: Type[PlusCommand]
|
||||
) -> bool:
|
||||
"""注册PlusCommand组件到特定注册表"""
|
||||
plus_command_name = plus_command_info.name
|
||||
|
||||
@@ -212,7 +216,7 @@ class ComponentRegistry:
|
||||
return False
|
||||
|
||||
# 创建专门的PlusCommand注册表(如果还没有)
|
||||
if not hasattr(self, '_plus_command_registry'):
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
|
||||
|
||||
plus_command_class.plugin_name = plus_command_info.plugin_name
|
||||
@@ -249,10 +253,11 @@ class ComponentRegistry:
|
||||
if not handler_info.enabled:
|
||||
logger.warning(f"EventHandler组件 {handler_name} 未启用")
|
||||
return True # 未启用,但是也是注册成功
|
||||
|
||||
|
||||
handler_class.plugin_name = handler_info.plugin_name
|
||||
# 使用EventManager进行事件处理器注册
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
return event_manager.register_event_handler(handler_class)
|
||||
|
||||
# === 组件移除相关 ===
|
||||
@@ -281,7 +286,7 @@ class ComponentRegistry:
|
||||
|
||||
case ComponentType.PLUS_COMMAND:
|
||||
# 移除PlusCommand注册
|
||||
if hasattr(self, '_plus_command_registry'):
|
||||
if hasattr(self, "_plus_command_registry"):
|
||||
self._plus_command_registry.pop(component_name, None)
|
||||
logger.debug(f"已移除PlusCommand组件: {component_name}")
|
||||
|
||||
@@ -371,6 +376,7 @@ class ComponentRegistry:
|
||||
assert issubclass(target_component_class, BaseEventHandler)
|
||||
self._enabled_event_handlers[component_name] = target_component_class
|
||||
from .event_manager import event_manager # 延迟导入防止循环导入问题
|
||||
|
||||
event_manager.register_event_handler(component_name)
|
||||
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
@@ -572,7 +578,7 @@ class ComponentRegistry:
|
||||
candidates[0].match(text).groupdict(), # type: ignore
|
||||
command_info,
|
||||
)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
# === Tool 特定查询方法 ===
|
||||
@@ -599,7 +605,7 @@ class ComponentRegistry:
|
||||
# === PlusCommand 特定查询方法 ===
|
||||
def get_plus_command_registry(self) -> Dict[str, Type[PlusCommand]]:
|
||||
"""获取PlusCommand注册表"""
|
||||
if not hasattr(self, '_plus_command_registry'):
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
|
||||
return self._plus_command_registry.copy()
|
||||
|
||||
|
||||
@@ -2,55 +2,57 @@
|
||||
事件管理器 - 实现Event和EventHandler的单例管理
|
||||
提供统一的事件注册、管理和触发接口
|
||||
"""
|
||||
|
||||
from typing import Dict, Type, List, Optional, Any, Union
|
||||
from threading import Lock
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection, HandlerResult
|
||||
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
||||
logger = get_logger("event_manager")
|
||||
|
||||
|
||||
class EventManager:
|
||||
"""事件管理器单例类
|
||||
|
||||
|
||||
负责管理所有事件和事件处理器的注册、订阅、触发等操作
|
||||
使用单例模式确保全局只有一个事件管理实例
|
||||
"""
|
||||
|
||||
_instance: Optional['EventManager'] = None
|
||||
|
||||
_instance: Optional["EventManager"] = None
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls) -> 'EventManager':
|
||||
|
||||
def __new__(cls) -> "EventManager":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
|
||||
self._events: Dict[str, BaseEvent] = {}
|
||||
self._event_handlers: Dict[str, Type[BaseEventHandler]] = {}
|
||||
self._pending_subscriptions: Dict[str, List[str]] = {} # 缓存失败的订阅
|
||||
self._initialized = True
|
||||
logger.info("EventManager 单例初始化完成")
|
||||
|
||||
|
||||
def register_event(
|
||||
self,
|
||||
event_name: Union[EventType, str],
|
||||
allowed_subscribers: List[str]=None,
|
||||
allowed_triggers: List[str]=None
|
||||
) -> bool:
|
||||
self,
|
||||
event_name: Union[EventType, str],
|
||||
allowed_subscribers: List[str] = None,
|
||||
allowed_triggers: List[str] = None,
|
||||
) -> bool:
|
||||
"""注册一个新的事件
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
allowed_subscribers: List[str]: 事件订阅者白名单,
|
||||
allowed_subscribers: List[str]: 事件订阅者白名单,
|
||||
allowed_triggers: List[str]: 事件触发插件白名单
|
||||
Returns:
|
||||
bool: 注册成功返回True,已存在返回False
|
||||
@@ -62,57 +64,57 @@ class EventManager:
|
||||
if event_name in self._events:
|
||||
logger.warning(f"事件 {event_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
event = BaseEvent(event_name,allowed_subscribers,allowed_triggers)
|
||||
|
||||
event = BaseEvent(event_name, allowed_subscribers, allowed_triggers)
|
||||
self._events[event_name] = event
|
||||
logger.info(f"事件 {event_name} 注册成功")
|
||||
|
||||
|
||||
# 检查是否有缓存的订阅需要处理
|
||||
self._process_pending_subscriptions(event_name)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_event(self, event_name: Union[EventType, str]) -> Optional[BaseEvent]:
|
||||
"""获取指定事件实例
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
|
||||
Returns:
|
||||
BaseEvent: 事件实例,不存在返回None
|
||||
"""
|
||||
return self._events.get(event_name)
|
||||
|
||||
|
||||
def get_all_events(self) -> Dict[str, BaseEvent]:
|
||||
"""获取所有已注册的事件
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseEvent]: 所有事件的字典
|
||||
"""
|
||||
return self._events.copy()
|
||||
|
||||
|
||||
def get_enabled_events(self) -> Dict[str, BaseEvent]:
|
||||
"""获取所有已启用的事件
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseEvent]: 已启用事件的字典
|
||||
"""
|
||||
return {name: event for name, event in self._events.items() if event.enabled}
|
||||
|
||||
|
||||
def get_disabled_events(self) -> Dict[str, BaseEvent]:
|
||||
"""获取所有已禁用的事件
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseEvent]: 已禁用事件的字典
|
||||
"""
|
||||
return {name: event for name, event in self._events.items() if not event.enabled}
|
||||
|
||||
|
||||
def enable_event(self, event_name: Union[EventType, str]) -> bool:
|
||||
"""启用指定事件
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 成功返回True,事件不存在返回False
|
||||
"""
|
||||
@@ -120,17 +122,17 @@ class EventManager:
|
||||
if event is None:
|
||||
logger.error(f"事件 {event_name} 不存在,无法启用")
|
||||
return False
|
||||
|
||||
|
||||
event.enabled = True
|
||||
logger.info(f"事件 {event_name} 已启用")
|
||||
return True
|
||||
|
||||
|
||||
def disable_event(self, event_name: Union[EventType, str]) -> bool:
|
||||
"""禁用指定事件
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 成功返回True,事件不存在返回False
|
||||
"""
|
||||
@@ -138,38 +140,38 @@ class EventManager:
|
||||
if event is None:
|
||||
logger.error(f"事件 {event_name} 不存在,无法禁用")
|
||||
return False
|
||||
|
||||
|
||||
event.enabled = False
|
||||
logger.info(f"事件 {event_name} 已禁用")
|
||||
return True
|
||||
|
||||
|
||||
def register_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
"""注册事件处理器
|
||||
|
||||
|
||||
Args:
|
||||
handler_class (Type[BaseEventHandler]): 事件处理器类
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 注册成功返回True,已存在返回False
|
||||
"""
|
||||
handler_name = handler_class.handler_name or handler_class.__name__.lower().replace("handler", "")
|
||||
|
||||
|
||||
if EventType.UNKNOWN in handler_class.init_subscribe:
|
||||
logger.error(f"事件处理器 {handler_name} 不能订阅 UNKNOWN 事件")
|
||||
return False
|
||||
if handler_name in self._event_handlers:
|
||||
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
|
||||
self._event_handlers[handler_name] = handler_class()
|
||||
|
||||
|
||||
# 处理init_subscribe,缓存失败的订阅
|
||||
if self._event_handlers[handler_name].init_subscribe:
|
||||
failed_subscriptions = []
|
||||
for event_name in self._event_handlers[handler_name].init_subscribe:
|
||||
if not self.subscribe_handler_to_event(handler_name, event_name):
|
||||
failed_subscriptions.append(event_name)
|
||||
|
||||
|
||||
# 缓存失败的订阅
|
||||
if failed_subscriptions:
|
||||
self._pending_subscriptions[handler_name] = failed_subscriptions
|
||||
@@ -177,33 +179,33 @@ class EventManager:
|
||||
|
||||
logger.info(f"事件处理器 {handler_name} 注册成功")
|
||||
return True
|
||||
|
||||
|
||||
def get_event_handler(self, handler_name: str) -> Optional[Type[BaseEventHandler]]:
|
||||
"""获取指定事件处理器实例
|
||||
|
||||
|
||||
Args:
|
||||
handler_name (str): 处理器名称
|
||||
|
||||
|
||||
Returns:
|
||||
Type[BaseEventHandler]: 处理器实例,不存在返回None
|
||||
"""
|
||||
return self._event_handlers.get(handler_name)
|
||||
|
||||
|
||||
def get_all_event_handlers(self) -> Dict[str, BaseEventHandler]:
|
||||
"""获取所有已注册的事件处理器
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Type[BaseEventHandler]]: 所有处理器的字典
|
||||
"""
|
||||
return self._event_handlers.copy()
|
||||
|
||||
|
||||
def subscribe_handler_to_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
|
||||
"""订阅事件处理器到指定事件
|
||||
|
||||
|
||||
Args:
|
||||
handler_name (str): 处理器名称
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 订阅成功返回True
|
||||
"""
|
||||
@@ -211,36 +213,36 @@ class EventManager:
|
||||
if handler_instance is None:
|
||||
logger.error(f"事件处理器 {handler_name} 不存在,无法订阅到事件 {event_name}")
|
||||
return False
|
||||
|
||||
|
||||
event = self.get_event(event_name)
|
||||
if event is None:
|
||||
logger.error(f"事件 {event_name} 不存在,无法订阅事件处理器 {handler_name}")
|
||||
return False
|
||||
|
||||
|
||||
if handler_instance in event.subscribers:
|
||||
logger.warning(f"事件处理器 {handler_name} 已经订阅了事件 {event_name},跳过重复订阅")
|
||||
return True
|
||||
|
||||
|
||||
# 白名单检查
|
||||
if event.allowed_subscribers and handler_name not in event.allowed_subscribers:
|
||||
logger.warning(f"事件处理器 {handler_name} 不在事件 {event_name} 的订阅者白名单中,无法订阅")
|
||||
return False
|
||||
|
||||
|
||||
event.subscribers.append(handler_instance)
|
||||
|
||||
|
||||
# 按权重从高到低排序订阅者
|
||||
event.subscribers.sort(key=lambda h: getattr(h, 'weight', 0), reverse=True)
|
||||
|
||||
event.subscribers.sort(key=lambda h: getattr(h, "weight", 0), reverse=True)
|
||||
|
||||
logger.info(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成")
|
||||
return True
|
||||
|
||||
|
||||
def unsubscribe_handler_from_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
|
||||
"""从指定事件取消订阅事件处理器
|
||||
|
||||
|
||||
Args:
|
||||
handler_name (str): 处理器名称
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 取消订阅成功返回True
|
||||
"""
|
||||
@@ -248,55 +250,57 @@ class EventManager:
|
||||
if event is None:
|
||||
logger.error(f"事件 {event_name} 不存在,无法取消订阅")
|
||||
return False
|
||||
|
||||
|
||||
# 查找并移除处理器实例
|
||||
removed = False
|
||||
for subscriber in event.subscribers[:]:
|
||||
if hasattr(subscriber, 'handler_name') and subscriber.handler_name == handler_name:
|
||||
if hasattr(subscriber, "handler_name") and subscriber.handler_name == handler_name:
|
||||
event.subscribers.remove(subscriber)
|
||||
removed = True
|
||||
break
|
||||
|
||||
|
||||
if removed:
|
||||
logger.info(f"事件处理器 {handler_name} 成功从事件 {event_name} 取消订阅")
|
||||
else:
|
||||
logger.warning(f"事件处理器 {handler_name} 未订阅事件 {event_name}")
|
||||
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def get_event_subscribers(self, event_name: Union[EventType, str]) -> Dict[str, BaseEventHandler]:
|
||||
"""获取订阅指定事件的所有事件处理器
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseEventHandler]: 处理器字典,键为处理器名称,值为处理器实例
|
||||
"""
|
||||
event = self.get_event(event_name)
|
||||
if event is None:
|
||||
return {}
|
||||
|
||||
|
||||
return {handler.handler_name: handler for handler in event.subscribers}
|
||||
|
||||
async def trigger_event(self, event_name: Union[EventType, str], plugin_name: Optional[str]="", **kwargs) -> Optional[HandlerResultsCollection]:
|
||||
|
||||
async def trigger_event(
|
||||
self, event_name: Union[EventType, str], plugin_name: Optional[str] = "", **kwargs
|
||||
) -> Optional[HandlerResultsCollection]:
|
||||
"""触发指定事件
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
plugin_name str: 触发事件的插件名
|
||||
**kwargs: 传递给处理器的参数
|
||||
|
||||
|
||||
Returns:
|
||||
HandlerResultsCollection: 所有处理器的执行结果,事件不存在返回None
|
||||
"""
|
||||
params = kwargs or {}
|
||||
|
||||
|
||||
event = self.get_event(event_name)
|
||||
if event is None:
|
||||
logger.error(f"事件 {event_name} 不存在,无法触发")
|
||||
return None
|
||||
|
||||
|
||||
# 插件白名单检查
|
||||
if event.allowed_triggers and not plugin_name:
|
||||
logger.warning(f"事件 {event_name} 存在触发者白名单,缺少plugin_name无法验证权限,已拒绝触发!")
|
||||
@@ -304,9 +308,9 @@ class EventManager:
|
||||
elif event.allowed_triggers and plugin_name not in event.allowed_triggers:
|
||||
logger.warning(f"插件 {plugin_name} 没有权限触发事件 {event_name},已拒绝触发!")
|
||||
return None
|
||||
|
||||
|
||||
return await event.activate(params)
|
||||
|
||||
|
||||
def init_default_events(self) -> None:
|
||||
"""初始化默认事件"""
|
||||
default_events = [
|
||||
@@ -317,29 +321,29 @@ class EventManager:
|
||||
EventType.POST_LLM,
|
||||
EventType.AFTER_LLM,
|
||||
EventType.POST_SEND,
|
||||
EventType.AFTER_SEND
|
||||
EventType.AFTER_SEND,
|
||||
]
|
||||
|
||||
|
||||
for event_name in default_events:
|
||||
self.register_event(event_name,allowed_triggers=["SYSTEM"])
|
||||
|
||||
self.register_event(event_name, allowed_triggers=["SYSTEM"])
|
||||
|
||||
logger.info("默认事件初始化完成")
|
||||
|
||||
|
||||
def clear_all_events(self) -> None:
|
||||
"""清除所有事件和处理器(主要用于测试)"""
|
||||
self._events.clear()
|
||||
self._event_handlers.clear()
|
||||
logger.info("所有事件和处理器已清除")
|
||||
|
||||
|
||||
def get_event_summary(self) -> Dict[str, Any]:
|
||||
"""获取事件系统摘要
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含事件系统统计信息的字典
|
||||
"""
|
||||
enabled_events = self.get_enabled_events()
|
||||
disabled_events = self.get_disabled_events()
|
||||
|
||||
|
||||
return {
|
||||
"total_events": len(self._events),
|
||||
"enabled_events": len(enabled_events),
|
||||
@@ -347,58 +351,58 @@ class EventManager:
|
||||
"total_handlers": len(self._event_handlers),
|
||||
"event_names": list(self._events.keys()),
|
||||
"handler_names": list(self._event_handlers.keys()),
|
||||
"pending_subscriptions": len(self._pending_subscriptions)
|
||||
"pending_subscriptions": len(self._pending_subscriptions),
|
||||
}
|
||||
|
||||
def _process_pending_subscriptions(self, event_name: Union[EventType, str]) -> None:
|
||||
"""处理指定事件的缓存订阅
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
"""
|
||||
handlers_to_remove = []
|
||||
|
||||
|
||||
for handler_name, pending_events in self._pending_subscriptions.items():
|
||||
if event_name in pending_events:
|
||||
if self.subscribe_handler_to_event(handler_name, event_name):
|
||||
pending_events.remove(event_name)
|
||||
logger.info(f"成功处理缓存订阅: {handler_name} -> {event_name}")
|
||||
|
||||
|
||||
# 如果该处理器没有更多待处理订阅,标记为移除
|
||||
if not pending_events:
|
||||
handlers_to_remove.append(handler_name)
|
||||
|
||||
|
||||
# 清理已完成的处理器缓存
|
||||
for handler_name in handlers_to_remove:
|
||||
del self._pending_subscriptions[handler_name]
|
||||
|
||||
def process_all_pending_subscriptions(self) -> int:
|
||||
"""处理所有缓存的订阅
|
||||
|
||||
|
||||
Returns:
|
||||
int: 成功处理的订阅数量
|
||||
"""
|
||||
processed_count = 0
|
||||
|
||||
|
||||
# 复制待处理订阅,避免在迭代时修改字典
|
||||
pending_copy = dict(self._pending_subscriptions)
|
||||
|
||||
|
||||
for handler_name, pending_events in pending_copy.items():
|
||||
for event_name in pending_events[:]: # 使用切片避免修改列表
|
||||
if self.subscribe_handler_to_event(handler_name, event_name):
|
||||
pending_events.remove(event_name)
|
||||
processed_count += 1
|
||||
|
||||
|
||||
# 清理已完成的处理器缓存
|
||||
handlers_to_remove = [name for name, events in self._pending_subscriptions.items() if not events]
|
||||
for handler_name in handlers_to_remove:
|
||||
del self._pending_subscriptions[handler_name]
|
||||
|
||||
|
||||
if processed_count > 0:
|
||||
logger.info(f"批量处理缓存订阅完成,共处理 {processed_count} 个订阅")
|
||||
|
||||
|
||||
return processed_count
|
||||
|
||||
|
||||
# 创建全局事件管理器实例
|
||||
event_manager = EventManager()
|
||||
event_manager = EventManager()
|
||||
|
||||
@@ -88,7 +88,7 @@ class GlobalAnnouncementManager:
|
||||
return False
|
||||
self._user_disabled_tools[chat_id].append(tool_name)
|
||||
return True
|
||||
|
||||
|
||||
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
|
||||
"""启用特定聊天的某个工具"""
|
||||
if chat_id in self._user_disabled_tools:
|
||||
@@ -111,7 +111,7 @@ class GlobalAnnouncementManager:
|
||||
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有事件处理器"""
|
||||
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||
|
||||
|
||||
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有工具"""
|
||||
return self._user_disabled_tools.get(chat_id, []).copy()
|
||||
|
||||
@@ -19,14 +19,14 @@ logger = get_logger(__name__)
|
||||
|
||||
class PermissionManager(IPermissionManager):
|
||||
"""权限管理器实现类"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.engine = get_engine()
|
||||
self.SessionLocal = sessionmaker(bind=self.engine)
|
||||
self._master_users: Set[Tuple[str, str]] = set()
|
||||
self._load_master_users()
|
||||
logger.info("权限管理器初始化完成")
|
||||
|
||||
|
||||
def _load_master_users(self):
|
||||
"""从配置文件加载Master用户列表"""
|
||||
try:
|
||||
@@ -40,19 +40,19 @@ class PermissionManager(IPermissionManager):
|
||||
except Exception as e:
|
||||
logger.warning(f"加载Master用户配置失败: {e}")
|
||||
self._master_users = set()
|
||||
|
||||
|
||||
def reload_master_users(self):
|
||||
"""重新加载Master用户配置"""
|
||||
self._load_master_users()
|
||||
logger.info("Master用户配置已重新加载")
|
||||
|
||||
|
||||
def is_master(self, user: UserInfo) -> bool:
|
||||
"""
|
||||
检查用户是否为Master用户
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否为Master用户
|
||||
"""
|
||||
@@ -61,15 +61,15 @@ class PermissionManager(IPermissionManager):
|
||||
if is_master:
|
||||
logger.debug(f"用户 {user.platform}:{user.user_id} 是Master用户")
|
||||
return is_master
|
||||
|
||||
|
||||
def check_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||
"""
|
||||
检查用户是否拥有指定权限节点
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否拥有权限
|
||||
"""
|
||||
@@ -78,46 +78,50 @@ class PermissionManager(IPermissionManager):
|
||||
if self.is_master(user):
|
||||
logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}")
|
||||
return True
|
||||
|
||||
|
||||
with self.SessionLocal() as session:
|
||||
# 检查权限节点是否存在
|
||||
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first()
|
||||
if not node:
|
||||
logger.warning(f"权限节点 {permission_node} 不存在")
|
||||
return False
|
||||
|
||||
|
||||
# 检查用户是否有明确的权限设置
|
||||
user_perm = session.query(UserPermissions).filter_by(
|
||||
platform=user.platform,
|
||||
user_id=user.user_id,
|
||||
permission_node=permission_node
|
||||
).first()
|
||||
|
||||
user_perm = (
|
||||
session.query(UserPermissions)
|
||||
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_perm:
|
||||
# 有明确设置,返回设置的值
|
||||
result = user_perm.granted
|
||||
logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {result}")
|
||||
logger.debug(
|
||||
f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {result}"
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# 没有明确设置,使用默认值
|
||||
result = node.default_granted
|
||||
logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {result}")
|
||||
logger.debug(
|
||||
f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {result}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"检查权限时数据库错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"检查权限时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def register_permission_node(self, node: PermissionNode) -> bool:
|
||||
"""
|
||||
注册权限节点
|
||||
|
||||
|
||||
Args:
|
||||
node: 权限节点
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 注册是否成功
|
||||
"""
|
||||
@@ -133,20 +137,20 @@ class PermissionManager(IPermissionManager):
|
||||
session.commit()
|
||||
logger.debug(f"更新权限节点: {node.node_name}")
|
||||
return True
|
||||
|
||||
|
||||
# 创建新节点
|
||||
new_node = PermissionNodes(
|
||||
node_name=node.node_name,
|
||||
description=node.description,
|
||||
plugin_name=node.plugin_name,
|
||||
default_granted=node.default_granted,
|
||||
created_at=datetime.utcnow()
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
session.add(new_node)
|
||||
session.commit()
|
||||
logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})")
|
||||
return True
|
||||
|
||||
|
||||
except IntegrityError as e:
|
||||
logger.error(f"注册权限节点时发生完整性错误: {e}")
|
||||
return False
|
||||
@@ -156,15 +160,15 @@ class PermissionManager(IPermissionManager):
|
||||
except Exception as e:
|
||||
logger.error(f"注册权限节点时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def grant_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||
"""
|
||||
授权用户权限节点
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 授权是否成功
|
||||
"""
|
||||
@@ -175,14 +179,14 @@ class PermissionManager(IPermissionManager):
|
||||
if not node:
|
||||
logger.error(f"尝试授权不存在的权限节点: {permission_node}")
|
||||
return False
|
||||
|
||||
|
||||
# 检查是否已有权限记录
|
||||
existing_perm = session.query(UserPermissions).filter_by(
|
||||
platform=user.platform,
|
||||
user_id=user.user_id,
|
||||
permission_node=permission_node
|
||||
).first()
|
||||
|
||||
existing_perm = (
|
||||
session.query(UserPermissions)
|
||||
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_perm:
|
||||
# 更新现有记录
|
||||
existing_perm.granted = True
|
||||
@@ -194,29 +198,29 @@ class PermissionManager(IPermissionManager):
|
||||
user_id=user.user_id,
|
||||
permission_node=permission_node,
|
||||
granted=True,
|
||||
granted_at=datetime.utcnow()
|
||||
granted_at=datetime.utcnow(),
|
||||
)
|
||||
session.add(new_perm)
|
||||
|
||||
|
||||
session.commit()
|
||||
logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}")
|
||||
return True
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"授权权限时数据库错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"授权权限时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def revoke_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||
"""
|
||||
撤销用户权限节点
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 撤销是否成功
|
||||
"""
|
||||
@@ -227,14 +231,14 @@ class PermissionManager(IPermissionManager):
|
||||
if not node:
|
||||
logger.error(f"尝试撤销不存在的权限节点: {permission_node}")
|
||||
return False
|
||||
|
||||
|
||||
# 检查是否已有权限记录
|
||||
existing_perm = session.query(UserPermissions).filter_by(
|
||||
platform=user.platform,
|
||||
user_id=user.user_id,
|
||||
permission_node=permission_node
|
||||
).first()
|
||||
|
||||
existing_perm = (
|
||||
session.query(UserPermissions)
|
||||
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_perm:
|
||||
# 更新现有记录
|
||||
existing_perm.granted = False
|
||||
@@ -246,28 +250,28 @@ class PermissionManager(IPermissionManager):
|
||||
user_id=user.user_id,
|
||||
permission_node=permission_node,
|
||||
granted=False,
|
||||
granted_at=datetime.utcnow()
|
||||
granted_at=datetime.utcnow(),
|
||||
)
|
||||
session.add(new_perm)
|
||||
|
||||
|
||||
session.commit()
|
||||
logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}")
|
||||
return True
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"撤销权限时数据库错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"撤销权限时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_user_permissions(self, user: UserInfo) -> List[str]:
|
||||
"""
|
||||
获取用户拥有的所有权限节点
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 权限节点列表
|
||||
"""
|
||||
@@ -277,21 +281,21 @@ class PermissionManager(IPermissionManager):
|
||||
with self.SessionLocal() as session:
|
||||
all_nodes = session.query(PermissionNodes.node_name).all()
|
||||
return [node.node_name for node in all_nodes]
|
||||
|
||||
|
||||
permissions = []
|
||||
|
||||
|
||||
with self.SessionLocal() as session:
|
||||
# 获取所有权限节点
|
||||
all_nodes = session.query(PermissionNodes).all()
|
||||
|
||||
|
||||
for node in all_nodes:
|
||||
# 检查用户是否有明确的权限设置
|
||||
user_perm = session.query(UserPermissions).filter_by(
|
||||
platform=user.platform,
|
||||
user_id=user.user_id,
|
||||
permission_node=node.node_name
|
||||
).first()
|
||||
|
||||
user_perm = (
|
||||
session.query(UserPermissions)
|
||||
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=node.node_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_perm:
|
||||
# 有明确设置,使用设置的值
|
||||
if user_perm.granted:
|
||||
@@ -300,20 +304,20 @@ class PermissionManager(IPermissionManager):
|
||||
# 没有明确设置,使用默认值
|
||||
if node.default_granted:
|
||||
permissions.append(node.node_name)
|
||||
|
||||
|
||||
return permissions
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"获取用户权限时数据库错误: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户权限时发生未知错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_all_permission_nodes(self) -> List[PermissionNode]:
|
||||
"""
|
||||
获取所有已注册的权限节点
|
||||
|
||||
|
||||
Returns:
|
||||
List[PermissionNode]: 权限节点列表
|
||||
"""
|
||||
@@ -325,25 +329,25 @@ class PermissionManager(IPermissionManager):
|
||||
node_name=node.node_name,
|
||||
description=node.description,
|
||||
plugin_name=node.plugin_name,
|
||||
default_granted=node.default_granted
|
||||
default_granted=node.default_granted,
|
||||
)
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"获取所有权限节点时数据库错误: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"获取所有权限节点时发生未知错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
|
||||
"""
|
||||
获取指定插件的所有权限节点
|
||||
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
|
||||
Returns:
|
||||
List[PermissionNode]: 权限节点列表
|
||||
"""
|
||||
@@ -355,25 +359,25 @@ class PermissionManager(IPermissionManager):
|
||||
node_name=node.node_name,
|
||||
description=node.description,
|
||||
plugin_name=node.plugin_name,
|
||||
default_granted=node.default_granted
|
||||
default_granted=node.default_granted,
|
||||
)
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"获取插件权限节点时数据库错误: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件权限节点时发生未知错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def delete_plugin_permissions(self, plugin_name: str) -> bool:
|
||||
"""
|
||||
删除指定插件的所有权限节点(用于插件卸载时清理)
|
||||
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
@@ -382,68 +386,71 @@ class PermissionManager(IPermissionManager):
|
||||
# 获取插件的所有权限节点
|
||||
plugin_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all()
|
||||
node_names = [node.node_name for node in plugin_nodes]
|
||||
|
||||
|
||||
if not node_names:
|
||||
logger.info(f"插件 {plugin_name} 没有注册任何权限节点")
|
||||
return True
|
||||
|
||||
|
||||
# 删除用户权限记录
|
||||
deleted_user_perms = session.query(UserPermissions).filter(
|
||||
UserPermissions.permission_node.in_(node_names)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
deleted_user_perms = (
|
||||
session.query(UserPermissions)
|
||||
.filter(UserPermissions.permission_node.in_(node_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
# 删除权限节点
|
||||
deleted_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).delete()
|
||||
|
||||
|
||||
session.commit()
|
||||
logger.info(f"已删除插件 {plugin_name} 的 {deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录")
|
||||
logger.info(
|
||||
f"已删除插件 {plugin_name} 的 {deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"删除插件权限时数据库错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"删除插件权限时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
获取拥有指定权限的所有用户
|
||||
|
||||
|
||||
Args:
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 用户列表,格式为 [(platform, user_id), ...]
|
||||
"""
|
||||
try:
|
||||
users = []
|
||||
|
||||
|
||||
with self.SessionLocal() as session:
|
||||
# 检查权限节点是否存在
|
||||
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first()
|
||||
if not node:
|
||||
logger.warning(f"权限节点 {permission_node} 不存在")
|
||||
return users
|
||||
|
||||
|
||||
# 获取明确授权的用户
|
||||
granted_users = session.query(UserPermissions).filter_by(
|
||||
permission_node=permission_node,
|
||||
granted=True
|
||||
).all()
|
||||
|
||||
granted_users = (
|
||||
session.query(UserPermissions).filter_by(permission_node=permission_node, granted=True).all()
|
||||
)
|
||||
|
||||
for user_perm in granted_users:
|
||||
users.append((user_perm.platform, user_perm.user_id))
|
||||
|
||||
|
||||
# 如果是默认授权的权限节点,还需要考虑没有明确设置的用户
|
||||
# 但这里我们只返回明确授权的用户,避免返回所有用户
|
||||
|
||||
|
||||
# 添加Master用户(他们拥有所有权限)
|
||||
users.extend(list(self._master_users))
|
||||
|
||||
|
||||
# 去重
|
||||
return list(set(users))
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"获取拥有权限的用户时数据库错误: {e}")
|
||||
return []
|
||||
|
||||
@@ -36,21 +36,21 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
"""文件修改事件"""
|
||||
if not event.is_directory:
|
||||
file_path = str(event.src_path)
|
||||
if file_path.endswith(('.py', '.toml')):
|
||||
if file_path.endswith((".py", ".toml")):
|
||||
self._handle_file_change(file_path, "modified")
|
||||
|
||||
def on_created(self, event):
|
||||
"""文件创建事件"""
|
||||
if not event.is_directory:
|
||||
file_path = str(event.src_path)
|
||||
if file_path.endswith(('.py', '.toml')):
|
||||
if file_path.endswith((".py", ".toml")):
|
||||
self._handle_file_change(file_path, "created")
|
||||
|
||||
def on_deleted(self, event):
|
||||
"""文件删除事件"""
|
||||
if not event.is_directory:
|
||||
file_path = str(event.src_path)
|
||||
if file_path.endswith(('.py', '.toml')):
|
||||
if file_path.endswith((".py", ".toml")):
|
||||
self._handle_file_change(file_path, "deleted")
|
||||
|
||||
def _handle_file_change(self, file_path: str, change_type: str):
|
||||
@@ -63,14 +63,14 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
|
||||
plugin_name, source_type = plugin_info
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# 文件变化缓存,避免重复处理同一文件的快速连续变化
|
||||
file_cache_key = f"{file_path}_{change_type}"
|
||||
last_file_time = self.file_change_cache.get(file_cache_key, 0)
|
||||
if current_time - last_file_time < 0.5: # 0.5秒内的重复文件变化忽略
|
||||
return
|
||||
self.file_change_cache[file_cache_key] = current_time
|
||||
|
||||
|
||||
# 插件级别的防抖处理
|
||||
last_plugin_time = self.last_reload_time.get(plugin_name, 0)
|
||||
if current_time - last_plugin_time < self.debounce_delay:
|
||||
@@ -85,20 +85,28 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
if change_type == "deleted":
|
||||
# 解析实际的插件名称
|
||||
actual_plugin_name = self.hot_reload_manager._resolve_plugin_name(plugin_name)
|
||||
|
||||
|
||||
if file_name == "plugin.py":
|
||||
if actual_plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(f"🗑️ 插件主文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]")
|
||||
logger.info(
|
||||
f"🗑️ 插件主文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
||||
)
|
||||
self.hot_reload_manager._unload_plugin(actual_plugin_name)
|
||||
else:
|
||||
logger.info(f"🗑️ 插件主文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]")
|
||||
logger.info(
|
||||
f"🗑️ 插件主文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
||||
)
|
||||
return
|
||||
elif file_name in ("manifest.toml", "_manifest.json"):
|
||||
if actual_plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]")
|
||||
logger.info(
|
||||
f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
||||
)
|
||||
self.hot_reload_manager._unload_plugin(actual_plugin_name)
|
||||
else:
|
||||
logger.info(f"🗑️ 插件配置文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]")
|
||||
logger.info(
|
||||
f"🗑️ 插件配置文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
||||
)
|
||||
return
|
||||
|
||||
# 对于修改和创建事件,都进行重载
|
||||
@@ -108,9 +116,7 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
|
||||
# 延迟重载,确保文件写入完成
|
||||
reload_thread = Thread(
|
||||
target=self._delayed_reload,
|
||||
args=(plugin_name, source_type, current_time),
|
||||
daemon=True
|
||||
target=self._delayed_reload, args=(plugin_name, source_type, current_time), daemon=True
|
||||
)
|
||||
reload_thread.start()
|
||||
|
||||
@@ -126,14 +132,14 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
# 检查是否还需要重载(可能在等待期间有更新的变化)
|
||||
if plugin_name not in self.pending_reloads:
|
||||
return
|
||||
|
||||
|
||||
# 检查是否有更新的重载请求
|
||||
if self.last_reload_time.get(plugin_name, 0) > trigger_time:
|
||||
return
|
||||
|
||||
self.pending_reloads.discard(plugin_name)
|
||||
logger.info(f"🔄 开始延迟重载插件: {plugin_name} [{source_type}]")
|
||||
|
||||
|
||||
# 执行深度重载
|
||||
success = self.hot_reload_manager._deep_reload_plugin(plugin_name)
|
||||
if success:
|
||||
@@ -146,7 +152,7 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
|
||||
def _get_plugin_info_from_path(self, file_path: str) -> Optional[Tuple[str, str]]:
|
||||
"""从文件路径获取插件信息
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[插件名称, 源类型] 或 None
|
||||
"""
|
||||
@@ -162,12 +168,12 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
source_type = "built-in"
|
||||
else:
|
||||
source_type = "external"
|
||||
|
||||
|
||||
# 获取插件目录名(插件名)
|
||||
relative_path = path.relative_to(plugin_root)
|
||||
if len(relative_path.parts) == 0:
|
||||
continue
|
||||
|
||||
|
||||
plugin_name = relative_path.parts[0]
|
||||
|
||||
# 确认这是一个有效的插件目录
|
||||
@@ -175,9 +181,10 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
if plugin_dir.is_dir():
|
||||
# 检查是否有插件主文件或配置文件
|
||||
has_plugin_py = (plugin_dir / "plugin.py").exists()
|
||||
has_manifest = ((plugin_dir / "manifest.toml").exists() or
|
||||
(plugin_dir / "_manifest.json").exists())
|
||||
|
||||
has_manifest = (plugin_dir / "manifest.toml").exists() or (
|
||||
plugin_dir / "_manifest.json"
|
||||
).exists()
|
||||
|
||||
if has_plugin_py or has_manifest:
|
||||
return plugin_name, source_type
|
||||
|
||||
@@ -195,11 +202,11 @@ class PluginHotReloadManager:
|
||||
# 默认监听两个目录:根目录下的 plugins 和 src 下的插件目录
|
||||
self.watch_directories = [
|
||||
os.path.join(os.getcwd(), "plugins"), # 外部插件目录
|
||||
os.path.join(os.getcwd(), "src", "plugins", "built_in") # 内置插件目录
|
||||
os.path.join(os.getcwd(), "src", "plugins", "built_in"), # 内置插件目录
|
||||
]
|
||||
else:
|
||||
self.watch_directories = watch_directories
|
||||
|
||||
|
||||
self.observers = []
|
||||
self.file_handlers = []
|
||||
self.is_running = False
|
||||
@@ -221,13 +228,9 @@ class PluginHotReloadManager:
|
||||
for watch_dir in self.watch_directories:
|
||||
observer = Observer()
|
||||
file_handler = PluginFileHandler(self)
|
||||
|
||||
observer.schedule(
|
||||
file_handler,
|
||||
watch_dir,
|
||||
recursive=True
|
||||
)
|
||||
|
||||
|
||||
observer.schedule(file_handler, watch_dir, recursive=True)
|
||||
|
||||
observer.start()
|
||||
self.observers.append(observer)
|
||||
self.file_handlers.append(file_handler)
|
||||
@@ -296,26 +299,26 @@ class PluginHotReloadManager:
|
||||
if folder_name in plugin_manager.plugin_classes:
|
||||
logger.debug(f"🔍 直接匹配插件名: {folder_name}")
|
||||
return folder_name
|
||||
|
||||
|
||||
# 如果没有直接匹配,搜索路径映射,并优先返回在插件类中存在的名称
|
||||
matched_plugins = []
|
||||
for plugin_name, plugin_path in plugin_manager.plugin_paths.items():
|
||||
# 检查路径是否包含该文件夹名
|
||||
if folder_name in plugin_path:
|
||||
matched_plugins.append((plugin_name, plugin_path))
|
||||
|
||||
|
||||
# 在匹配的插件中,优先选择在插件类中存在的
|
||||
for plugin_name, plugin_path in matched_plugins:
|
||||
if plugin_name in plugin_manager.plugin_classes:
|
||||
logger.debug(f"🔍 文件夹名 '{folder_name}' 映射到插件名 '{plugin_name}' (路径: {plugin_path})")
|
||||
return plugin_name
|
||||
|
||||
|
||||
# 如果还是没找到在插件类中存在的,返回第一个匹配项
|
||||
if matched_plugins:
|
||||
plugin_name, plugin_path = matched_plugins[0]
|
||||
logger.warning(f"⚠️ 文件夹 '{folder_name}' 映射到 '{plugin_name}',但该插件类不存在")
|
||||
return plugin_name
|
||||
|
||||
|
||||
# 如果还是没找到,返回原文件夹名
|
||||
logger.warning(f"⚠️ 无法找到文件夹 '{folder_name}' 对应的插件名,使用原名称")
|
||||
return folder_name
|
||||
@@ -326,13 +329,13 @@ class PluginHotReloadManager:
|
||||
# 解析实际的插件名称
|
||||
actual_plugin_name = self._resolve_plugin_name(plugin_name)
|
||||
logger.info(f"🔄 开始深度重载插件: {plugin_name} -> {actual_plugin_name}")
|
||||
|
||||
|
||||
# 强制清理相关模块缓存
|
||||
self._force_clear_plugin_modules(plugin_name)
|
||||
|
||||
|
||||
# 使用插件管理器的强制重载功能
|
||||
success = plugin_manager.force_reload_plugin(actual_plugin_name)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"✅ 插件深度重载成功: {actual_plugin_name}")
|
||||
return True
|
||||
@@ -348,15 +351,15 @@ class PluginHotReloadManager:
|
||||
|
||||
def _force_clear_plugin_modules(self, plugin_name: str):
|
||||
"""强制清理插件相关的模块缓存"""
|
||||
|
||||
|
||||
# 找到所有相关的模块名
|
||||
modules_to_remove = []
|
||||
plugin_module_prefix = f"src.plugins.built_in.{plugin_name}"
|
||||
|
||||
|
||||
for module_name in list(sys.modules.keys()):
|
||||
if plugin_module_prefix in module_name:
|
||||
modules_to_remove.append(module_name)
|
||||
|
||||
|
||||
# 删除模块缓存
|
||||
for module_name in modules_to_remove:
|
||||
if module_name in sys.modules:
|
||||
@@ -369,7 +372,7 @@ class PluginHotReloadManager:
|
||||
# 使用插件管理器的重载功能
|
||||
success = plugin_manager.reload_plugin(plugin_name)
|
||||
return success
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 强制重新导入插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -378,7 +381,7 @@ class PluginHotReloadManager:
|
||||
"""卸载指定插件"""
|
||||
try:
|
||||
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
|
||||
|
||||
|
||||
if plugin_manager.unload_plugin(plugin_name):
|
||||
logger.info(f"✅ 插件卸载成功: {plugin_name}")
|
||||
return True
|
||||
@@ -409,7 +412,7 @@ class PluginHotReloadManager:
|
||||
fail_count += 1
|
||||
|
||||
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count} 个")
|
||||
|
||||
|
||||
# 清理全局缓存
|
||||
importlib.invalidate_caches()
|
||||
|
||||
@@ -420,21 +423,21 @@ class PluginHotReloadManager:
|
||||
"""手动强制重载指定插件(委托给插件管理器)"""
|
||||
try:
|
||||
logger.info(f"🔄 手动强制重载插件: {plugin_name}")
|
||||
|
||||
|
||||
# 清理待重载列表中的该插件(避免重复重载)
|
||||
for handler in self.file_handlers:
|
||||
handler.pending_reloads.discard(plugin_name)
|
||||
|
||||
|
||||
# 使用插件管理器的强制重载功能
|
||||
success = plugin_manager.force_reload_plugin(plugin_name)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"✅ 手动强制重载成功: {plugin_name}")
|
||||
else:
|
||||
logger.error(f"❌ 手动强制重载失败: {plugin_name}")
|
||||
|
||||
|
||||
return success
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 手动强制重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -457,19 +460,15 @@ class PluginHotReloadManager:
|
||||
try:
|
||||
observer = Observer()
|
||||
file_handler = PluginFileHandler(self)
|
||||
|
||||
observer.schedule(
|
||||
file_handler,
|
||||
directory,
|
||||
recursive=True
|
||||
)
|
||||
|
||||
|
||||
observer.schedule(file_handler, directory, recursive=True)
|
||||
|
||||
observer.start()
|
||||
self.observers.append(observer)
|
||||
self.file_handlers.append(file_handler)
|
||||
|
||||
|
||||
logger.info(f"📂 已添加新的监听目录: {directory}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 添加监听目录 {directory} 失败: {e}")
|
||||
self.watch_directories.remove(directory)
|
||||
@@ -480,7 +479,7 @@ class PluginHotReloadManager:
|
||||
if self.file_handlers:
|
||||
for handler in self.file_handlers:
|
||||
pending_reloads.update(handler.pending_reloads)
|
||||
|
||||
|
||||
return {
|
||||
"is_running": self.is_running,
|
||||
"watch_directories": self.watch_directories,
|
||||
@@ -495,11 +494,11 @@ class PluginHotReloadManager:
|
||||
"""清理所有Python模块缓存"""
|
||||
try:
|
||||
logger.info("🧹 开始清理所有Python模块缓存...")
|
||||
|
||||
|
||||
# 重新扫描所有插件目录,这会重新加载模块
|
||||
plugin_manager.rescan_plugin_directory()
|
||||
logger.info("✅ 模块缓存清理完成")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 清理模块缓存时发生错误: {e}", exc_info=True)
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ class PluginManager:
|
||||
return False # 目标文件不存在,视为不同
|
||||
|
||||
# 使用 'rb' 模式以二进制方式读取文件,确保哈希值计算的一致性
|
||||
with open(file1, 'rb') as f1, open(file2, 'rb') as f2:
|
||||
with open(file1, "rb") as f1, open(file2, "rb") as f2:
|
||||
return hashlib.md5(f1.read()).hexdigest() == hashlib.md5(f2.read()).hexdigest()
|
||||
|
||||
# === 插件目录管理 ===
|
||||
@@ -300,7 +300,7 @@ class PluginManager:
|
||||
list: 已注册的插件类名称列表。
|
||||
"""
|
||||
return list(self.plugin_classes.keys())
|
||||
|
||||
|
||||
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
|
||||
"""
|
||||
获取指定插件的路径。
|
||||
@@ -366,7 +366,7 @@ class PluginManager:
|
||||
# 生成模块名和插件信息
|
||||
plugin_path = Path(plugin_file)
|
||||
plugin_dir = plugin_path.parent # 插件目录
|
||||
plugin_name = plugin_dir.name # 插件名称
|
||||
plugin_name = plugin_dir.name # 插件名称
|
||||
module_name = ".".join(plugin_path.parent.parts)
|
||||
|
||||
try:
|
||||
@@ -386,7 +386,7 @@ class PluginManager:
|
||||
except Exception as e:
|
||||
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
|
||||
logger.error(error_msg)
|
||||
self.failed_plugins[plugin_name if 'plugin_name' in locals() else module_name] = error_msg
|
||||
self.failed_plugins[plugin_name if "plugin_name" in locals() else module_name] = error_msg
|
||||
return False
|
||||
|
||||
# == 兼容性检查 ==
|
||||
@@ -478,9 +478,7 @@ class PluginManager:
|
||||
command_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
|
||||
]
|
||||
tool_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
|
||||
]
|
||||
tool_components = [c for c in plugin_info.components if c.component_type == ComponentType.TOOL]
|
||||
event_handler_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
|
||||
]
|
||||
@@ -591,7 +589,7 @@ class PluginManager:
|
||||
plugin_instance = self.loaded_plugins[plugin_name]
|
||||
|
||||
# 调用插件的清理方法(如果有的话)
|
||||
if hasattr(plugin_instance, 'on_unload'):
|
||||
if hasattr(plugin_instance, "on_unload"):
|
||||
plugin_instance.on_unload()
|
||||
|
||||
# 从组件注册表中移除插件的所有组件
|
||||
@@ -654,10 +652,10 @@ class PluginManager:
|
||||
|
||||
def force_reload_plugin(self, plugin_name: str) -> bool:
|
||||
"""强制重载插件(使用简化的方法)
|
||||
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 重载是否成功
|
||||
"""
|
||||
|
||||
@@ -129,17 +129,17 @@ class ToolExecutor:
|
||||
if not tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无需执行工具")
|
||||
return [], []
|
||||
|
||||
|
||||
# 提取tool_calls中的函数名称
|
||||
func_names = []
|
||||
for call in tool_calls:
|
||||
try:
|
||||
if hasattr(call, 'func_name'):
|
||||
if hasattr(call, "func_name"):
|
||||
func_names.append(call.func_name)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}获取工具名称失败: {e}")
|
||||
continue
|
||||
|
||||
|
||||
if func_names:
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
|
||||
else:
|
||||
@@ -185,9 +185,11 @@ class ToolExecutor:
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
|
||||
async def execute_tool_call(
|
||||
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""执行单个工具调用,并处理缓存"""
|
||||
|
||||
|
||||
function_args = tool_call.args or {}
|
||||
tool_instance = tool_instance or get_tool_instance(tool_call.func_name)
|
||||
|
||||
@@ -206,7 +208,7 @@ class ToolExecutor:
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=function_args,
|
||||
tool_file_path=tool_file_path,
|
||||
semantic_query=semantic_query
|
||||
semantic_query=semantic_query,
|
||||
)
|
||||
if cached_result:
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||
@@ -223,14 +225,14 @@ class ToolExecutor:
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
|
||||
await tool_cache.set(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=function_args,
|
||||
tool_file_path=tool_file_path,
|
||||
data=result,
|
||||
ttl=tool_instance.cache_ttl,
|
||||
semantic_query=semantic_query
|
||||
semantic_query=semantic_query,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
|
||||
@@ -238,12 +240,16 @@ class ToolExecutor:
|
||||
|
||||
return result
|
||||
|
||||
async def _original_execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
|
||||
async def _original_execute_tool_call(
|
||||
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""执行单个工具调用的原始逻辑"""
|
||||
try:
|
||||
function_name = tool_call.func_name
|
||||
function_args = tool_call.args or {}
|
||||
logger.info(f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}"
|
||||
)
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
# 获取对应工具实例
|
||||
tool_instance = tool_instance or get_tool_instance(function_name)
|
||||
@@ -261,7 +267,7 @@ class ToolExecutor:
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": "function",
|
||||
"content": result.get("content", "")
|
||||
"content": result.get("content", ""),
|
||||
}
|
||||
logger.warning(f"{self.log_prefix}工具 {function_name} 返回空结果")
|
||||
return None
|
||||
@@ -308,7 +314,6 @@ class ToolExecutor:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
"""
|
||||
ToolExecutor使用示例:
|
||||
|
||||
|
||||
@@ -23,112 +23,105 @@
|
||||
|
||||
INSTALL_NAME_TO_IMPORT_NAME = {
|
||||
# ============== 数据科学与机器学习 (Data Science & Machine Learning) ==============
|
||||
"scikit-learn": "sklearn", # 机器学习库
|
||||
"scikit-image": "skimage", # 图像处理库
|
||||
"opencv-python": "cv2", # OpenCV 计算机视觉库
|
||||
"opencv-contrib-python": "cv2", # OpenCV 扩展模块
|
||||
"tensorflow-gpu": "tensorflow", # TensorFlow GPU版本
|
||||
"tensorboardx": "tensorboardX", # TensorBoard 的封装
|
||||
"torchvision": "torchvision", # PyTorch 视觉库 (通常与 torch 一起)
|
||||
"torchaudio": "torchaudio", # PyTorch 音频库
|
||||
"catboost": "catboost", # CatBoost 梯度提升库
|
||||
"lightgbm": "lightgbm", # LightGBM 梯度提升库
|
||||
"xgboost": "xgboost", # XGBoost 梯度提升库
|
||||
"imbalanced-learn": "imblearn", # 处理不平衡数据集
|
||||
"seqeval": "seqeval", # 序列标注评估
|
||||
"gensim": "gensim", # 主题建模和NLP
|
||||
"nltk": "nltk", # 自然语言工具包
|
||||
"spacy": "spacy", # 工业级自然语言处理
|
||||
"fuzzywuzzy": "fuzzywuzzy", # 模糊字符串匹配
|
||||
"python-levenshtein": "Levenshtein", # Levenshtein 距离计算
|
||||
|
||||
"scikit-learn": "sklearn", # 机器学习库
|
||||
"scikit-image": "skimage", # 图像处理库
|
||||
"opencv-python": "cv2", # OpenCV 计算机视觉库
|
||||
"opencv-contrib-python": "cv2", # OpenCV 扩展模块
|
||||
"tensorflow-gpu": "tensorflow", # TensorFlow GPU版本
|
||||
"tensorboardx": "tensorboardX", # TensorBoard 的封装
|
||||
"torchvision": "torchvision", # PyTorch 视觉库 (通常与 torch 一起)
|
||||
"torchaudio": "torchaudio", # PyTorch 音频库
|
||||
"catboost": "catboost", # CatBoost 梯度提升库
|
||||
"lightgbm": "lightgbm", # LightGBM 梯度提升库
|
||||
"xgboost": "xgboost", # XGBoost 梯度提升库
|
||||
"imbalanced-learn": "imblearn", # 处理不平衡数据集
|
||||
"seqeval": "seqeval", # 序列标注评估
|
||||
"gensim": "gensim", # 主题建模和NLP
|
||||
"nltk": "nltk", # 自然语言工具包
|
||||
"spacy": "spacy", # 工业级自然语言处理
|
||||
"fuzzywuzzy": "fuzzywuzzy", # 模糊字符串匹配
|
||||
"python-levenshtein": "Levenshtein", # Levenshtein 距离计算
|
||||
# ============== Web开发与API (Web Development & API) ==============
|
||||
"python-socketio": "socketio", # Socket.IO 服务器和客户端
|
||||
"python-engineio": "engineio", # Engine.IO 底层库
|
||||
"aiohttp": "aiohttp", # 异步HTTP客户端/服务器
|
||||
"python-multipart": "multipart", # 解析 multipart/form-data
|
||||
"uvloop": "uvloop", # 高性能asyncio事件循环
|
||||
"httptools": "httptools", # 高性能HTTP解析器
|
||||
"websockets": "websockets", # WebSocket实现
|
||||
"fastapi": "fastapi", # 高性能Web框架
|
||||
"starlette": "starlette", # ASGI框架
|
||||
"uvicorn": "uvicorn", # ASGI服务器
|
||||
"gunicorn": "gunicorn", # WSGI服务器
|
||||
"django-rest-framework": "rest_framework", # Django REST框架
|
||||
"django-cors-headers": "corsheaders", # Django CORS处理
|
||||
"flask-jwt-extended": "flask_jwt_extended", # Flask JWT扩展
|
||||
"flask-sqlalchemy": "flask_sqlalchemy", # Flask SQLAlchemy扩展
|
||||
"flask-migrate": "flask_migrate", # Flask Alembic迁移扩展
|
||||
"python-jose": "jose", # JOSE (JWT, JWS, JWE) 实现
|
||||
"passlib": "passlib", # 密码哈希库
|
||||
"bcrypt": "bcrypt", # Bcrypt密码哈希
|
||||
|
||||
"python-socketio": "socketio", # Socket.IO 服务器和客户端
|
||||
"python-engineio": "engineio", # Engine.IO 底层库
|
||||
"aiohttp": "aiohttp", # 异步HTTP客户端/服务器
|
||||
"python-multipart": "multipart", # 解析 multipart/form-data
|
||||
"uvloop": "uvloop", # 高性能asyncio事件循环
|
||||
"httptools": "httptools", # 高性能HTTP解析器
|
||||
"websockets": "websockets", # WebSocket实现
|
||||
"fastapi": "fastapi", # 高性能Web框架
|
||||
"starlette": "starlette", # ASGI框架
|
||||
"uvicorn": "uvicorn", # ASGI服务器
|
||||
"gunicorn": "gunicorn", # WSGI服务器
|
||||
"django-rest-framework": "rest_framework", # Django REST框架
|
||||
"django-cors-headers": "corsheaders", # Django CORS处理
|
||||
"flask-jwt-extended": "flask_jwt_extended", # Flask JWT扩展
|
||||
"flask-sqlalchemy": "flask_sqlalchemy", # Flask SQLAlchemy扩展
|
||||
"flask-migrate": "flask_migrate", # Flask Alembic迁移扩展
|
||||
"python-jose": "jose", # JOSE (JWT, JWS, JWE) 实现
|
||||
"passlib": "passlib", # 密码哈希库
|
||||
"bcrypt": "bcrypt", # Bcrypt密码哈希
|
||||
# ============== 数据库 (Database) ==============
|
||||
"mysql-connector-python": "mysql.connector", # MySQL官方驱动
|
||||
"psycopg2-binary": "psycopg2", # PostgreSQL驱动 (二进制)
|
||||
"pymongo": "pymongo", # MongoDB驱动
|
||||
"redis": "redis", # Redis客户端
|
||||
"aioredis": "aioredis", # 异步Redis客户端
|
||||
"sqlalchemy": "sqlalchemy", # SQL工具包和ORM
|
||||
"alembic": "alembic", # SQLAlchemy数据库迁移工具
|
||||
"tortoise-orm": "tortoise", # 异步ORM
|
||||
|
||||
"mysql-connector-python": "mysql.connector", # MySQL官方驱动
|
||||
"psycopg2-binary": "psycopg2", # PostgreSQL驱动 (二进制)
|
||||
"pymongo": "pymongo", # MongoDB驱动
|
||||
"redis": "redis", # Redis客户端
|
||||
"aioredis": "aioredis", # 异步Redis客户端
|
||||
"sqlalchemy": "sqlalchemy", # SQL工具包和ORM
|
||||
"alembic": "alembic", # SQLAlchemy数据库迁移工具
|
||||
"tortoise-orm": "tortoise", # 异步ORM
|
||||
# ============== 图像与多媒体 (Image & Multimedia) ==============
|
||||
"Pillow": "PIL", # Python图像处理库 (PIL Fork)
|
||||
"moviepy": "moviepy", # 视频编辑库
|
||||
"pydub": "pydub", # 音频处理库
|
||||
"pycairo": "cairo", # Cairo 2D图形库的Python绑定
|
||||
"wand": "wand", # ImageMagick的Python绑定
|
||||
|
||||
"Pillow": "PIL", # Python图像处理库 (PIL Fork)
|
||||
"moviepy": "moviepy", # 视频编辑库
|
||||
"pydub": "pydub", # 音频处理库
|
||||
"pycairo": "cairo", # Cairo 2D图形库的Python绑定
|
||||
"wand": "wand", # ImageMagick的Python绑定
|
||||
# ============== 解析与序列化 (Parsing & Serialization) ==============
|
||||
"beautifulsoup4": "bs4", # HTML/XML解析库
|
||||
"lxml": "lxml", # 高性能HTML/XML解析库
|
||||
"PyYAML": "yaml", # YAML解析库
|
||||
"python-dotenv": "dotenv", # .env文件解析
|
||||
"python-dateutil": "dateutil", # 强大的日期时间解析
|
||||
"protobuf": "google.protobuf", # Protocol Buffers
|
||||
"msgpack": "msgpack", # MessagePack序列化
|
||||
"orjson": "orjson", # 高性能JSON库
|
||||
"pydantic": "pydantic", # 数据验证和设置管理
|
||||
|
||||
"beautifulsoup4": "bs4", # HTML/XML解析库
|
||||
"lxml": "lxml", # 高性能HTML/XML解析库
|
||||
"PyYAML": "yaml", # YAML解析库
|
||||
"python-dotenv": "dotenv", # .env文件解析
|
||||
"python-dateutil": "dateutil", # 强大的日期时间解析
|
||||
"protobuf": "google.protobuf", # Protocol Buffers
|
||||
"msgpack": "msgpack", # MessagePack序列化
|
||||
"orjson": "orjson", # 高性能JSON库
|
||||
"pydantic": "pydantic", # 数据验证和设置管理
|
||||
# ============== 系统与硬件 (System & Hardware) ==============
|
||||
"pyserial": "serial", # 串口通信
|
||||
"pyusb": "usb", # USB访问
|
||||
"pybluez": "bluetooth", # 蓝牙通信 (可能因平台而异)
|
||||
"psutil": "psutil", # 系统信息和进程管理
|
||||
"watchdog": "watchdog", # 文件系统事件监控
|
||||
"python-gnupg": "gnupg", # GnuPG的Python接口
|
||||
|
||||
"pyserial": "serial", # 串口通信
|
||||
"pyusb": "usb", # USB访问
|
||||
"pybluez": "bluetooth", # 蓝牙通信 (可能因平台而异)
|
||||
"psutil": "psutil", # 系统信息和进程管理
|
||||
"watchdog": "watchdog", # 文件系统事件监控
|
||||
"python-gnupg": "gnupg", # GnuPG的Python接口
|
||||
# ============== 加密与安全 (Cryptography & Security) ==============
|
||||
"pycrypto": "Crypto", # 加密库 (较旧)
|
||||
"pycryptodome": "Crypto", # PyCrypto的现代分支
|
||||
"cryptography": "cryptography", # 现代加密库
|
||||
"pyopenssl": "OpenSSL", # OpenSSL的Python接口
|
||||
"service-identity": "service_identity", # 服务身份验证
|
||||
|
||||
"pycrypto": "Crypto", # 加密库 (较旧)
|
||||
"pycryptodome": "Crypto", # PyCrypto的现代分支
|
||||
"cryptography": "cryptography", # 现代加密库
|
||||
"pyopenssl": "OpenSSL", # OpenSSL的Python接口
|
||||
"service-identity": "service_identity", # 服务身份验证
|
||||
# ============== 工具与杂项 (Utilities & Miscellaneous) ==============
|
||||
"setuptools": "setuptools", # 打包工具
|
||||
"pip": "pip", # 包安装器
|
||||
"tqdm": "tqdm", # 进度条
|
||||
"regex": "regex", # 替代的正则表达式引擎
|
||||
"colorama": "colorama", # 跨平台彩色终端文本
|
||||
"termcolor": "termcolor", # 终端颜色格式化
|
||||
"requests-oauthlib": "requests_oauthlib", # OAuth for Requests
|
||||
"oauthlib": "oauthlib", # 通用OAuth库
|
||||
"authlib": "authlib", # OAuth和OpenID Connect客户端/服务器
|
||||
"pyjwt": "jwt", # JSON Web Token实现
|
||||
"python-editor": "editor", # 程序化地调用编辑器
|
||||
"prompt-toolkit": "prompt_toolkit", # 构建交互式命令行
|
||||
"pygments": "pygments", # 语法高亮
|
||||
"tabulate": "tabulate", # 生成漂亮的表格
|
||||
"nats-client": "nats", # NATS客户端
|
||||
"gitpython": "git", # Git的Python接口
|
||||
"pygithub": "github", # GitHub API v3的Python接口
|
||||
"python-gitlab": "gitlab", # GitLab API的Python接口
|
||||
"jira": "jira", # JIRA API的Python接口
|
||||
"python-jenkins": "jenkins", # Jenkins API的Python接口
|
||||
"huggingface-hub": "huggingface_hub", # Hugging Face Hub API
|
||||
"apache-airflow": "airflow", # Airflow工作流管理
|
||||
"pandas-stubs": "pandas-stubs", # Pandas的类型存根
|
||||
"data-science-types": "data_science_types", # 数据科学类型
|
||||
}
|
||||
"setuptools": "setuptools", # 打包工具
|
||||
"pip": "pip", # 包安装器
|
||||
"tqdm": "tqdm", # 进度条
|
||||
"regex": "regex", # 替代的正则表达式引擎
|
||||
"colorama": "colorama", # 跨平台彩色终端文本
|
||||
"termcolor": "termcolor", # 终端颜色格式化
|
||||
"requests-oauthlib": "requests_oauthlib", # OAuth for Requests
|
||||
"oauthlib": "oauthlib", # 通用OAuth库
|
||||
"authlib": "authlib", # OAuth和OpenID Connect客户端/服务器
|
||||
"pyjwt": "jwt", # JSON Web Token实现
|
||||
"python-editor": "editor", # 程序化地调用编辑器
|
||||
"prompt-toolkit": "prompt_toolkit", # 构建交互式命令行
|
||||
"pygments": "pygments", # 语法高亮
|
||||
"tabulate": "tabulate", # 生成漂亮的表格
|
||||
"nats-client": "nats", # NATS客户端
|
||||
"gitpython": "git", # Git的Python接口
|
||||
"pygithub": "github", # GitHub API v3的Python接口
|
||||
"python-gitlab": "gitlab", # GitLab API的Python接口
|
||||
"jira": "jira", # JIRA API的Python接口
|
||||
"python-jenkins": "jenkins", # Jenkins API的Python接口
|
||||
"huggingface-hub": "huggingface_hub", # Hugging Face Hub API
|
||||
"apache-airflow": "airflow", # Airflow工作流管理
|
||||
"pandas-stubs": "pandas-stubs", # Pandas的类型存根
|
||||
"data-science-types": "data_science_types", # 数据科学类型
|
||||
}
|
||||
|
||||
@@ -6,62 +6,61 @@ logger = get_logger("dependency_config")
|
||||
|
||||
class DependencyConfig:
|
||||
"""依赖管理配置类 - 现在使用全局配置"""
|
||||
|
||||
|
||||
def __init__(self, global_config=None):
|
||||
self._global_config = global_config
|
||||
|
||||
|
||||
def _get_config(self):
|
||||
"""获取全局配置对象"""
|
||||
if self._global_config is not None:
|
||||
return self._global_config
|
||||
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
return global_config
|
||||
except ImportError:
|
||||
logger.warning("无法导入全局配置,使用默认设置")
|
||||
return None
|
||||
|
||||
|
||||
@property
|
||||
def auto_install(self) -> bool:
|
||||
"""是否启用自动安装"""
|
||||
config = self._get_config()
|
||||
if config and hasattr(config, 'dependency_management'):
|
||||
if config and hasattr(config, "dependency_management"):
|
||||
return config.dependency_management.auto_install
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def use_mirror(self) -> bool:
|
||||
"""是否使用PyPI镜像源"""
|
||||
config = self._get_config()
|
||||
if config and hasattr(config, 'dependency_management'):
|
||||
if config and hasattr(config, "dependency_management"):
|
||||
return config.dependency_management.use_mirror
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def mirror_url(self) -> str:
|
||||
"""PyPI镜像源URL"""
|
||||
config = self._get_config()
|
||||
if config and hasattr(config, 'dependency_management'):
|
||||
if config and hasattr(config, "dependency_management"):
|
||||
return config.dependency_management.mirror_url
|
||||
return ""
|
||||
|
||||
|
||||
@property
|
||||
def install_timeout(self) -> int:
|
||||
"""安装超时时间(秒)"""
|
||||
config = self._get_config()
|
||||
if config and hasattr(config, 'dependency_management'):
|
||||
if config and hasattr(config, "dependency_management"):
|
||||
return config.dependency_management.auto_install_timeout
|
||||
return 300
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def prompt_before_install(self) -> bool:
|
||||
"""安装前是否提示用户"""
|
||||
config = self._get_config()
|
||||
if config and hasattr(config, 'dependency_management'):
|
||||
if config and hasattr(config, "dependency_management"):
|
||||
return config.dependency_management.prompt_before_install
|
||||
return False
|
||||
|
||||
@@ -82,4 +81,4 @@ def configure_dependency_settings(**kwargs) -> None:
|
||||
"""配置依赖管理设置 - 注意:这个函数现在仅用于兼容性,实际配置需要修改bot_config.toml"""
|
||||
logger.info("依赖管理设置现在通过 bot_config.toml 的 [dependency_management] 节进行配置")
|
||||
logger.info(f"请求的配置更改: {kwargs}")
|
||||
logger.warning("configure_dependency_settings 函数仅用于兼容性,配置更改不会持久化")
|
||||
logger.warning("configure_dependency_settings 函数仅用于兼容性,配置更改不会持久化")
|
||||
|
||||
@@ -15,13 +15,13 @@ logger = get_logger("dependency_manager")
|
||||
|
||||
class DependencyManager:
|
||||
"""Python包依赖管理器
|
||||
|
||||
|
||||
负责检查和自动安装插件的Python包依赖
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, auto_install: bool = True, use_mirror: bool = False, mirror_url: Optional[str] = None):
|
||||
"""初始化依赖管理器
|
||||
|
||||
|
||||
Args:
|
||||
auto_install: 是否自动安装缺失的依赖
|
||||
use_mirror: 是否使用PyPI镜像源
|
||||
@@ -30,38 +30,39 @@ class DependencyManager:
|
||||
# 延迟导入配置以避免循环依赖
|
||||
try:
|
||||
from src.plugin_system.utils.dependency_config import get_dependency_config
|
||||
|
||||
config = get_dependency_config()
|
||||
|
||||
|
||||
# 优先使用配置文件中的设置,参数作为覆盖
|
||||
self.auto_install = config.auto_install if auto_install is True else auto_install
|
||||
self.use_mirror = config.use_mirror if use_mirror is False else use_mirror
|
||||
self.mirror_url = config.mirror_url if mirror_url is None else mirror_url
|
||||
self.install_timeout = config.install_timeout
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"无法加载依赖配置,使用默认设置: {e}")
|
||||
self.auto_install = auto_install
|
||||
self.use_mirror = use_mirror or False
|
||||
self.mirror_url = mirror_url or ""
|
||||
self.install_timeout = 300
|
||||
|
||||
|
||||
def check_dependencies(self, dependencies: Any, plugin_name: str = "") -> Tuple[bool, List[str], List[str]]:
|
||||
"""检查依赖包是否满足要求
|
||||
|
||||
|
||||
Args:
|
||||
dependencies: 依赖列表,支持字符串或PythonDependency对象
|
||||
plugin_name: 插件名称,用于日志记录
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, List[str], List[str]]: (是否全部满足, 缺失的包, 错误信息)
|
||||
"""
|
||||
missing_packages = []
|
||||
error_messages = []
|
||||
log_prefix = f"[Plugin:{plugin_name}] " if plugin_name else ""
|
||||
|
||||
|
||||
# 标准化依赖格式
|
||||
normalized_deps = self._normalize_dependencies(dependencies)
|
||||
|
||||
|
||||
for dep in normalized_deps:
|
||||
try:
|
||||
if not self._check_single_dependency(dep):
|
||||
@@ -71,38 +72,40 @@ class DependencyManager:
|
||||
error_msg = f"检查依赖 {dep.package_name} 时发生错误: {str(e)}"
|
||||
error_messages.append(error_msg)
|
||||
logger.error(f"{log_prefix}{error_msg}")
|
||||
|
||||
|
||||
all_satisfied = len(missing_packages) == 0 and len(error_messages) == 0
|
||||
|
||||
|
||||
if all_satisfied:
|
||||
logger.debug(f"{log_prefix}所有Python依赖检查通过")
|
||||
else:
|
||||
logger.warning(f"{log_prefix}Python依赖检查失败: 缺失{len(missing_packages)}个包, {len(error_messages)}个错误")
|
||||
|
||||
logger.warning(
|
||||
f"{log_prefix}Python依赖检查失败: 缺失{len(missing_packages)}个包, {len(error_messages)}个错误"
|
||||
)
|
||||
|
||||
return all_satisfied, missing_packages, error_messages
|
||||
|
||||
|
||||
def install_dependencies(self, packages: List[str], plugin_name: str = "") -> Tuple[bool, List[str]]:
|
||||
"""自动安装缺失的依赖包
|
||||
|
||||
|
||||
Args:
|
||||
packages: 要安装的包列表
|
||||
plugin_name: 插件名称,用于日志记录
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, List[str]]: (是否全部安装成功, 失败的包列表)
|
||||
"""
|
||||
if not packages:
|
||||
return True, []
|
||||
|
||||
|
||||
if not self.auto_install:
|
||||
logger.info(f"[Plugin:{plugin_name}] 自动安装已禁用,跳过安装: {packages}")
|
||||
return False, packages
|
||||
|
||||
|
||||
log_prefix = f"[Plugin:{plugin_name}] " if plugin_name else ""
|
||||
logger.info(f"{log_prefix}开始自动安装Python依赖: {packages}")
|
||||
|
||||
|
||||
failed_packages = []
|
||||
|
||||
|
||||
for package in packages:
|
||||
try:
|
||||
if self._install_single_package(package, plugin_name):
|
||||
@@ -113,37 +116,37 @@ class DependencyManager:
|
||||
except Exception as e:
|
||||
failed_packages.append(package)
|
||||
logger.error(f"{log_prefix}❌ 安装 {package} 时发生异常: {str(e)}")
|
||||
|
||||
|
||||
success = len(failed_packages) == 0
|
||||
if success:
|
||||
logger.info(f"{log_prefix}🎉 所有依赖安装完成")
|
||||
else:
|
||||
logger.error(f"{log_prefix}⚠️ 部分依赖安装失败: {failed_packages}")
|
||||
|
||||
|
||||
return success, failed_packages
|
||||
|
||||
|
||||
def check_and_install_dependencies(self, dependencies: Any, plugin_name: str = "") -> Tuple[bool, List[str]]:
|
||||
"""检查并自动安装依赖(组合操作)
|
||||
|
||||
|
||||
Args:
|
||||
dependencies: 依赖列表
|
||||
plugin_name: 插件名称
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, List[str]]: (是否全部满足, 错误信息列表)
|
||||
"""
|
||||
# 第一步:检查依赖
|
||||
all_satisfied, missing_packages, check_errors = self.check_dependencies(dependencies, plugin_name)
|
||||
|
||||
|
||||
if all_satisfied:
|
||||
return True, []
|
||||
|
||||
|
||||
all_errors = check_errors.copy()
|
||||
|
||||
|
||||
# 第二步:尝试安装缺失的包
|
||||
if missing_packages and self.auto_install:
|
||||
install_success, failed_packages = self.install_dependencies(missing_packages, plugin_name)
|
||||
|
||||
|
||||
if not install_success:
|
||||
all_errors.extend([f"安装失败: {pkg}" for pkg in failed_packages])
|
||||
else:
|
||||
@@ -156,13 +159,13 @@ class DependencyManager:
|
||||
return True, []
|
||||
else:
|
||||
all_errors.extend([f"缺失依赖: {pkg}" for pkg in missing_packages])
|
||||
|
||||
|
||||
return False, all_errors
|
||||
|
||||
|
||||
def _normalize_dependencies(self, dependencies: Any) -> List[PythonDependency]:
|
||||
"""将依赖列表标准化为PythonDependency对象"""
|
||||
normalized = []
|
||||
|
||||
|
||||
for dep in dependencies:
|
||||
if isinstance(dep, str):
|
||||
# 解析字符串格式的依赖
|
||||
@@ -170,28 +173,27 @@ class DependencyManager:
|
||||
# 尝试解析为requirement格式 (如 "package>=1.0.0")
|
||||
req = Requirement(dep)
|
||||
version_spec = str(req.specifier) if req.specifier else ""
|
||||
|
||||
normalized.append(PythonDependency(
|
||||
package_name=req.name,
|
||||
version=version_spec,
|
||||
install_name=dep # 保持原始的安装名称
|
||||
))
|
||||
|
||||
normalized.append(
|
||||
PythonDependency(
|
||||
package_name=req.name,
|
||||
version=version_spec,
|
||||
install_name=dep, # 保持原始的安装名称
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# 如果解析失败,作为简单包名处理
|
||||
normalized.append(PythonDependency(
|
||||
package_name=dep,
|
||||
install_name=dep
|
||||
))
|
||||
normalized.append(PythonDependency(package_name=dep, install_name=dep))
|
||||
elif isinstance(dep, PythonDependency):
|
||||
normalized.append(dep)
|
||||
else:
|
||||
logger.warning(f"未知的依赖格式: {dep}")
|
||||
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _check_single_dependency(self, dep: PythonDependency) -> bool:
|
||||
"""检查单个依赖是否满足要求"""
|
||||
|
||||
|
||||
def _try_check(import_name: str) -> bool:
|
||||
"""尝试使用给定的导入名进行检查"""
|
||||
try:
|
||||
@@ -206,11 +208,11 @@ class DependencyManager:
|
||||
# 检查版本要求
|
||||
try:
|
||||
module = importlib.import_module(import_name)
|
||||
installed_version = getattr(module, '__version__', None)
|
||||
installed_version = getattr(module, "__version__", None)
|
||||
|
||||
if installed_version is None:
|
||||
# 尝试其他常见的版本属性
|
||||
installed_version = getattr(module, 'VERSION', None)
|
||||
installed_version = getattr(module, "VERSION", None)
|
||||
if installed_version is None:
|
||||
logger.debug(f"无法获取包 {import_name} 的版本信息,假设满足要求")
|
||||
return True
|
||||
@@ -243,33 +245,27 @@ class DependencyManager:
|
||||
|
||||
# 3. 如果别名也失败了,或者没有别名,最终确认失败
|
||||
return False
|
||||
|
||||
|
||||
def _install_single_package(self, package: str, plugin_name: str = "") -> bool:
|
||||
"""安装单个包"""
|
||||
try:
|
||||
cmd = [sys.executable, "-m", "pip", "install", package]
|
||||
|
||||
|
||||
# 添加镜像源设置
|
||||
if self.use_mirror and self.mirror_url:
|
||||
cmd.extend(["-i", self.mirror_url])
|
||||
logger.debug(f"[Plugin:{plugin_name}] 使用PyPI镜像源: {self.mirror_url}")
|
||||
|
||||
|
||||
logger.debug(f"[Plugin:{plugin_name}] 执行安装命令: {' '.join(cmd)}")
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=self.install_timeout,
|
||||
check=False
|
||||
)
|
||||
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=self.install_timeout, check=False)
|
||||
|
||||
if result.returncode == 0:
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[Plugin:{plugin_name}] pip安装失败: {result.stderr}")
|
||||
return False
|
||||
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"[Plugin:{plugin_name}] 安装 {package} 超时")
|
||||
return False
|
||||
@@ -294,7 +290,5 @@ def configure_dependency_manager(auto_install: bool = True, use_mirror: bool = F
|
||||
"""配置全局依赖管理器"""
|
||||
global _global_dependency_manager
|
||||
_global_dependency_manager = DependencyManager(
|
||||
auto_install=auto_install,
|
||||
use_mirror=use_mirror,
|
||||
mirror_url=mirror_url
|
||||
)
|
||||
auto_install=auto_install, use_mirror=use_mirror, mirror_url=mirror_url
|
||||
)
|
||||
|
||||
@@ -19,65 +19,64 @@ logger = get_logger(__name__)
|
||||
def require_permission(permission_node: str, deny_message: Optional[str] = None):
|
||||
"""
|
||||
权限检查装饰器
|
||||
|
||||
|
||||
用于装饰需要特定权限才能执行的函数。如果用户没有权限,会发送拒绝消息并阻止函数执行。
|
||||
|
||||
|
||||
Args:
|
||||
permission_node: 所需的权限节点名称
|
||||
deny_message: 权限不足时的提示消息,如果为None则使用默认消息
|
||||
|
||||
|
||||
Example:
|
||||
@require_permission("plugin.example.admin")
|
||||
async def admin_command(message: Message, chat_stream: ChatStream):
|
||||
# 只有拥有 plugin.example.admin 权限的用户才能执行
|
||||
pass
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
# 尝试从参数中提取 ChatStream 对象
|
||||
chat_stream = None
|
||||
|
||||
|
||||
# 首先检查位置参数中的 ChatStream
|
||||
for arg in args:
|
||||
if isinstance(arg, ChatStream):
|
||||
chat_stream = arg
|
||||
break
|
||||
|
||||
|
||||
# 如果在位置参数中没找到,尝试从关键字参数中查找
|
||||
if chat_stream is None:
|
||||
chat_stream = kwargs.get('chat_stream')
|
||||
|
||||
chat_stream = kwargs.get("chat_stream")
|
||||
|
||||
# 如果还没找到,检查是否是 PlusCommand 方法调用
|
||||
if chat_stream is None and args:
|
||||
# 检查第一个参数是否有 message.chat_stream 属性(PlusCommand 实例)
|
||||
instance = args[0]
|
||||
if hasattr(instance, 'message') and hasattr(instance.message, 'chat_stream'):
|
||||
if hasattr(instance, "message") and hasattr(instance.message, "chat_stream"):
|
||||
chat_stream = instance.message.chat_stream
|
||||
|
||||
|
||||
if chat_stream is None:
|
||||
logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
|
||||
return
|
||||
|
||||
|
||||
# 检查权限
|
||||
has_permission = permission_api.check_permission(
|
||||
chat_stream.platform,
|
||||
chat_stream.user_info.user_id,
|
||||
permission_node
|
||||
chat_stream.platform, chat_stream.user_info.user_id, permission_node
|
||||
)
|
||||
|
||||
|
||||
if not has_permission:
|
||||
# 权限不足,发送拒绝消息
|
||||
message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}"
|
||||
await text_to_stream(message, chat_stream.stream_id)
|
||||
# 对于PlusCommand的execute方法,需要返回适当的元组
|
||||
if func.__name__ == 'execute' and hasattr(args[0], 'send_text'):
|
||||
if func.__name__ == "execute" and hasattr(args[0], "send_text"):
|
||||
return False, "权限不足", True
|
||||
return
|
||||
|
||||
|
||||
# 权限检查通过,执行原函数
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
# 对于同步函数,我们不能发送异步消息,只能记录日志
|
||||
chat_stream = None
|
||||
@@ -85,95 +84,93 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None)
|
||||
if isinstance(arg, ChatStream):
|
||||
chat_stream = arg
|
||||
break
|
||||
|
||||
|
||||
if chat_stream is None:
|
||||
chat_stream = kwargs.get('chat_stream')
|
||||
|
||||
chat_stream = kwargs.get("chat_stream")
|
||||
|
||||
if chat_stream is None:
|
||||
logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
|
||||
return
|
||||
|
||||
|
||||
# 检查权限
|
||||
has_permission = permission_api.check_permission(
|
||||
chat_stream.platform,
|
||||
chat_stream.user_info.user_id,
|
||||
permission_node
|
||||
chat_stream.platform, chat_stream.user_info.user_id, permission_node
|
||||
)
|
||||
|
||||
|
||||
if not has_permission:
|
||||
logger.warning(f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 没有权限 {permission_node}")
|
||||
logger.warning(
|
||||
f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 没有权限 {permission_node}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# 权限检查通过,执行原函数
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
# 根据函数类型选择包装器
|
||||
if iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_master(deny_message: Optional[str] = None):
|
||||
"""
|
||||
Master权限检查装饰器
|
||||
|
||||
|
||||
用于装饰只有Master用户才能执行的函数。
|
||||
|
||||
|
||||
Args:
|
||||
deny_message: 权限不足时的提示消息,如果为None则使用默认消息
|
||||
|
||||
|
||||
Example:
|
||||
@require_master()
|
||||
async def master_only_command(message: Message, chat_stream: ChatStream):
|
||||
# 只有Master用户才能执行
|
||||
pass
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
# 尝试从参数中提取 ChatStream 对象
|
||||
chat_stream = None
|
||||
|
||||
|
||||
# 首先检查位置参数中的 ChatStream
|
||||
for arg in args:
|
||||
if isinstance(arg, ChatStream):
|
||||
chat_stream = arg
|
||||
break
|
||||
|
||||
|
||||
# 如果在位置参数中没找到,尝试从关键字参数中查找
|
||||
if chat_stream is None:
|
||||
chat_stream = kwargs.get('chat_stream')
|
||||
|
||||
chat_stream = kwargs.get("chat_stream")
|
||||
|
||||
# 如果还没找到,检查是否是 PlusCommand 方法调用
|
||||
if chat_stream is None and args:
|
||||
# 检查第一个参数是否有 message.chat_stream 属性(PlusCommand 实例)
|
||||
instance = args[0]
|
||||
if hasattr(instance, 'message') and hasattr(instance.message, 'chat_stream'):
|
||||
if hasattr(instance, "message") and hasattr(instance.message, "chat_stream"):
|
||||
chat_stream = instance.message.chat_stream
|
||||
|
||||
|
||||
if chat_stream is None:
|
||||
logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
|
||||
return
|
||||
|
||||
|
||||
# 检查是否为Master用户
|
||||
is_master = permission_api.is_master(
|
||||
chat_stream.platform,
|
||||
chat_stream.user_info.user_id
|
||||
)
|
||||
|
||||
is_master = permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id)
|
||||
|
||||
if not is_master:
|
||||
message = deny_message or "❌ 此操作仅限Master用户执行"
|
||||
await text_to_stream(message, chat_stream.stream_id)
|
||||
if func.__name__ == 'execute' and hasattr(args[0], 'send_text'):
|
||||
if func.__name__ == "execute" and hasattr(args[0], "send_text"):
|
||||
return False, "需要Master权限", True
|
||||
return
|
||||
|
||||
|
||||
# 权限检查通过,执行原函数
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
# 对于同步函数,我们不能发送异步消息,只能记录日志
|
||||
chat_stream = None
|
||||
@@ -181,116 +178,106 @@ def require_master(deny_message: Optional[str] = None):
|
||||
if isinstance(arg, ChatStream):
|
||||
chat_stream = arg
|
||||
break
|
||||
|
||||
|
||||
if chat_stream is None:
|
||||
chat_stream = kwargs.get('chat_stream')
|
||||
|
||||
chat_stream = kwargs.get("chat_stream")
|
||||
|
||||
if chat_stream is None:
|
||||
logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}")
|
||||
return
|
||||
|
||||
|
||||
# 检查是否为Master用户
|
||||
is_master = permission_api.is_master(
|
||||
chat_stream.platform,
|
||||
chat_stream.user_info.user_id
|
||||
)
|
||||
|
||||
is_master = permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id)
|
||||
|
||||
if not is_master:
|
||||
logger.warning(f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 不是Master用户")
|
||||
return
|
||||
|
||||
|
||||
# 权限检查通过,执行原函数
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
# 根据函数类型选择包装器
|
||||
if iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PermissionChecker:
|
||||
"""
|
||||
权限检查工具类
|
||||
|
||||
|
||||
提供一些便捷的权限检查方法,用于在代码中进行权限验证。
|
||||
"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def check_permission(chat_stream: ChatStream, permission_node: str) -> bool:
|
||||
"""
|
||||
检查用户是否拥有指定权限
|
||||
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否拥有权限
|
||||
"""
|
||||
return permission_api.check_permission(
|
||||
chat_stream.platform,
|
||||
chat_stream.user_info.user_id,
|
||||
permission_node
|
||||
)
|
||||
|
||||
return permission_api.check_permission(chat_stream.platform, chat_stream.user_info.user_id, permission_node)
|
||||
|
||||
@staticmethod
|
||||
def is_master(chat_stream: ChatStream) -> bool:
|
||||
"""
|
||||
检查用户是否为Master用户
|
||||
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否为Master用户
|
||||
"""
|
||||
return permission_api.is_master(
|
||||
chat_stream.platform,
|
||||
chat_stream.user_info.user_id
|
||||
)
|
||||
|
||||
return permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id)
|
||||
|
||||
@staticmethod
|
||||
async def ensure_permission(chat_stream: ChatStream, permission_node: str,
|
||||
deny_message: Optional[str] = None) -> bool:
|
||||
async def ensure_permission(
|
||||
chat_stream: ChatStream, permission_node: str, deny_message: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
确保用户拥有指定权限,如果没有权限会发送消息并返回False
|
||||
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
permission_node: 权限节点名称
|
||||
deny_message: 权限不足时的提示消息
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否拥有权限
|
||||
"""
|
||||
has_permission = PermissionChecker.check_permission(chat_stream, permission_node)
|
||||
|
||||
|
||||
if not has_permission:
|
||||
message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}"
|
||||
await text_to_stream(message, chat_stream.stream_id)
|
||||
|
||||
|
||||
return has_permission
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def ensure_master(chat_stream: ChatStream,
|
||||
deny_message: Optional[str] = None) -> bool:
|
||||
async def ensure_master(chat_stream: ChatStream, deny_message: Optional[str] = None) -> bool:
|
||||
"""
|
||||
确保用户为Master用户,如果不是会发送消息并返回False
|
||||
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
deny_message: 权限不足时的提示消息
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否为Master用户
|
||||
"""
|
||||
is_master = PermissionChecker.is_master(chat_stream)
|
||||
|
||||
|
||||
if not is_master:
|
||||
message = deny_message or "❌ 此操作仅限Master用户执行"
|
||||
await text_to_stream(message, chat_stream.stream_id)
|
||||
|
||||
|
||||
return is_master
|
||||
|
||||
Reference in New Issue
Block a user