修复代码格式和文件名大小写问题
This commit is contained in:
@@ -147,7 +147,7 @@ class BaseAction(ABC):
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
||||
)
|
||||
|
||||
|
||||
# 验证聊天类型限制
|
||||
if not self._validate_chat_type():
|
||||
logger.warning(
|
||||
@@ -157,7 +157,7 @@ class BaseAction(ABC):
|
||||
|
||||
def _validate_chat_type(self) -> bool:
|
||||
"""验证当前聊天类型是否允许执行此Action
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
@@ -172,9 +172,9 @@ class BaseAction(ABC):
|
||||
|
||||
def is_chat_type_allowed(self) -> bool:
|
||||
"""检查当前聊天类型是否允许执行此Action
|
||||
|
||||
|
||||
这是一个公开的方法,供外部调用检查聊天类型限制
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
@@ -240,9 +240,7 @@ class BaseAction(ABC):
|
||||
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
|
||||
return False, f"等待新消息失败: {str(e)}"
|
||||
|
||||
async def send_text(
|
||||
self, content: str, reply_to: str = "", typing: bool = False
|
||||
) -> bool:
|
||||
async def send_text(self, content: str, reply_to: str = "", typing: bool = False) -> bool:
|
||||
"""发送文本消息
|
||||
|
||||
Args:
|
||||
|
||||
@@ -46,10 +46,10 @@ class BaseCommand(ABC):
|
||||
self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
||||
|
||||
logger.debug(f"{self.log_prefix} Command组件初始化完成")
|
||||
|
||||
|
||||
# 验证聊天类型限制
|
||||
if not self._validate_chat_type():
|
||||
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
|
||||
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
|
||||
logger.warning(
|
||||
f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: "
|
||||
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
|
||||
@@ -65,16 +65,16 @@ class BaseCommand(ABC):
|
||||
|
||||
def _validate_chat_type(self) -> bool:
|
||||
"""验证当前聊天类型是否允许执行此Command
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
if self.chat_type_allow == ChatType.ALL:
|
||||
return True
|
||||
|
||||
|
||||
# 检查是否为群聊消息
|
||||
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
|
||||
|
||||
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
|
||||
|
||||
if self.chat_type_allow == ChatType.GROUP and is_group:
|
||||
return True
|
||||
elif self.chat_type_allow == ChatType.PRIVATE and not is_group:
|
||||
@@ -84,9 +84,9 @@ class BaseCommand(ABC):
|
||||
|
||||
def is_chat_type_allowed(self) -> bool:
|
||||
"""检查当前聊天类型是否允许执行此Command
|
||||
|
||||
|
||||
这是一个公开的方法,供外部调用检查聊天类型限制
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
|
||||
@@ -3,12 +3,14 @@ from typing import List, Dict, Any, Optional
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("base_event")
|
||||
|
||||
|
||||
|
||||
class HandlerResult:
|
||||
"""事件处理器执行结果
|
||||
|
||||
|
||||
所有事件处理器必须返回此类的实例
|
||||
"""
|
||||
|
||||
def __init__(self, success: bool, continue_process: bool, message: Any = None, handler_name: str = ""):
|
||||
self.success = success
|
||||
self.continue_process = continue_process
|
||||
@@ -18,31 +20,32 @@ class HandlerResult:
|
||||
def __repr__(self):
|
||||
return f"HandlerResult(success={self.success}, continue_process={self.continue_process}, message='{self.message}', handler_name='{self.handler_name}')"
|
||||
|
||||
|
||||
class HandlerResultsCollection:
|
||||
"""HandlerResult集合,提供便捷的查询方法"""
|
||||
|
||||
|
||||
def __init__(self, results: List[HandlerResult]):
|
||||
self.results = results
|
||||
|
||||
|
||||
def all_continue_process(self) -> bool:
|
||||
"""检查是否所有handler的continue_process都为True"""
|
||||
return all(result.continue_process for result in self.results)
|
||||
|
||||
|
||||
def get_all_results(self) -> List[HandlerResult]:
|
||||
"""获取所有HandlerResult"""
|
||||
return self.results
|
||||
|
||||
|
||||
def get_failed_handlers(self) -> List[HandlerResult]:
|
||||
"""获取执行失败的handler结果"""
|
||||
return [result for result in self.results if not result.success]
|
||||
|
||||
|
||||
def get_stopped_handlers(self) -> List[HandlerResult]:
|
||||
"""获取continue_process为False的handler结果"""
|
||||
return [result for result in self.results if not result.continue_process]
|
||||
|
||||
|
||||
def get_message_result(self) -> Any:
|
||||
"""获取handler的message
|
||||
|
||||
|
||||
当只有一个handler的结果时,直接返回那个handler结果中的message字段
|
||||
否则用字典的形式{handler_name:message}返回
|
||||
"""
|
||||
@@ -52,22 +55,22 @@ class HandlerResultsCollection:
|
||||
return self.results[0].message
|
||||
else:
|
||||
return {result.handler_name: result.message for result in self.results}
|
||||
|
||||
|
||||
def get_handler_result(self, handler_name: str) -> Optional[HandlerResult]:
|
||||
"""获取指定handler的结果"""
|
||||
for result in self.results:
|
||||
if result.handler_name == handler_name:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def get_success_count(self) -> int:
|
||||
"""获取成功执行的handler数量"""
|
||||
return sum(1 for result in self.results if result.success)
|
||||
|
||||
|
||||
def get_failure_count(self) -> int:
|
||||
"""获取执行失败的handler数量"""
|
||||
return sum(1 for result in self.results if not result.success)
|
||||
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""获取执行摘要"""
|
||||
return {
|
||||
@@ -76,62 +79,63 @@ class HandlerResultsCollection:
|
||||
"failure_count": self.get_failure_count(),
|
||||
"continue_process": self.all_continue_process(),
|
||||
"failed_handlers": [r.handler_name for r in self.get_failed_handlers()],
|
||||
"stopped_handlers": [r.handler_name for r in self.get_stopped_handlers()]
|
||||
"stopped_handlers": [r.handler_name for r in self.get_stopped_handlers()],
|
||||
}
|
||||
|
||||
|
||||
class BaseEvent:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
allowed_subscribers: List[str] = None,
|
||||
allowed_triggers: List[str] = None
|
||||
):
|
||||
def __init__(self, name: str, allowed_subscribers: List[str] = None, allowed_triggers: List[str] = None):
|
||||
self.name = name
|
||||
self.enabled = True
|
||||
self.allowed_subscribers = allowed_subscribers # 记录事件处理器名
|
||||
self.allowed_triggers = allowed_triggers # 记录插件名
|
||||
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
self.subscribers: List["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表
|
||||
|
||||
self.subscribers: List["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表
|
||||
|
||||
self.event_handle_lock = asyncio.Lock()
|
||||
|
||||
def __name__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
async def activate(self, params: dict) -> HandlerResultsCollection:
|
||||
"""激活事件,执行所有订阅的处理器
|
||||
|
||||
|
||||
Args:
|
||||
params: 传递给处理器的参数
|
||||
|
||||
|
||||
Returns:
|
||||
HandlerResultsCollection: 所有处理器的执行结果集合
|
||||
"""
|
||||
if not self.enabled:
|
||||
return HandlerResultsCollection([])
|
||||
|
||||
|
||||
# 使用锁确保同一个事件不能同时激活多次
|
||||
async with self.event_handle_lock:
|
||||
# 按权重从高到低排序订阅者
|
||||
# 使用直接属性访问,-1代表自动权重
|
||||
sorted_subscribers = sorted(self.subscribers, key=lambda h: h.weight if hasattr(h, 'weight') and h.weight != -1 else 0, reverse=True)
|
||||
|
||||
sorted_subscribers = sorted(
|
||||
self.subscribers, key=lambda h: h.weight if hasattr(h, "weight") and h.weight != -1 else 0, reverse=True
|
||||
)
|
||||
|
||||
# 并行执行所有订阅者
|
||||
tasks = []
|
||||
for subscriber in sorted_subscribers:
|
||||
# 为每个订阅者创建执行任务
|
||||
task = self._execute_subscriber(subscriber, params)
|
||||
tasks.append(task)
|
||||
|
||||
|
||||
# 等待所有任务完成
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 处理执行结果
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
subscriber = sorted_subscribers[i]
|
||||
handler_name = subscriber.handler_name if hasattr(subscriber, 'handler_name') else subscriber.__class__.__name__
|
||||
handler_name = (
|
||||
subscriber.handler_name if hasattr(subscriber, "handler_name") else subscriber.__class__.__name__
|
||||
)
|
||||
if result:
|
||||
if isinstance(result, Exception):
|
||||
# 处理执行异常
|
||||
@@ -143,13 +147,13 @@ class BaseEvent:
|
||||
# 补充handler_name
|
||||
result.handler_name = handler_name
|
||||
processed_results.append(result)
|
||||
|
||||
|
||||
return HandlerResultsCollection(processed_results)
|
||||
|
||||
|
||||
async def _execute_subscriber(self, subscriber, params: dict) -> HandlerResult:
|
||||
"""执行单个订阅者处理器"""
|
||||
try:
|
||||
return await subscriber.execute(params)
|
||||
except Exception as e:
|
||||
# 异常会在 gather 中捕获,这里直接抛出让 gather 处理
|
||||
raise e
|
||||
raise e
|
||||
|
||||
@@ -51,11 +51,11 @@ class BaseEventHandler(ABC):
|
||||
event_name (str): 要订阅的事件名称
|
||||
"""
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
|
||||
if not event_manager.subscribe_handler_to_event(self.handler_name, event_name):
|
||||
logger.error(f"事件处理器 {self.handler_name} 订阅事件 {event_name} 失败")
|
||||
return
|
||||
|
||||
|
||||
logger.debug(f"{self.log_prefix} 订阅事件 {event_name}")
|
||||
self.subscribed_events.append(event_name)
|
||||
|
||||
@@ -66,7 +66,7 @@ class BaseEventHandler(ABC):
|
||||
event_name (str): 要取消订阅的事件名称
|
||||
"""
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
|
||||
if event_manager.unsubscribe_handler_from_event(self.handler_name, event_name):
|
||||
logger.debug(f"{self.log_prefix} 取消订阅事件 {event_name}")
|
||||
if event_name in self.subscribed_events:
|
||||
|
||||
@@ -9,32 +9,32 @@ import shlex
|
||||
|
||||
class CommandArgs:
|
||||
"""命令参数解析类
|
||||
|
||||
|
||||
提供方便的方法来处理命令参数
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, raw_args: str = ""):
|
||||
"""初始化命令参数
|
||||
|
||||
|
||||
Args:
|
||||
raw_args: 原始参数字符串
|
||||
"""
|
||||
self._raw_args = raw_args.strip()
|
||||
self._parsed_args: Optional[List[str]] = None
|
||||
|
||||
|
||||
def get_raw(self) -> str:
|
||||
"""获取完整的参数字符串
|
||||
|
||||
|
||||
Returns:
|
||||
str: 原始参数字符串
|
||||
"""
|
||||
return self._raw_args
|
||||
|
||||
|
||||
def get_args(self) -> List[str]:
|
||||
"""获取解析后的参数列表
|
||||
|
||||
|
||||
将参数按空格分割,支持引号包围的参数
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 参数列表
|
||||
"""
|
||||
@@ -48,25 +48,25 @@ class CommandArgs:
|
||||
except ValueError:
|
||||
# 如果shlex解析失败,fallback到简单的split
|
||||
self._parsed_args = self._raw_args.split()
|
||||
|
||||
|
||||
return self._parsed_args
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""检查参数是否为空
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果没有参数返回True
|
||||
"""
|
||||
return len(self.get_args()) == 0
|
||||
|
||||
|
||||
def get_arg(self, index: int, default: str = "") -> str:
|
||||
"""获取指定索引的参数
|
||||
|
||||
|
||||
Args:
|
||||
index: 参数索引(从0开始)
|
||||
default: 默认值
|
||||
|
||||
|
||||
Returns:
|
||||
str: 参数值或默认值
|
||||
"""
|
||||
@@ -78,21 +78,21 @@ class CommandArgs:
|
||||
@property
|
||||
def get_first(self, default: str = "") -> str:
|
||||
"""获取第一个参数
|
||||
|
||||
|
||||
Args:
|
||||
default: 默认值
|
||||
|
||||
|
||||
Returns:
|
||||
str: 第一个参数或默认值
|
||||
"""
|
||||
return self.get_arg(0, default)
|
||||
|
||||
|
||||
def get_remaining(self, start_index: int = 0) -> str:
|
||||
"""获取从指定索引开始的剩余参数字符串
|
||||
|
||||
|
||||
Args:
|
||||
start_index: 起始索引
|
||||
|
||||
|
||||
Returns:
|
||||
str: 剩余参数组成的字符串
|
||||
"""
|
||||
@@ -100,45 +100,45 @@ class CommandArgs:
|
||||
if start_index < len(args):
|
||||
return " ".join(args[start_index:])
|
||||
return ""
|
||||
|
||||
|
||||
def count(self) -> int:
|
||||
"""获取参数数量
|
||||
|
||||
|
||||
Returns:
|
||||
int: 参数数量
|
||||
"""
|
||||
return len(self.get_args())
|
||||
|
||||
|
||||
def has_flag(self, flag: str) -> bool:
|
||||
"""检查是否包含指定的标志参数
|
||||
|
||||
|
||||
Args:
|
||||
flag: 标志名(如 "--verbose" 或 "-v")
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果包含该标志返回True
|
||||
"""
|
||||
return flag in self.get_args()
|
||||
|
||||
|
||||
def get_flag_value(self, flag: str, default: str = "") -> str:
|
||||
"""获取标志参数的值
|
||||
|
||||
|
||||
查找 --key=value 或 --key value 形式的参数
|
||||
|
||||
|
||||
Args:
|
||||
flag: 标志名(如 "--output")
|
||||
default: 默认值
|
||||
|
||||
|
||||
Returns:
|
||||
str: 标志的值或默认值
|
||||
"""
|
||||
args = self.get_args()
|
||||
|
||||
|
||||
# 查找 --key=value 形式
|
||||
for arg in args:
|
||||
if arg.startswith(f"{flag}="):
|
||||
return arg[len(flag) + 1:]
|
||||
|
||||
return arg[len(flag) + 1 :]
|
||||
|
||||
# 查找 --key value 形式
|
||||
try:
|
||||
flag_index = args.index(flag)
|
||||
@@ -146,13 +146,13 @@ class CommandArgs:
|
||||
return args[flag_index + 1]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示"""
|
||||
return self._raw_args
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""调试表示"""
|
||||
return f"CommandArgs(raw='{self._raw_args}', parsed={self.get_args()})"
|
||||
|
||||
@@ -6,6 +6,7 @@ from maim_message import Seg
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
||||
|
||||
|
||||
# 组件类型枚举
|
||||
class ComponentType(Enum):
|
||||
"""组件类型枚举"""
|
||||
@@ -185,7 +186,9 @@ class PlusCommandInfo(ComponentInfo):
|
||||
class ToolInfo(ComponentInfo):
|
||||
"""工具组件信息"""
|
||||
|
||||
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义
|
||||
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(
|
||||
default_factory=list
|
||||
) # 工具参数定义
|
||||
tool_description: str = "" # 工具描述
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -205,6 +208,7 @@ class EventHandlerInfo(ComponentInfo):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.EVENT_HANDLER
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventInfo(ComponentInfo):
|
||||
"""事件组件信息"""
|
||||
@@ -213,6 +217,7 @@ class EventInfo(ComponentInfo):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.EVENT
|
||||
|
||||
|
||||
# 事件类型枚举
|
||||
class EventType(Enum):
|
||||
"""
|
||||
@@ -232,6 +237,7 @@ class EventType(Enum):
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginInfo:
|
||||
"""插件信息"""
|
||||
@@ -320,16 +326,16 @@ class MaiMessages:
|
||||
|
||||
llm_response_content: Optional[str] = None
|
||||
"""LLM响应内容"""
|
||||
|
||||
|
||||
llm_response_reasoning: Optional[str] = None
|
||||
"""LLM响应推理内容"""
|
||||
|
||||
|
||||
llm_response_model: Optional[str] = None
|
||||
"""LLM响应模型名称"""
|
||||
|
||||
|
||||
llm_response_tool_call: Optional[List[ToolCall]] = None
|
||||
"""LLM使用的工具调用"""
|
||||
|
||||
|
||||
action_usage: Optional[List[str]] = None
|
||||
"""使用的Action"""
|
||||
|
||||
|
||||
@@ -90,10 +90,10 @@ class PluginBase(ABC):
|
||||
|
||||
# 标准化Python依赖为PythonDependency对象
|
||||
normalized_python_deps = self._normalize_python_dependencies(self.python_dependencies)
|
||||
|
||||
|
||||
# 检查Python依赖
|
||||
self._check_python_dependencies(normalized_python_deps)
|
||||
|
||||
|
||||
# 创建插件信息对象
|
||||
self.plugin_info = PluginInfo(
|
||||
name=self.plugin_name,
|
||||
@@ -560,7 +560,7 @@ class PluginBase(ABC):
|
||||
def _normalize_python_dependencies(self, dependencies: Any) -> List[PythonDependency]:
|
||||
"""将依赖列表标准化为PythonDependency对象"""
|
||||
from packaging.requirements import Requirement
|
||||
|
||||
|
||||
normalized = []
|
||||
for dep in dependencies:
|
||||
if isinstance(dep, str):
|
||||
@@ -568,23 +568,22 @@ class PluginBase(ABC):
|
||||
# 尝试解析为requirement格式 (如 "package>=1.0.0")
|
||||
req = Requirement(dep)
|
||||
version_spec = str(req.specifier) if req.specifier else ""
|
||||
|
||||
normalized.append(PythonDependency(
|
||||
package_name=req.name,
|
||||
version=version_spec,
|
||||
install_name=dep # 保持原始的安装名称
|
||||
))
|
||||
|
||||
normalized.append(
|
||||
PythonDependency(
|
||||
package_name=req.name,
|
||||
version=version_spec,
|
||||
install_name=dep, # 保持原始的安装名称
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# 如果解析失败,作为简单包名处理
|
||||
normalized.append(PythonDependency(
|
||||
package_name=dep,
|
||||
install_name=dep
|
||||
))
|
||||
normalized.append(PythonDependency(package_name=dep, install_name=dep))
|
||||
elif isinstance(dep, PythonDependency):
|
||||
normalized.append(dep)
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 未知的依赖格式: {dep}")
|
||||
|
||||
|
||||
return normalized
|
||||
|
||||
def _check_python_dependencies(self, dependencies: List[PythonDependency]) -> bool:
|
||||
@@ -596,10 +595,10 @@ class PluginBase(ABC):
|
||||
try:
|
||||
# 延迟导入以避免循环依赖
|
||||
from src.plugin_system.utils.dependency_manager import get_dependency_manager
|
||||
|
||||
|
||||
dependency_manager = get_dependency_manager()
|
||||
success, errors = dependency_manager.check_and_install_dependencies(dependencies, self.plugin_name)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} Python依赖检查通过")
|
||||
return True
|
||||
@@ -608,7 +607,7 @@ class PluginBase(ABC):
|
||||
for error in errors:
|
||||
logger.error(f"{self.log_prefix} - {error}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} Python依赖检查时发生异常: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@@ -20,12 +20,12 @@ logger = get_logger("plus_command")
|
||||
|
||||
class PlusCommand(ABC):
|
||||
"""增强版命令基类
|
||||
|
||||
|
||||
提供更简单的命令定义方式,无需手写正则表达式
|
||||
|
||||
|
||||
子类只需要定义:
|
||||
- command_name: 命令名称
|
||||
- command_description: 命令描述
|
||||
- command_description: 命令描述
|
||||
- command_aliases: 命令别名列表(可选)
|
||||
- priority: 优先级(可选,数字越大优先级越高)
|
||||
- chat_type_allow: 允许的聊天类型(可选)
|
||||
@@ -35,19 +35,19 @@ class PlusCommand(ABC):
|
||||
# 子类需要定义的属性
|
||||
command_name: str = ""
|
||||
"""命令名称,如 'echo'"""
|
||||
|
||||
|
||||
command_description: str = ""
|
||||
"""命令描述"""
|
||||
|
||||
|
||||
command_aliases: List[str] = []
|
||||
"""命令别名列表,如 ['say', 'repeat']"""
|
||||
|
||||
|
||||
priority: int = 0
|
||||
"""命令优先级,数字越大优先级越高"""
|
||||
|
||||
|
||||
chat_type_allow: ChatType = ChatType.ALL
|
||||
"""允许的聊天类型"""
|
||||
|
||||
|
||||
intercept_message: bool = False
|
||||
"""是否拦截消息,不进行后续处理"""
|
||||
|
||||
@@ -61,13 +61,13 @@ class PlusCommand(ABC):
|
||||
self.message = message
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.log_prefix = "[PlusCommand]"
|
||||
|
||||
|
||||
# 解析命令参数
|
||||
self._parse_command()
|
||||
|
||||
|
||||
# 验证聊天类型限制
|
||||
if not self._validate_chat_type():
|
||||
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
|
||||
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
|
||||
logger.warning(
|
||||
f"{self.log_prefix} 命令 '{self.command_name}' 不支持当前聊天类型: "
|
||||
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
|
||||
@@ -75,59 +75,59 @@ class PlusCommand(ABC):
|
||||
|
||||
def _parse_command(self) -> None:
|
||||
"""解析命令和参数"""
|
||||
if not hasattr(self.message, 'plain_text') or not self.message.plain_text:
|
||||
if not hasattr(self.message, "plain_text") or not self.message.plain_text:
|
||||
self.args = CommandArgs("")
|
||||
return
|
||||
|
||||
|
||||
plain_text = self.message.plain_text.strip()
|
||||
|
||||
|
||||
# 获取配置的命令前缀
|
||||
prefixes = global_config.command.command_prefixes
|
||||
|
||||
|
||||
# 检查是否以任何前缀开头
|
||||
matched_prefix = None
|
||||
for prefix in prefixes:
|
||||
if plain_text.startswith(prefix):
|
||||
matched_prefix = prefix
|
||||
break
|
||||
|
||||
|
||||
if not matched_prefix:
|
||||
self.args = CommandArgs("")
|
||||
return
|
||||
|
||||
|
||||
# 移除前缀
|
||||
command_part = plain_text[len(matched_prefix):].strip()
|
||||
|
||||
command_part = plain_text[len(matched_prefix) :].strip()
|
||||
|
||||
# 分离命令名和参数
|
||||
parts = command_part.split(None, 1)
|
||||
if not parts:
|
||||
self.args = CommandArgs("")
|
||||
return
|
||||
|
||||
|
||||
command_word = parts[0].lower()
|
||||
args_text = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
|
||||
# 检查命令名是否匹配
|
||||
all_commands = [self.command_name.lower()] + [alias.lower() for alias in self.command_aliases]
|
||||
if command_word not in all_commands:
|
||||
self.args = CommandArgs("")
|
||||
return
|
||||
|
||||
|
||||
# 创建参数对象
|
||||
self.args = CommandArgs(args_text)
|
||||
|
||||
def _validate_chat_type(self) -> bool:
|
||||
"""验证当前聊天类型是否允许执行此命令
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
if self.chat_type_allow == ChatType.ALL:
|
||||
return True
|
||||
|
||||
|
||||
# 检查是否为群聊消息
|
||||
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
|
||||
|
||||
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
|
||||
|
||||
if self.chat_type_allow == ChatType.GROUP and is_group:
|
||||
return True
|
||||
elif self.chat_type_allow == ChatType.PRIVATE and not is_group:
|
||||
@@ -137,7 +137,7 @@ class PlusCommand(ABC):
|
||||
|
||||
def is_chat_type_allowed(self) -> bool:
|
||||
"""检查当前聊天类型是否允许执行此命令
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
@@ -145,30 +145,30 @@ class PlusCommand(ABC):
|
||||
|
||||
def is_command_match(self) -> bool:
|
||||
"""检查当前消息是否匹配此命令
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果匹配返回True
|
||||
"""
|
||||
return not self.args.is_empty() or self._is_exact_command_call()
|
||||
|
||||
|
||||
def _is_exact_command_call(self) -> bool:
|
||||
"""检查是否是精确的命令调用(无参数)"""
|
||||
if not hasattr(self.message, 'plain_text') or not self.message.plain_text:
|
||||
if not hasattr(self.message, "plain_text") or not self.message.plain_text:
|
||||
return False
|
||||
|
||||
|
||||
plain_text = self.message.plain_text.strip()
|
||||
|
||||
|
||||
# 获取配置的命令前缀
|
||||
prefixes = global_config.command.command_prefixes
|
||||
|
||||
|
||||
# 检查每个前缀
|
||||
for prefix in prefixes:
|
||||
if plain_text.startswith(prefix):
|
||||
command_part = plain_text[len(prefix):].strip()
|
||||
command_part = plain_text[len(prefix) :].strip()
|
||||
all_commands = [self.command_name.lower()] + [alias.lower() for alias in self.command_aliases]
|
||||
if command_part.lower() in all_commands:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
@@ -298,10 +298,10 @@ class PlusCommand(ABC):
|
||||
if "." in cls.command_name:
|
||||
logger.error(f"命令名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"命令名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
|
||||
|
||||
# 生成正则表达式模式来匹配命令
|
||||
command_pattern = cls._generate_command_pattern()
|
||||
|
||||
|
||||
return CommandInfo(
|
||||
name=cls.command_name,
|
||||
component_type=ComponentType.COMMAND,
|
||||
@@ -320,7 +320,7 @@ class PlusCommand(ABC):
|
||||
if "." in cls.command_name:
|
||||
logger.error(f"命令名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"命令名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
|
||||
|
||||
return PlusCommandInfo(
|
||||
name=cls.command_name,
|
||||
component_type=ComponentType.PLUS_COMMAND,
|
||||
@@ -334,38 +334,38 @@ class PlusCommand(ABC):
|
||||
@classmethod
|
||||
def _generate_command_pattern(cls) -> str:
|
||||
"""生成命令匹配的正则表达式
|
||||
|
||||
|
||||
Returns:
|
||||
str: 正则表达式字符串
|
||||
"""
|
||||
# 获取所有可能的命令名(主命令名 + 别名)
|
||||
all_commands = [cls.command_name] + getattr(cls, 'command_aliases', [])
|
||||
|
||||
all_commands = [cls.command_name] + getattr(cls, "command_aliases", [])
|
||||
|
||||
# 转义特殊字符并创建选择组
|
||||
escaped_commands = [re.escape(cmd) for cmd in all_commands]
|
||||
commands_pattern = "|".join(escaped_commands)
|
||||
|
||||
|
||||
# 获取默认前缀列表(这里先用硬编码,后续可以优化为动态获取)
|
||||
default_prefixes = ["/", "!", ".", "#"]
|
||||
escaped_prefixes = [re.escape(prefix) for prefix in default_prefixes]
|
||||
prefixes_pattern = "|".join(escaped_prefixes)
|
||||
|
||||
|
||||
# 生成完整的正则表达式
|
||||
# 匹配: [前缀][命令名][可选空白][任意参数]
|
||||
pattern = f"^(?P<prefix>{prefixes_pattern})(?P<command>{commands_pattern})(?P<args>\\s.*)?$"
|
||||
|
||||
|
||||
return pattern
|
||||
|
||||
|
||||
class PlusCommandAdapter(BaseCommand):
|
||||
"""PlusCommand适配器
|
||||
|
||||
|
||||
将PlusCommand适配到现有的插件系统,继承BaseCommand
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, plus_command_class, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
"""初始化适配器
|
||||
|
||||
|
||||
Args:
|
||||
plus_command_class: PlusCommand子类
|
||||
message: 消息对象
|
||||
@@ -378,27 +378,27 @@ class PlusCommandAdapter(BaseCommand):
|
||||
self.chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
|
||||
self.priority = getattr(plus_command_class, "priority", 0)
|
||||
self.intercept_message = getattr(plus_command_class, "intercept_message", False)
|
||||
|
||||
|
||||
# 调用父类初始化
|
||||
super().__init__(message, plugin_config)
|
||||
|
||||
|
||||
# 创建PlusCommand实例
|
||||
self.plus_command = plus_command_class(message, plugin_config)
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
"""执行命令
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str], bool]: 执行结果
|
||||
"""
|
||||
# 检查命令是否匹配
|
||||
if not self.plus_command.is_command_match():
|
||||
return False, "命令不匹配", False
|
||||
|
||||
|
||||
# 检查聊天类型权限
|
||||
if not self.plus_command.is_chat_type_allowed():
|
||||
return False, "不支持当前聊天类型", self.intercept_message
|
||||
|
||||
|
||||
# 执行命令
|
||||
try:
|
||||
return await self.plus_command.execute(self.plus_command.args)
|
||||
@@ -409,49 +409,50 @@ class PlusCommandAdapter(BaseCommand):
|
||||
|
||||
def create_plus_command_adapter(plus_command_class):
|
||||
"""创建PlusCommand适配器的工厂函数
|
||||
|
||||
|
||||
Args:
|
||||
plus_command_class: PlusCommand子类
|
||||
|
||||
|
||||
Returns:
|
||||
适配器类
|
||||
"""
|
||||
|
||||
class AdapterClass(BaseCommand):
|
||||
command_name = plus_command_class.command_name
|
||||
command_description = plus_command_class.command_description
|
||||
command_pattern = plus_command_class._generate_command_pattern()
|
||||
chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
|
||||
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
super().__init__(message, plugin_config)
|
||||
self.plus_command = plus_command_class(message, plugin_config)
|
||||
self.priority = getattr(plus_command_class, "priority", 0)
|
||||
self.intercept_message = getattr(plus_command_class, "intercept_message", False)
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
"""执行命令"""
|
||||
# 从BaseCommand的正则匹配结果中提取参数
|
||||
args_text = ""
|
||||
if hasattr(self, 'matched_groups') and self.matched_groups:
|
||||
if hasattr(self, "matched_groups") and self.matched_groups:
|
||||
# 从正则匹配组中获取参数部分
|
||||
args_match = self.matched_groups.get('args', '')
|
||||
args_match = self.matched_groups.get("args", "")
|
||||
if args_match:
|
||||
args_text = args_match.strip()
|
||||
|
||||
|
||||
# 创建CommandArgs对象
|
||||
command_args = CommandArgs(args_text)
|
||||
|
||||
|
||||
# 检查聊天类型权限
|
||||
if not self.plus_command.is_chat_type_allowed():
|
||||
return False, "不支持当前聊天类型", self.intercept_message
|
||||
|
||||
|
||||
# 执行命令,传递正确解析的参数
|
||||
try:
|
||||
return await self.plus_command.execute(command_args)
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {e}", exc_info=True)
|
||||
return False, f"命令执行出错: {str(e)}", self.intercept_message
|
||||
|
||||
|
||||
return AdapterClass
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user