re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -5,33 +5,49 @@ MaiBot 插件系统
"""
# 导出主要的公共接口
from .apis import (
chat_api,
component_manage_api,
config_api,
database_api,
emoji_api,
generator_api,
get_logger,
llm_api,
message_api,
person_api,
plugin_manage_api,
register_plugin,
send_api,
tool_api,
)
from .base import (
BasePlugin,
ActionActivationType,
ActionInfo,
BaseAction,
BaseCommand,
BaseTool,
ConfigField,
ComponentType,
ActionActivationType,
ChatMode,
ComponentInfo,
ActionInfo,
CommandInfo,
PlusCommandInfo,
PluginInfo,
ToolInfo,
PythonDependency,
BaseEventHandler,
BasePlugin,
BaseTool,
ChatMode,
ChatType,
CommandArgs,
CommandInfo,
ComponentInfo,
ComponentType,
ConfigField,
EventHandlerInfo,
EventType,
MaiMessages,
ToolParamType,
PluginInfo,
# 新增的增强命令系统
PlusCommand,
CommandArgs,
PlusCommandAdapter,
PlusCommandInfo,
PythonDependency,
ToolInfo,
ToolParamType,
create_plus_command_adapter,
ChatType,
)
# 导入工具模块
@@ -41,28 +57,10 @@ from .utils import (
# validate_plugin_manifest,
# generate_plugin_manifest,
)
from .utils.dependency_config import configure_dependency_settings, get_dependency_config
# 导入依赖管理模块
from .utils.dependency_manager import get_dependency_manager, configure_dependency_manager
from .utils.dependency_config import get_dependency_config, configure_dependency_settings
from .apis import (
chat_api,
tool_api,
component_manage_api,
config_api,
database_api,
emoji_api,
generator_api,
llm_api,
message_api,
person_api,
plugin_manage_api,
send_api,
register_plugin,
get_logger,
)
from .utils.dependency_manager import configure_dependency_manager, get_dependency_manager
__version__ = "2.0.0"

View File

@@ -14,14 +14,15 @@ from src.plugin_system.apis import (
generator_api,
llm_api,
message_api,
permission_api,
person_api,
plugin_manage_api,
schedule_api,
send_api,
tool_api,
permission_api,
schedule_api,
)
from src.plugin_system.apis.chat_api import ChatManager as context_api
from .logging_api import get_logger
from .plugin_register_api import register_plugin
@@ -30,18 +31,18 @@ __all__ = [
"chat_api",
"component_manage_api",
"config_api",
"context_api",
"database_api",
"emoji_api",
"generator_api",
"get_logger",
"llm_api",
"message_api",
"permission_api",
"person_api",
"plugin_manage_api",
"send_api",
"get_logger",
"register_plugin",
"tool_api",
"permission_api",
"context_api",
"schedule_api",
"send_api",
"tool_api",
]

View File

@@ -12,11 +12,11 @@
streams = chat.get_all_group_streams()
"""
from typing import List, Dict, Any, Optional
from enum import Enum
from typing import Any
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.common.logger import get_logger
logger = get_logger("chat_api")
@@ -31,7 +31,7 @@ class ChatManager:
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
# sourcery skip: for-append-to-extend
"""获取所有聊天流
@@ -57,7 +57,7 @@ class ChatManager:
return streams
@staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
# sourcery skip: for-append-to-extend
"""获取所有群聊聊天流
@@ -80,7 +80,7 @@ class ChatManager:
return streams
@staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
# sourcery skip: for-append-to-extend
"""获取所有私聊聊天流
@@ -107,8 +107,8 @@ class ChatManager:
@staticmethod
def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
group_id: str, platform: str | None | SpecialTypes = "qq"
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast
"""根据群ID获取聊天流
Args:
@@ -144,8 +144,8 @@ class ChatManager:
@staticmethod
def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
user_id: str, platform: str | None | SpecialTypes = "qq"
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast
"""根据用户ID获取私聊流
Args:
@@ -203,7 +203,7 @@ class ChatManager:
return "unknown"
@staticmethod
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]:
"""获取聊天流详细信息
Args:
@@ -222,7 +222,7 @@ class ChatManager:
raise TypeError("chat_stream 必须是 ChatStream 类型")
try:
info: Dict[str, Any] = {
info: dict[str, Any] = {
"stream_id": chat_stream.stream_id,
"platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream),
@@ -250,7 +250,7 @@ class ChatManager:
return {}
@staticmethod
def get_streams_summary() -> Dict[str, int]:
def get_streams_summary() -> dict[str, int]:
"""获取聊天流统计摘要
Returns:
@@ -285,27 +285,27 @@ class ChatManager:
# =============================================================================
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
"""获取所有聊天流的便捷函数"""
return ChatManager.get_all_streams(platform)
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
"""获取群聊聊天流的便捷函数"""
return ChatManager.get_group_streams(platform)
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
"""获取私聊聊天流的便捷函数"""
return ChatManager.get_private_streams(platform)
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None:
"""根据群ID获取聊天流的便捷函数"""
return ChatManager.get_group_stream_by_group_id(group_id, platform)
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None:
"""根据用户ID获取私聊流的便捷函数"""
return ChatManager.get_private_stream_by_user_id(user_id, platform)
@@ -315,11 +315,11 @@ def get_stream_type(chat_stream: ChatStream) -> str:
return ChatManager.get_stream_type(chat_stream)
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]:
"""获取聊天流信息的便捷函数"""
return ChatManager.get_stream_info(chat_stream)
def get_streams_summary() -> Dict[str, int]:
def get_streams_summary() -> dict[str, int]:
"""获取聊天流统计摘要的便捷函数"""
return ChatManager.get_streams_summary()

View File

@@ -1,16 +1,15 @@
from typing import Optional, Union, Dict
from src.plugin_system.base.component_types import (
CommandInfo,
ActionInfo,
CommandInfo,
ComponentType,
EventHandlerInfo,
PluginInfo,
ComponentType,
ToolInfo,
)
# === 插件信息查询 ===
def get_all_plugin_info() -> Dict[str, PluginInfo]:
def get_all_plugin_info() -> dict[str, PluginInfo]:
"""
获取所有插件的信息。
@@ -22,7 +21,7 @@ def get_all_plugin_info() -> Dict[str, PluginInfo]:
return component_registry.get_all_plugins()
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
def get_plugin_info(plugin_name: str) -> PluginInfo | None:
"""
获取指定插件的信息。
@@ -40,7 +39,7 @@ def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
# === 组件查询方法 ===
def get_component_info(
component_name: str, component_type: ComponentType
) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
) -> CommandInfo | ActionInfo | EventHandlerInfo | None:
"""
获取指定组件的信息。
@@ -57,7 +56,7 @@ def get_component_info(
def get_components_info_by_type(
component_type: ComponentType,
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
) -> dict[str, CommandInfo | ActionInfo | EventHandlerInfo]:
"""
获取指定类型的所有组件信息。
@@ -74,7 +73,7 @@ def get_components_info_by_type(
def get_enabled_components_info_by_type(
component_type: ComponentType,
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
) -> dict[str, CommandInfo | ActionInfo | EventHandlerInfo]:
"""
获取指定类型的所有启用的组件信息。
@@ -90,7 +89,7 @@ def get_enabled_components_info_by_type(
# === Action 查询方法 ===
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
def get_registered_action_info(action_name: str) -> ActionInfo | None:
"""
获取指定 Action 的注册信息。
@@ -105,7 +104,7 @@ def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
return component_registry.get_registered_action_info(action_name)
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
def get_registered_command_info(command_name: str) -> CommandInfo | None:
"""
获取指定 Command 的注册信息。
@@ -120,7 +119,7 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
return component_registry.get_registered_command_info(command_name)
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
def get_registered_tool_info(tool_name: str) -> ToolInfo | None:
"""
获取指定 Tool 的注册信息。
@@ -138,7 +137,7 @@ def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
# === EventHandler 特定查询方法 ===
def get_registered_event_handler_info(
event_handler_name: str,
) -> Optional[EventHandlerInfo]:
) -> EventHandlerInfo | None:
"""
获取指定 EventHandler 的注册信息。

View File

@@ -8,6 +8,7 @@
"""
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config

View File

@@ -3,20 +3,20 @@
"""
import time
from typing import Dict, Any, Optional, List
from typing import Any
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
)
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
build_readable_messages_with_id,
)
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
logger = get_logger("cross_context_api")
def get_context_groups(chat_id: str) -> Optional[List[List[str]]]:
def get_context_groups(chat_id: str) -> list[list[str]] | None:
"""
获取当前聊天所在的共享组的其他聊天ID
"""
@@ -41,7 +41,7 @@ def get_context_groups(chat_id: str) -> Optional[List[List[str]]]:
return None
async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: List[List[str]]) -> str:
async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: list[list[str]]) -> str:
"""
构建跨群聊/私聊上下文 (Normal模式)
"""
@@ -74,8 +74,8 @@ async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos:
async def build_cross_context_s4u(
chat_stream: ChatStream,
other_chat_infos: List[List[str]],
target_user_info: Optional[Dict[str, Any]],
other_chat_infos: list[list[str]],
target_user_info: dict[str, Any] | None,
) -> str:
"""
构建跨群聊/私聊上下文 (S4U模式)

View File

@@ -9,7 +9,7 @@
注意此模块现在使用SQLAlchemy实现提供更好的连接管理和错误处理
"""
from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get, store_action_info, MODEL_MAPPING
from src.common.database.sqlalchemy_database_api import MODEL_MAPPING, db_get, db_query, db_save, store_action_info
# 保持向后兼容性
__all__ = ["db_query", "db_save", "db_get", "store_action_info", "MODEL_MAPPING"]
__all__ = ["MODEL_MAPPING", "db_get", "db_query", "db_save", "store_action_info"]

View File

@@ -10,10 +10,9 @@
import random
from typing import Optional, Tuple, List
from src.common.logger import get_logger
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.utils.utils_image import image_path_to_base64
from src.common.logger import get_logger
logger = get_logger("emoji_api")
@@ -23,7 +22,7 @@ logger = get_logger("emoji_api")
# =============================================================================
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
async def get_by_description(description: str) -> tuple[str, str, str] | None:
"""根据描述选择表情包
Args:
@@ -65,7 +64,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
return None
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
async def get_random(count: int | None = 1) -> list[tuple[str, str, str]]:
"""随机获取指定数量的表情包
Args:
@@ -137,7 +136,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
return []
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
async def get_by_emotion(emotion: str) -> tuple[str, str, str] | None:
"""根据情感标签获取表情包
Args:
@@ -227,7 +226,7 @@ def get_info():
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
def get_emotions() -> List[str]:
def get_emotions() -> list[str]:
"""获取所有可用的情感标签
Returns:
@@ -247,7 +246,7 @@ def get_emotions() -> List[str]:
return []
def get_descriptions() -> List[str]:
def get_descriptions() -> list[str]:
"""获取所有表情包描述
Returns:

View File

@@ -9,13 +9,15 @@
"""
import traceback
from typing import Tuple, Any, Dict, List, Optional
from typing import Any
from rich.traceback import install
from src.common.logger import get_logger
from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.utils import process_llm_response
from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.replyer.replyer_manager import replyer_manager
from src.chat.utils.utils import process_llm_response
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ActionInfo
install(extra_lines=3)
@@ -30,10 +32,10 @@ logger = get_logger("generator_api")
def get_replyer(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
chat_stream: ChatStream | None = None,
chat_id: str | None = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer]:
) -> DefaultReplyer | None:
"""获取回复器对象
优先使用chat_stream如果没有则使用chat_id直接查找。
@@ -71,13 +73,13 @@ def get_replyer(
async def generate_reply(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None,
chat_stream: ChatStream | None = None,
chat_id: str | None = None,
action_data: dict[str, Any] | None = None,
reply_to: str = "",
reply_message: Optional[Dict[str, Any]] = None,
reply_message: dict[str, Any] | None = None,
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
available_actions: dict[str, ActionInfo] | None = None,
enable_tool: bool = False,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
@@ -85,7 +87,7 @@ async def generate_reply(
request_type: str = "generator_api",
from_plugin: bool = True,
read_mark: float = 0.0,
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
) -> tuple[bool, list[tuple[str, Any]], str | None]:
"""生成回复
Args:
@@ -168,9 +170,9 @@ async def generate_reply(
async def rewrite_reply(
chat_stream: Optional[ChatStream] = None,
reply_data: Optional[Dict[str, Any]] = None,
chat_id: Optional[str] = None,
chat_stream: ChatStream | None = None,
reply_data: dict[str, Any] | None = None,
chat_id: str | None = None,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
raw_reply: str = "",
@@ -178,7 +180,7 @@ async def rewrite_reply(
reply_to: str = "",
return_prompt: bool = False,
request_type: str = "generator_api",
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
) -> tuple[bool, list[tuple[str, Any]], str | None]:
"""重写回复
Args:
@@ -237,7 +239,7 @@ async def rewrite_reply(
return False, [], None
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> list[tuple[str, Any]]:
"""将文本处理为更拟人化的文本
Args:
@@ -266,11 +268,11 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
async def generate_response_custom(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
chat_stream: ChatStream | None = None,
chat_id: str | None = None,
request_type: str = "generator_api",
prompt: str = "",
) -> Optional[str]:
) -> str | None:
"""
使用自定义提示生成回复

View File

@@ -7,12 +7,13 @@
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
"""
from typing import Tuple, Dict, List, Any, Optional
from typing import Any
from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.config.config import model_config
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.config.api_ada_configs import TaskConfig
logger = get_logger("llm_api")
@@ -21,7 +22,7 @@ logger = get_logger("llm_api")
# =============================================================================
def get_available_models() -> Dict[str, TaskConfig]:
def get_available_models() -> dict[str, TaskConfig]:
"""获取所有可用的模型配置
Returns:
@@ -31,7 +32,7 @@ def get_available_models() -> Dict[str, TaskConfig]:
# 自动获取所有属性并转换为字典形式
models = model_config.model_task_config
attrs = dir(models)
rets: Dict[str, TaskConfig] = {}
rets: dict[str, TaskConfig] = {}
for attr in attrs:
if not attr.startswith("__"):
try:
@@ -52,9 +53,9 @@ async def generate_with_model(
prompt: str,
model_config: TaskConfig,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str]:
temperature: float | None = None,
max_tokens: int | None = None,
) -> tuple[bool, str, str, str]:
"""使用指定模型生成内容
Args:
@@ -78,7 +79,7 @@ async def generate_with_model(
return True, response, reasoning_content, model_name
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
error_msg = f"生成内容时出错: {e!s}"
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", ""
@@ -86,11 +87,11 @@ async def generate_with_model(
async def generate_with_model_with_tools(
prompt: str,
model_config: TaskConfig,
tool_options: List[Dict[str, Any]] | None = None,
tool_options: list[dict[str, Any]] | None = None,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
temperature: float | None = None,
max_tokens: int | None = None,
) -> tuple[bool, str, str, str, list[ToolCall] | None]:
"""使用指定模型和工具生成内容
Args:
@@ -117,6 +118,6 @@ async def generate_with_model_with_tools(
return True, response, reasoning_content, model_name, tool_call
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
error_msg = f"生成内容时出错: {e!s}"
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", "", None

View File

@@ -8,26 +8,26 @@
readable_text = message_api.build_readable_messages(messages)
"""
from typing import List, Dict, Any, Tuple, Optional
from src.config.config import global_config
import time
from typing import Any
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_users,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
build_readable_messages,
build_readable_messages_with_list,
get_person_id_list,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
)
from src.config.config import global_config
# =============================================================================
# 消息查询API函数
@@ -36,7 +36,7 @@ from src.chat.utils.chat_message_builder import (
async def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
获取指定时间范围内的消息
@@ -70,7 +70,7 @@ async def get_messages_by_time_in_chat(
limit_mode: str = "latest",
filter_mai: bool = False,
filter_command: bool = False,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
获取指定聊天中指定时间范围内的消息
@@ -111,7 +111,7 @@ async def get_messages_by_time_in_chat_inclusive(
limit_mode: str = "latest",
filter_mai: bool = False,
filter_command: bool = False,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
获取指定聊天中指定时间范围内的消息(包含边界)
@@ -152,10 +152,10 @@ async def get_messages_by_time_in_chat_for_users(
chat_id: str,
start_time: float,
end_time: float,
person_ids: List[str],
person_ids: list[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
获取指定聊天中指定用户在指定时间范围内的消息
@@ -186,7 +186,7 @@ async def get_messages_by_time_in_chat_for_users(
async def get_random_chat_messages(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
随机选择一个聊天,返回该聊天在指定时间范围内的消息
@@ -213,8 +213,8 @@ async def get_random_chat_messages(
async def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
start_time: float, end_time: float, person_ids: list[str], limit: int = 0, limit_mode: str = "latest"
) -> list[dict[str, Any]]:
"""
获取指定用户在所有聊天中指定时间范围内的消息
@@ -238,7 +238,7 @@ async def get_messages_by_time_for_users(
return await get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]:
async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> list[dict[str, Any]]:
"""
获取指定时间戳之前的消息
@@ -294,8 +294,8 @@ async def get_messages_before_time_in_chat(
async def get_messages_before_time_for_users(
timestamp: float, person_ids: List[str], limit: int = 0
) -> List[Dict[str, Any]]:
timestamp: float, person_ids: list[str], limit: int = 0
) -> list[dict[str, Any]]:
"""
获取指定用户在指定时间戳之前的消息
@@ -319,7 +319,7 @@ async def get_messages_before_time_for_users(
async def get_recent_messages(
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
获取指定聊天中最近一段时间的消息
@@ -358,7 +358,7 @@ async def get_recent_messages(
# =============================================================================
async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: float | None = None) -> int:
"""
计算指定聊天中从开始时间到结束时间的新消息数量
@@ -382,7 +382,7 @@ async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Op
return await num_new_messages_since(chat_id, start_time, end_time)
async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: list[str]) -> int:
"""
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
@@ -413,7 +413,7 @@ async def count_new_messages_for_users(chat_id: str, start_time: float, end_time
async def build_readable_messages_to_str(
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
@@ -442,12 +442,12 @@ async def build_readable_messages_to_str(
async def build_readable_messages_with_details(
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
) -> tuple[str, list[tuple[float, str, str]]]:
"""
将消息列表构建成可读的字符串,并返回详细信息
@@ -464,7 +464,7 @@ async def build_readable_messages_with_details(
return await build_readable_messages_with_list(messages, replace_bot_name, merge_messages, timestamp_mode, truncate)
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
async def get_person_ids_from_messages(messages: list[dict[str, Any]]) -> list[str]:
"""
从消息列表中提取不重复的用户ID列表
@@ -482,7 +482,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
# =============================================================================
async def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
async def filter_mai_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
从消息列表中移除麦麦的消息
Args:

View File

@@ -1,9 +1,9 @@
"""纯异步权限API定义。所有外部调用方必须使用 await。"""
from typing import Optional, List, Dict, Any
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from abc import ABC, abstractmethod
from typing import Any
from src.common.logger import get_logger
@@ -48,18 +48,18 @@ class IPermissionManager(ABC):
async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: ...
@abstractmethod
async def get_user_permissions(self, user: UserInfo) -> List[str]: ...
async def get_user_permissions(self, user: UserInfo) -> list[str]: ...
@abstractmethod
async def get_all_permission_nodes(self) -> List[PermissionNode]: ...
async def get_all_permission_nodes(self) -> list[PermissionNode]: ...
@abstractmethod
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: ...
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[PermissionNode]: ...
class PermissionAPI:
def __init__(self):
self._permission_manager: Optional[IPermissionManager] = None
self._permission_manager: IPermissionManager | None = None
# 需要保留的前缀(视为绝对节点名,不再自动加 plugins.<plugin>. 前缀)
self.RESERVED_PREFIXES: tuple[str, ...] = "system."
# 系统节点列表 (name, description, default_granted)
@@ -147,11 +147,11 @@ class PermissionAPI:
self._ensure_manager()
return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node)
async def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
async def get_user_permissions(self, platform: str, user_id: str) -> list[str]:
self._ensure_manager()
return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id))
async def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
async def get_all_permission_nodes(self) -> list[dict[str, Any]]:
self._ensure_manager()
nodes = await self._permission_manager.get_all_permission_nodes()
return [
@@ -164,7 +164,7 @@ class PermissionAPI:
for n in nodes
]
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[dict[str, Any]]:
self._ensure_manager()
nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name)
return [

View File

@@ -7,9 +7,10 @@
value = await person_api.get_person_value(person_id, "nickname")
"""
from typing import Any, Optional
from typing import Any
from src.common.logger import get_logger
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
logger = get_logger("person_api")
@@ -63,7 +64,7 @@ async def get_person_value(person_id: str, field_name: str, default: Any = None)
return default
async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict:
async def get_person_values(person_id: str, field_names: list, default_dict: dict | None = None) -> dict:
"""批量获取用户信息字段值
Args:

View File

@@ -1,7 +1,4 @@
from typing import Tuple, List
def list_loaded_plugins() -> List[str]:
def list_loaded_plugins() -> list[str]:
"""
列出所有当前加载的插件。
@@ -13,7 +10,7 @@ def list_loaded_plugins() -> List[str]:
return plugin_manager.list_loaded_plugins()
def list_registered_plugins() -> List[str]:
def list_registered_plugins() -> list[str]:
"""
列出所有已注册的插件。
@@ -80,7 +77,7 @@ async def reload_plugin(plugin_name: str) -> bool:
return await plugin_manager.reload_registered_plugin(plugin_name)
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
def load_plugin(plugin_name: str) -> tuple[bool, int]:
"""
加载指定的插件。
@@ -109,7 +106,7 @@ def add_plugin_directory(plugin_directory: str) -> bool:
return plugin_manager.add_plugin_directory(plugin_directory)
def rescan_plugin_directory() -> Tuple[int, int]:
def rescan_plugin_directory() -> tuple[int, int]:
"""
重新扫描插件目录,加载新插件。
Returns:

View File

@@ -6,8 +6,8 @@ logger = get_logger("plugin_manager") # 复用plugin_manager名称
def register_plugin(cls):
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.core.plugin_manager import plugin_manager
"""插件注册装饰器

View File

@@ -30,7 +30,7 @@
"""
from datetime import datetime
from typing import List, Dict, Any, Optional
from typing import Any
from src.common.database.sqlalchemy_models import MonthlyPlan
from src.common.logger import get_logger
@@ -44,7 +44,7 @@ class ScheduleAPI:
"""日程表与月度计划API - 负责日程和计划信息的查询与管理"""
@staticmethod
async def get_today_schedule() -> Optional[List[Dict[str, Any]]]:
async def get_today_schedule() -> list[dict[str, Any]] | None:
"""(异步) 获取今天的日程安排
Returns:
@@ -58,7 +58,7 @@ class ScheduleAPI:
return None
@staticmethod
async def get_current_activity() -> Optional[str]:
async def get_current_activity() -> str | None:
"""(异步) 获取当前正在进行的活动
Returns:
@@ -87,7 +87,7 @@ class ScheduleAPI:
return False
@staticmethod
async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]:
async def get_monthly_plans(target_month: str | None = None) -> list[MonthlyPlan]:
"""(异步) 获取指定月份的有效月度计划
Args:
@@ -106,7 +106,7 @@ class ScheduleAPI:
return []
@staticmethod
async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool:
async def ensure_monthly_plans(target_month: str | None = None) -> bool:
"""(异步) 确保指定月份存在月度计划,如果不存在则触发生成
Args:
@@ -125,7 +125,7 @@ class ScheduleAPI:
return False
@staticmethod
async def archive_monthly_plans(target_month: Optional[str] = None) -> bool:
async def archive_monthly_plans(target_month: str | None = None) -> bool:
"""(异步) 归档指定月份的月度计划
Args:
@@ -150,12 +150,12 @@ class ScheduleAPI:
# =============================================================================
async def get_today_schedule() -> Optional[List[Dict[str, Any]]]:
async def get_today_schedule() -> list[dict[str, Any]] | None:
"""(异步) 获取今天的日程安排的便捷函数"""
return await ScheduleAPI.get_today_schedule()
async def get_current_activity() -> Optional[str]:
async def get_current_activity() -> str | None:
"""(异步) 获取当前正在进行的活动的便捷函数"""
return await ScheduleAPI.get_current_activity()
@@ -165,16 +165,16 @@ async def regenerate_schedule() -> bool:
return await ScheduleAPI.regenerate_schedule()
async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]:
async def get_monthly_plans(target_month: str | None = None) -> list[MonthlyPlan]:
"""(异步) 获取指定月份的有效月度计划的便捷函数"""
return await ScheduleAPI.get_monthly_plans(target_month)
async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool:
async def ensure_monthly_plans(target_month: str | None = None) -> bool:
"""(异步) 确保指定月份存在月度计划的便捷函数"""
return await ScheduleAPI.ensure_monthly_plans(target_month)
async def archive_monthly_plans(target_month: Optional[str] = None) -> bool:
async def archive_monthly_plans(target_month: str | None = None) -> bool:
"""(异步) 归档指定月份的月度计划的便捷函数"""
return await ScheduleAPI.archive_monthly_plans(target_month)

View File

@@ -28,29 +28,28 @@
"""
import traceback
import time
import asyncio
from typing import Optional, Union, Dict, Any
from src.common.logger import get_logger
import time
import traceback
from typing import Any
from maim_message import Seg, UserInfo
# 导入依赖
from src.chat.message_receive.chat_stream import get_chat_manager
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.message import MessageRecv, MessageSending
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.message_receive.message import MessageSending, MessageRecv
from maim_message import Seg
from src.common.logger import get_logger
from src.config.config import global_config
# 日志记录器
logger = get_logger("send_api")
# 适配器命令响应等待池
_adapter_response_pool: Dict[str, asyncio.Future] = {}
_adapter_response_pool: dict[str, asyncio.Future] = {}
def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[MessageRecv]:
def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None:
"""查找要回复的消息
Args:
@@ -134,13 +133,13 @@ async def wait_adapter_response(request_id: str, timeout: float = 30.0) -> dict:
async def _send_to_target(
message_type: str,
content: Union[str, dict],
content: str | dict,
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_to: str = "",
set_reply: bool = False,
reply_to_message: Optional[Dict[str, Any]] = None,
reply_to_message: dict[str, Any] | None = None,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
@@ -247,7 +246,7 @@ async def text_to_stream(
stream_id: str,
typing: bool = False,
reply_to: str = "",
reply_to_message: Optional[Dict[str, Any]] = None,
reply_to_message: dict[str, Any] | None = None,
set_reply: bool = True,
storage_message: bool = True,
) -> bool:
@@ -313,7 +312,7 @@ async def image_to_stream(
async def command_to_stream(
command: Union[str, dict],
command: str | dict,
stream_id: str,
storage_message: bool = True,
display_message: str = "",
@@ -341,7 +340,7 @@ async def custom_to_stream(
display_message: str = "",
typing: bool = False,
reply_to: str = "",
reply_to_message: Optional[Dict[str, Any]] = None,
reply_to_message: dict[str, Any] | None = None,
set_reply: bool = True,
storage_message: bool = True,
show_log: bool = True,
@@ -377,8 +376,8 @@ async def custom_to_stream(
async def adapter_command_to_stream(
action: str,
params: dict,
platform: Optional[str] = "qq",
stream_id: Optional[str] = None,
platform: str | None = "qq",
stream_id: str | None = None,
timeout: float = 30.0,
storage_message: bool = False,
) -> dict:
@@ -497,4 +496,4 @@ async def adapter_command_to_stream(
except Exception as e:
logger.error(f"[SendAPI] 发送适配器命令时出错: {e}")
traceback.print_exc()
return {"status": "error", "message": f"发送适配器命令时出错: {str(e)}"}
return {"status": "error", "message": f"发送适配器命令时出错: {e!s}"}

View File

@@ -1,13 +1,11 @@
from typing import Optional, Type
from src.common.logger import get_logger
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType
from src.common.logger import get_logger
logger = get_logger("tool_api")
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
def get_tool_instance(tool_name: str) -> BaseTool | None:
"""获取公开工具实例"""
from src.plugin_system.core import component_registry
@@ -18,7 +16,7 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
else:
plugin_config = None
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
tool_class: type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
return tool_class(plugin_config) if tool_class else None

View File

@@ -4,31 +4,31 @@
提供插件开发的基础类和类型定义
"""
from .base_plugin import BasePlugin
from .base_action import BaseAction
from .base_tool import BaseTool
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .base_plugin import BasePlugin
from .base_tool import BaseTool
from .command_args import CommandArgs
from .component_types import (
ComponentType,
ActionActivationType,
ActionInfo,
ChatMode,
ChatType,
ComponentInfo,
ActionInfo,
CommandInfo,
PlusCommandInfo,
ToolInfo,
PluginInfo,
PythonDependency,
ComponentInfo,
ComponentType,
EventHandlerInfo,
EventType,
MaiMessages,
PluginInfo,
PlusCommandInfo,
PythonDependency,
ToolInfo,
ToolParamType,
)
from .config_types import ConfigField
from .plus_command import PlusCommand, PlusCommandAdapter, create_plus_command_adapter
from .command_args import CommandArgs
__all__ = [
"BasePlugin",

View File

@@ -1,14 +1,11 @@
import time
import asyncio
import time
from abc import ABC, abstractmethod
from typing import Tuple, Optional, List, Dict
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType, ChatType
from src.plugin_system.apis import send_api, database_api, message_api
from src.common.logger import get_logger
from src.plugin_system.apis import database_api, message_api, send_api
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ChatMode, ChatType, ComponentType
logger = get_logger("base_action")
@@ -39,7 +36,7 @@ class BaseAction(ABC):
"""是否为二步Action。如果为TrueAction将分两步执行第一步选择操作第二步执行具体操作"""
step_one_description: str = ""
"""第一步的描述用于向LLM展示Action的基本功能"""
sub_actions: List[Tuple[str, str, Dict[str, str]]] = []
sub_actions: list[tuple[str, str, dict[str, str]]] = []
"""子Action列表格式为[(子Action名, 子Action描述, 子Action参数)]。仅在二步Action中使用"""
def __init__(
@@ -50,8 +47,8 @@ class BaseAction(ABC):
thinking_id: str,
chat_stream: ChatStream,
log_prefix: str = "",
plugin_config: Optional[dict] = None,
action_message: Optional[dict] = None,
plugin_config: dict | None = None,
action_message: dict | None = None,
**kwargs,
):
# sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs
@@ -109,8 +106,8 @@ class BaseAction(ABC):
# 二步Action相关实例属性
self.is_two_step_action: bool = getattr(self.__class__, "is_two_step_action", False)
self.step_one_description: str = getattr(self.__class__, "step_one_description", "")
self.sub_actions: List[Tuple[str, str, Dict[str, str]]] = getattr(self.__class__, "sub_actions", []).copy()
self._selected_sub_action: Optional[str] = None
self.sub_actions: list[tuple[str, str, dict[str, str]]] = getattr(self.__class__, "sub_actions", []).copy()
self._selected_sub_action: str | None = None
"""当前选择的子Action名称用于二步Action的状态管理"""
# =============================================================================
@@ -200,7 +197,7 @@ class BaseAction(ABC):
"""
return self._validate_chat_type()
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
async def wait_for_new_message(self, timeout: int = 1200) -> tuple[bool, str]:
"""等待新消息或超时
在loop_start_time之后等待新消息如果没有新消息且没有超时就一直等待。
@@ -232,7 +229,7 @@ class BaseAction(ABC):
# 检查新消息
current_time = time.time()
new_message_count = message_api.count_new_messages(
new_message_count = await message_api.count_new_messages(
chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time
)
@@ -258,7 +255,7 @@ class BaseAction(ABC):
return False, ""
except Exception as e:
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
return False, f"等待新消息失败: {str(e)}"
return False, f"等待新消息失败: {e!s}"
async def send_text(self, content: str, reply_to: str = "", typing: bool = False) -> bool:
"""发送文本消息
@@ -359,7 +356,7 @@ class BaseAction(ABC):
)
async def send_command(
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
self, command_name: str, args: dict | None = None, display_message: str = "", storage_message: bool = True
) -> bool:
"""发送命令消息
@@ -400,7 +397,7 @@ class BaseAction(ABC):
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
return False
async def call_action(self, action_name: str, action_data: Optional[dict] = None) -> Tuple[bool, str]:
async def call_action(self, action_name: str, action_data: dict | None = None) -> tuple[bool, str]:
"""
在当前Action中调用另一个Action。
@@ -514,7 +511,7 @@ class BaseAction(ABC):
sub_actions=getattr(cls, "sub_actions", []).copy(),
)
async def handle_step_one(self) -> Tuple[bool, str]:
async def handle_step_one(self) -> tuple[bool, str]:
"""处理二步Action的第一步
Returns:
@@ -546,7 +543,7 @@ class BaseAction(ABC):
# 调用第二步执行
return await self.execute_step_two(selected_action)
async def execute_step_two(self, sub_action_name: str) -> Tuple[bool, str]:
async def execute_step_two(self, sub_action_name: str) -> tuple[bool, str]:
"""执行二步Action的第二步
Args:
@@ -562,7 +559,7 @@ class BaseAction(ABC):
return False, f"二步Action必须实现execute_step_two方法来处理操作: {sub_action_name}"
@abstractmethod
async def execute(self) -> Tuple[bool, str]:
async def execute(self) -> tuple[bool, str]:
"""执行Action的抽象方法子类必须实现
对于二步Action会自动处理第一步逻辑
@@ -577,7 +574,7 @@ class BaseAction(ABC):
# 普通Action由子类实现
pass
async def handle_action(self) -> Tuple[bool, str]:
async def handle_action(self) -> tuple[bool, str]:
"""兼容旧系统的handle_action接口委托给execute方法
为了保持向后兼容性旧系统的代码可能会调用handle_action方法。

View File

@@ -1,9 +1,11 @@
from abc import ABC, abstractmethod
from typing import List, TYPE_CHECKING
from typing import TYPE_CHECKING
from src.common.data_models.message_manager_data_model import StreamContext
from .component_types import ChatType
from src.plugin_system.base.component_types import ChatterInfo, ComponentType
from .component_types import ChatType
if TYPE_CHECKING:
from src.chat.planner_actions.action_manager import ChatterActionManager
@@ -13,7 +15,7 @@ class BaseChatter(ABC):
"""Chatter组件的名称"""
chatter_description: str = ""
"""Chatter组件的描述"""
chat_types: List[ChatType] = [ChatType.PRIVATE, ChatType.GROUP]
chat_types: list[ChatType] = [ChatType.PRIVATE, ChatType.GROUP]
def __init__(self, stream_id: str, action_manager: "ChatterActionManager"):
"""

View File

@@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional
from src.common.logger import get_logger
from src.plugin_system.base.component_types import CommandInfo, ComponentType, ChatType
from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger
from src.plugin_system.apis import send_api
from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType
logger = get_logger("base_command")
@@ -29,7 +29,7 @@ class BaseCommand(ABC):
chat_type_allow: ChatType = ChatType.ALL
"""允许的聊天类型,默认为所有类型"""
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
"""初始化Command组件
Args:
@@ -37,7 +37,7 @@ class BaseCommand(ABC):
plugin_config: 插件配置字典
"""
self.message = message
self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组
self.matched_groups: dict[str, str] = {} # 存储正则表达式匹配的命名组
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
self.log_prefix = "[Command]"
@@ -55,7 +55,7 @@ class BaseCommand(ABC):
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
)
def set_matched_groups(self, groups: Dict[str, str]) -> None:
def set_matched_groups(self, groups: dict[str, str]) -> None:
"""设置正则表达式匹配的命名组
Args:
@@ -93,7 +93,7 @@ class BaseCommand(ABC):
return self._validate_chat_type()
@abstractmethod
async def execute(self) -> Tuple[bool, Optional[str], bool]:
async def execute(self) -> tuple[bool, str | None, bool]:
"""执行Command的抽象方法子类必须实现
Returns:
@@ -175,7 +175,7 @@ class BaseCommand(ABC):
)
async def send_command(
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
self, command_name: str, args: dict | None = None, display_message: str = "", storage_message: bool = True
) -> bool:
"""发送命令消息

View File

@@ -1,5 +1,5 @@
import asyncio
from typing import List, Dict, Any, Optional
from typing import Any
from src.common.logger import get_logger
@@ -25,22 +25,22 @@ class HandlerResult:
class HandlerResultsCollection:
"""HandlerResult集合提供便捷的查询方法"""
def __init__(self, results: List[HandlerResult]):
def __init__(self, results: list[HandlerResult]):
self.results = results
def all_continue_process(self) -> bool:
"""检查是否所有handler的continue_process都为True"""
return all(result.continue_process for result in self.results)
def get_all_results(self) -> List[HandlerResult]:
def get_all_results(self) -> list[HandlerResult]:
"""获取所有HandlerResult"""
return self.results
def get_failed_handlers(self) -> List[HandlerResult]:
def get_failed_handlers(self) -> list[HandlerResult]:
"""获取执行失败的handler结果"""
return [result for result in self.results if not result.success]
def get_stopped_handlers(self) -> List[HandlerResult]:
def get_stopped_handlers(self) -> list[HandlerResult]:
"""获取continue_process为False的handler结果"""
return [result for result in self.results if not result.continue_process]
@@ -57,7 +57,7 @@ class HandlerResultsCollection:
else:
return {result.handler_name: result.message for result in self.results}
def get_handler_result(self, handler_name: str) -> Optional[HandlerResult]:
def get_handler_result(self, handler_name: str) -> HandlerResult | None:
"""获取指定handler的结果"""
for result in self.results:
if result.handler_name == handler_name:
@@ -72,7 +72,7 @@ class HandlerResultsCollection:
"""获取执行失败的handler数量"""
return sum(1 for result in self.results if not result.success)
def get_summary(self) -> Dict[str, Any]:
def get_summary(self) -> dict[str, Any]:
"""获取执行摘要"""
return {
"total_handlers": len(self.results),
@@ -85,13 +85,13 @@ class HandlerResultsCollection:
class BaseEvent:
def __init__(self, name: str, allowed_subscribers: List[str] = None, allowed_triggers: List[str] = None):
def __init__(self, name: str, allowed_subscribers: list[str] = None, allowed_triggers: list[str] = None):
self.name = name
self.enabled = True
self.allowed_subscribers = allowed_subscribers # 记录事件处理器名
self.allowed_triggers = allowed_triggers # 记录插件名
self.subscribers: List["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表
self.subscribers: list["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表
self.event_handle_lock = asyncio.Lock()

View File

@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from typing import Tuple, Optional, List, Union
from src.common.logger import get_logger
from .component_types import EventType, EventHandlerInfo, ComponentType
from .component_types import ComponentType, EventHandlerInfo, EventType
logger = get_logger("base_event_handler")
@@ -21,7 +21,7 @@ class BaseEventHandler(ABC):
"""处理器权重,越大权重越高"""
intercept_message: bool = False
"""是否拦截消息,默认为否"""
init_subscribe: List[Union[EventType, str]] = [EventType.UNKNOWN]
init_subscribe: list[EventType | str] = [EventType.UNKNOWN]
"""初始化时订阅的事件名称"""
plugin_name = None
@@ -44,7 +44,7 @@ class BaseEventHandler(ABC):
self.plugin_config = getattr(self.__class__, "plugin_config", {})
@abstractmethod
async def execute(self, kwargs: dict | None) -> Tuple[bool, bool, Optional[str]]:
async def execute(self, kwargs: dict | None) -> tuple[bool, bool, str | None]:
"""执行事件处理的抽象方法,子类必须实现
Args:
kwargs (dict | None): 事件消息对象当你注册的事件为ON_START和ON_STOP时message为None

View File

@@ -1,13 +1,13 @@
from abc import abstractmethod
from typing import List, Type, Tuple, Union
from .plugin_base import PluginBase
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, PlusCommandInfo, EventHandlerInfo, ToolInfo
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, PlusCommandInfo, ToolInfo
from .base_action import BaseAction
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .base_tool import BaseTool
from .plugin_base import PluginBase
from .plus_command import PlusCommand
logger = get_logger("base_plugin")
@@ -28,14 +28,12 @@ class BasePlugin(PluginBase):
@abstractmethod
def get_plugin_components(
self,
) -> List[
Union[
Tuple[ActionInfo, Type[BaseAction]],
Tuple[CommandInfo, Type[BaseCommand]],
Tuple[PlusCommandInfo, Type[PlusCommand]],
Tuple[EventHandlerInfo, Type[BaseEventHandler]],
Tuple[ToolInfo, Type[BaseTool]],
]
) -> list[
tuple[ActionInfo, type[BaseAction]]
| tuple[CommandInfo, type[BaseCommand]]
| tuple[PlusCommandInfo, type[PlusCommand]]
| tuple[EventHandlerInfo, type[BaseEventHandler]]
| tuple[ToolInfo, type[BaseTool]]
]:
"""获取插件包含的组件列表

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple
from typing import Any
from rich.traceback import install
from src.common.logger import get_logger
@@ -17,7 +18,7 @@ class BaseTool(ABC):
"""工具的名称"""
description: str = ""
"""工具的描述"""
parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = []
parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
"""工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式
param_name: 参数名称
param_type: 参数类型
@@ -35,7 +36,7 @@ class BaseTool(ABC):
"""是否为该工具启用缓存"""
cache_ttl: int = 3600
"""缓存的TTL值默认为3600秒1小时"""
semantic_cache_query_key: Optional[str] = None
semantic_cache_query_key: str | None = None
"""用于语义缓存的查询参数键名。如果设置,将使用此参数的值进行语义相似度搜索"""
# 二步工具调用相关属性
@@ -43,10 +44,10 @@ class BaseTool(ABC):
"""是否为二步工具。如果为True工具将分两步调用第一步展示工具信息第二步执行具体操作"""
step_one_description: str = ""
"""第一步的描述用于向LLM展示工具的基本功能"""
sub_tools: List[Tuple[str, str, List[Tuple[str, ToolParamType, str, bool, List[str] | None]]]] = []
sub_tools: list[tuple[str, str, list[tuple[str, ToolParamType, str, bool, list[str] | None]]]] = []
"""子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用"""
def __init__(self, plugin_config: Optional[dict] = None):
def __init__(self, plugin_config: dict | None = None):
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
@classmethod
@@ -101,7 +102,7 @@ class BaseTool(ABC):
raise ValueError(f"未找到子工具: {sub_tool_name}")
@classmethod
def get_all_sub_tool_definitions(cls) -> List[dict[str, Any]]:
def get_all_sub_tool_definitions(cls) -> list[dict[str, Any]]:
"""获取所有子工具的定义
Returns:

View File

@@ -3,7 +3,6 @@
提供简单易用的命令参数解析功能
"""
from typing import List, Optional
import shlex
@@ -20,7 +19,7 @@ class CommandArgs:
raw_args: 原始参数字符串
"""
self._raw_args = raw_args.strip()
self._parsed_args: Optional[List[str]] = None
self._parsed_args: list[str] | None = None
def get_raw(self) -> str:
"""获取完整的参数字符串
@@ -30,7 +29,7 @@ class CommandArgs:
"""
return self._raw_args
def get_args(self) -> List[str]:
def get_args(self) -> list[str]:
"""获取解析后的参数列表
将参数按空格分割,支持引号包围的参数

View File

@@ -1,10 +1,11 @@
from enum import Enum
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from maim_message import Seg
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
# 组件类型枚举
@@ -114,7 +115,7 @@ class ComponentInfo:
enabled: bool = True # 是否启用
plugin_name: str = "" # 所属插件名称
is_built_in: bool = False # 是否为内置组件
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
metadata: dict[str, Any] = field(default_factory=dict) # 额外元数据
def __post_init__(self):
if self.metadata is None:
@@ -125,18 +126,18 @@ class ComponentInfo:
class ActionInfo(ComponentInfo):
"""动作组件信息"""
action_parameters: Dict[str, str] = field(
action_parameters: dict[str, str] = field(
default_factory=dict
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
action_require: List[str] = field(default_factory=list) # 动作需求说明
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
action_require: list[str] = field(default_factory=list) # 动作需求说明
associated_types: list[str] = field(default_factory=list) # 关联的消息类型
# 激活类型相关
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS
activation_type: ActionActivationType = ActionActivationType.ALWAYS
random_activation_probability: float = 0.0
llm_judge_prompt: str = ""
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
activation_keywords: list[str] = field(default_factory=list) # 激活关键词列表
keyword_case_sensitive: bool = False
# 模式和并行设置
mode_enable: ChatMode = ChatMode.ALL
@@ -145,7 +146,7 @@ class ActionInfo(ComponentInfo):
# 二步Action相关属性
is_two_step_action: bool = False # 是否为二步Action
step_one_description: str = "" # 第一步的描述
sub_actions: List[Tuple[str, str, Dict[str, str]]] = field(default_factory=list) # 子Action列表
sub_actions: list[tuple[str, str, dict[str, str]]] = field(default_factory=list) # 子Action列表
def __post_init__(self):
super().__post_init__()
@@ -178,7 +179,7 @@ class CommandInfo(ComponentInfo):
class PlusCommandInfo(ComponentInfo):
"""增强命令组件信息"""
command_aliases: List[str] = field(default_factory=list) # 命令别名列表
command_aliases: list[str] = field(default_factory=list) # 命令别名列表
priority: int = 0 # 命令优先级
chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型
intercept_message: bool = False # 是否拦截消息
@@ -194,7 +195,7 @@ class PlusCommandInfo(ComponentInfo):
class ToolInfo(ComponentInfo):
"""工具组件信息"""
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(
tool_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = field(
default_factory=list
) # 工具参数定义
tool_description: str = "" # 工具描述
@@ -248,18 +249,18 @@ class PluginInfo:
author: str = "" # 插件作者
enabled: bool = True # 是否启用
is_built_in: bool = False # 是否为内置插件
components: List[ComponentInfo] = field(default_factory=list) # 包含的组件列表
dependencies: List[str] = field(default_factory=list) # 依赖的其他插件
python_dependencies: List[PythonDependency] = field(default_factory=list) # Python包依赖
components: list[ComponentInfo] = field(default_factory=list) # 包含的组件列表
dependencies: list[str] = field(default_factory=list) # 依赖的其他插件
python_dependencies: list[PythonDependency] = field(default_factory=list) # Python包依赖
config_file: str = "" # 配置文件路径
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
metadata: dict[str, Any] = field(default_factory=dict) # 额外元数据
# 新增manifest相关信息
manifest_data: Dict[str, Any] = field(default_factory=dict) # manifest文件数据
manifest_data: dict[str, Any] = field(default_factory=dict) # manifest文件数据
license: str = "" # 插件许可证
homepage_url: str = "" # 插件主页
repository_url: str = "" # 插件仓库地址
keywords: List[str] = field(default_factory=list) # 插件关键词
categories: List[str] = field(default_factory=list) # 插件分类
keywords: list[str] = field(default_factory=list) # 插件关键词
categories: list[str] = field(default_factory=list) # 插件分类
min_host_version: str = "" # 最低主机版本要求
max_host_version: str = "" # 最高主机版本要求
@@ -279,7 +280,7 @@ class PluginInfo:
if self.categories is None:
self.categories = []
def get_missing_packages(self) -> List[PythonDependency]:
def get_missing_packages(self) -> list[PythonDependency]:
"""检查缺失的Python包"""
missing = []
for dep in self.python_dependencies:
@@ -290,7 +291,7 @@ class PluginInfo:
missing.append(dep)
return missing
def get_pip_requirements(self) -> List[str]:
def get_pip_requirements(self) -> list[str]:
"""获取所有pip安装格式的依赖"""
return [dep.get_pip_requirement() for dep in self.python_dependencies]
@@ -299,16 +300,16 @@ class PluginInfo:
class MaiMessages:
"""MaiM插件消息"""
message_segments: List[Seg] = field(default_factory=list)
message_segments: list[Seg] = field(default_factory=list)
"""消息段列表,支持多段消息"""
message_base_info: Dict[str, Any] = field(default_factory=dict)
message_base_info: dict[str, Any] = field(default_factory=dict)
"""消息基本信息,包含平台,用户信息等数据"""
plain_text: str = ""
"""纯文本消息内容"""
raw_message: Optional[str] = None
raw_message: str | None = None
"""原始消息内容"""
is_group_message: bool = False
@@ -317,28 +318,28 @@ class MaiMessages:
is_private_message: bool = False
"""是否为私聊消息"""
stream_id: Optional[str] = None
stream_id: str | None = None
"""流ID用于标识消息流"""
llm_prompt: Optional[str] = None
llm_prompt: str | None = None
"""LLM提示词"""
llm_response_content: Optional[str] = None
llm_response_content: str | None = None
"""LLM响应内容"""
llm_response_reasoning: Optional[str] = None
llm_response_reasoning: str | None = None
"""LLM响应推理内容"""
llm_response_model: Optional[str] = None
llm_response_model: str | None = None
"""LLM响应模型名称"""
llm_response_tool_call: Optional[List[ToolCall]] = None
llm_response_tool_call: list[ToolCall] | None = None
"""LLM使用的工具调用"""
action_usage: Optional[List[str]] = None
action_usage: list[str] | None = None
"""使用的Action"""
additional_data: Dict[Any, Any] = field(default_factory=dict)
additional_data: dict[Any, Any] = field(default_factory=dict)
"""附加数据,可以存储额外信息"""
def __post_init__(self):

View File

@@ -2,8 +2,8 @@
插件系统配置类型定义
"""
from typing import Any, Optional, List
from dataclasses import dataclass, field
from typing import Any
@dataclass
@@ -13,6 +13,6 @@ class ConfigField:
type: type # 字段类型
default: Any # 默认值
description: str # 字段描述
example: Optional[str] = None # 示例值
example: str | None = None # 示例值
required: bool = False # 是否必需
choices: Optional[List[Any]] = field(default_factory=list) # 可选值列表
choices: list[Any] | None = field(default_factory=list) # 可选值列表

View File

@@ -1,11 +1,12 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Union
import os
import toml
import orjson
import shutil
import datetime
import os
import shutil
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
import orjson
import toml
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
@@ -38,12 +39,12 @@ class PluginBase(ABC):
@property
@abstractmethod
def dependencies(self) -> List[str]:
def dependencies(self) -> list[str]:
return [] # 依赖的其他插件
@property
@abstractmethod
def python_dependencies(self) -> List[Union[str, PythonDependency]]:
def python_dependencies(self) -> list[str | PythonDependency]:
return [] # Python包依赖支持字符串列表或PythonDependency对象列表
@property
@@ -53,15 +54,15 @@ class PluginBase(ABC):
# manifest文件相关
manifest_file_name: str = "_manifest.json" # manifest文件名
manifest_data: Dict[str, Any] = {} # manifest数据
manifest_data: dict[str, Any] = {} # manifest数据
# 配置定义
@property
@abstractmethod
def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]:
def config_schema(self) -> dict[str, dict[str, ConfigField] | str]:
return {}
config_section_descriptions: Dict[str, str] = {}
config_section_descriptions: dict[str, str] = {}
def __init__(self, plugin_dir: str):
"""初始化插件
@@ -69,7 +70,7 @@ class PluginBase(ABC):
Args:
plugin_dir: 插件目录路径,由插件管理器传递
"""
self.config: Dict[str, Any] = {} # 插件配置
self.config: dict[str, Any] = {} # 插件配置
self.plugin_dir = plugin_dir # 插件目录路径
self.log_prefix = f"[Plugin:{self.plugin_name}]"
self._is_enabled = self.enable_plugin # 从插件定义中获取默认启用状态
@@ -144,7 +145,7 @@ class PluginBase(ABC):
raise FileNotFoundError(error_msg)
try:
with open(manifest_path, "r", encoding="utf-8") as f:
with open(manifest_path, encoding="utf-8") as f:
self.manifest_data = orjson.loads(f.read())
logger.debug(f"{self.log_prefix} 成功加载manifest文件: {manifest_path}")
@@ -155,8 +156,8 @@ class PluginBase(ABC):
except orjson.JSONDecodeError as e:
error_msg = f"{self.log_prefix} manifest文件格式错误: {e}"
logger.error(error_msg)
raise ValueError(error_msg) # noqa
except IOError as e:
raise ValueError(error_msg)
except OSError as e:
error_msg = f"{self.log_prefix} 读取manifest文件失败: {e}"
logger.error(error_msg)
raise IOError(error_msg) # noqa
@@ -266,7 +267,7 @@ class PluginBase(ABC):
with open(config_file_path, "w", encoding="utf-8") as f:
f.write(toml_str)
logger.info(f"{self.log_prefix} 已生成默认配置文件: {config_file_path}")
except IOError as e:
except OSError as e:
logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True)
def _backup_config_file(self, config_file_path: str) -> str:
@@ -288,13 +289,13 @@ class PluginBase(ABC):
return ""
def _synchronize_config(
self, schema_config: Dict[str, Any], user_config: Dict[str, Any]
) -> tuple[Dict[str, Any], bool]:
self, schema_config: dict[str, Any], user_config: dict[str, Any]
) -> tuple[dict[str, Any], bool]:
"""递归地将用户配置与 schema 同步,返回同步后的配置和是否发生变化的标志"""
changed = False
# 内部递归函数
def _sync_dicts(schema_dict: Dict[str, Any], user_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, Any]:
def _sync_dicts(schema_dict: dict[str, Any], user_dict: dict[str, Any], parent_key: str = "") -> dict[str, Any]:
nonlocal changed
synced_dict = schema_dict.copy()
@@ -326,7 +327,7 @@ class PluginBase(ABC):
final_config = _sync_dicts(schema_config, user_config)
return final_config, changed
def _generate_config_from_schema(self) -> Dict[str, Any]:
def _generate_config_from_schema(self) -> dict[str, Any]:
# sourcery skip: dict-comprehension
"""根据schema生成配置数据结构不写入文件"""
if not self.config_schema:
@@ -348,7 +349,7 @@ class PluginBase(ABC):
return config_data
def _save_config_to_file(self, config_data: Dict[str, Any], config_file_path: str):
def _save_config_to_file(self, config_data: dict[str, Any], config_file_path: str):
"""将配置数据保存为TOML文件包含注释"""
if not self.config_schema:
logger.debug(f"{self.log_prefix} 插件未定义config_schema不生成配置文件")
@@ -410,7 +411,7 @@ class PluginBase(ABC):
with open(config_file_path, "w", encoding="utf-8") as f:
f.write(toml_str)
logger.info(f"{self.log_prefix} 配置文件已保存: {config_file_path}")
except IOError as e:
except OSError as e:
logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True)
def _load_plugin_config(self): # sourcery skip: extract-method
@@ -456,7 +457,7 @@ class PluginBase(ABC):
return
try:
with open(user_config_path, "r", encoding="utf-8") as f:
with open(user_config_path, encoding="utf-8") as f:
user_config = toml.load(f) or {}
except Exception as e:
logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True)
@@ -520,7 +521,7 @@ class PluginBase(ABC):
return current
def _normalize_python_dependencies(self, dependencies: Any) -> List[PythonDependency]:
def _normalize_python_dependencies(self, dependencies: Any) -> list[PythonDependency]:
"""将依赖列表标准化为PythonDependency对象"""
from packaging.requirements import Requirement
@@ -549,7 +550,7 @@ class PluginBase(ABC):
return normalized
def _check_python_dependencies(self, dependencies: List[PythonDependency]) -> bool:
def _check_python_dependencies(self, dependencies: list[PythonDependency]) -> bool:
"""检查Python依赖并尝试自动安装"""
if not dependencies:
logger.info(f"{self.log_prefix} 无Python依赖需要检查")

View File

@@ -3,17 +3,16 @@
提供更简单易用的命令处理方式,无需手写正则表达式
"""
from abc import ABC, abstractmethod
from typing import Tuple, Optional, List
import re
from abc import ABC, abstractmethod
from src.common.logger import get_logger
from src.plugin_system.base.component_types import PlusCommandInfo, ComponentType, ChatType
from src.chat.message_receive.message import MessageRecv
from src.plugin_system.apis import send_api
from src.plugin_system.base.command_args import CommandArgs
from src.plugin_system.base.base_command import BaseCommand
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis import send_api
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.command_args import CommandArgs
from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo
logger = get_logger("plus_command")
@@ -39,7 +38,7 @@ class PlusCommand(ABC):
command_description: str = ""
"""命令描述"""
command_aliases: List[str] = []
command_aliases: list[str] = []
"""命令别名列表,如 ['say', 'repeat']"""
priority: int = 0
@@ -51,7 +50,7 @@ class PlusCommand(ABC):
intercept_message: bool = False
"""是否拦截消息,不进行后续处理"""
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
"""初始化命令组件
Args:
@@ -172,7 +171,7 @@ class PlusCommand(ABC):
return False
@abstractmethod
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
"""执行命令的抽象方法,子类必须实现
Args:
@@ -341,7 +340,7 @@ class PlusCommandAdapter(BaseCommand):
将PlusCommand适配到现有的插件系统继承BaseCommand
"""
def __init__(self, plus_command_class, message: MessageRecv, plugin_config: Optional[dict] = None):
def __init__(self, plus_command_class, message: MessageRecv, plugin_config: dict | None = None):
"""初始化适配器
Args:
@@ -363,7 +362,7 @@ class PlusCommandAdapter(BaseCommand):
# 创建PlusCommand实例
self.plus_command = plus_command_class(message, plugin_config)
async def execute(self) -> Tuple[bool, Optional[str], bool]:
async def execute(self) -> tuple[bool, str | None, bool]:
"""执行命令
Returns:
@@ -382,7 +381,7 @@ class PlusCommandAdapter(BaseCommand):
return await self.plus_command.execute(self.plus_command.args)
except Exception as e:
logger.error(f"执行命令时出错: {e}", exc_info=True)
return False, f"命令执行出错: {str(e)}", self.intercept_message
return False, f"命令执行出错: {e!s}", self.intercept_message
def create_plus_command_adapter(plus_command_class):
@@ -401,13 +400,13 @@ def create_plus_command_adapter(plus_command_class):
command_pattern = plus_command_class._generate_command_pattern()
chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
super().__init__(message, plugin_config)
self.plus_command = plus_command_class(message, plugin_config)
self.priority = getattr(plus_command_class, "priority", 0)
self.intercept_message = getattr(plus_command_class, "intercept_message", False)
async def execute(self) -> Tuple[bool, Optional[str], bool]:
async def execute(self) -> tuple[bool, str | None, bool]:
"""执行命令"""
# 从BaseCommand的正则匹配结果中提取参数
args_text = ""
@@ -429,7 +428,7 @@ def create_plus_command_adapter(plus_command_class):
return await self.plus_command.execute(command_args)
except Exception as e:
logger.error(f"执行命令时出错: {e}", exc_info=True)
return False, f"命令执行出错: {str(e)}", self.intercept_message
return False, f"命令执行出错: {e!s}", self.intercept_message
return AdapterClass

View File

@@ -4,14 +4,14 @@
提供插件的加载、注册和管理功能
"""
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.plugin_manager import plugin_manager
__all__ = [
"plugin_manager",
"component_registry",
"event_manager",
"global_announcement_manager",
"plugin_manager",
]

View File

@@ -1,27 +1,26 @@
from pathlib import Path
import re
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
from pathlib import Path
from re import Pattern
from typing import Any, Optional, Union
from src.common.logger import get_logger
from src.plugin_system.base.component_types import (
ComponentInfo,
ActionInfo,
ToolInfo,
CommandInfo,
PlusCommandInfo,
EventHandlerInfo,
ChatterInfo,
PluginInfo,
ComponentType,
)
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.plus_command import PlusCommand
from src.plugin_system.base.base_chatter import BaseChatter
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import (
ActionInfo,
ChatterInfo,
CommandInfo,
ComponentInfo,
ComponentType,
EventHandlerInfo,
PluginInfo,
PlusCommandInfo,
ToolInfo,
)
from src.plugin_system.base.plus_command import PlusCommand
logger = get_logger("component_registry")
@@ -34,46 +33,46 @@ class ComponentRegistry:
def __init__(self):
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
self._components: Dict[str, "ComponentInfo"] = {}
self._components: dict[str, "ComponentInfo"] = {}
"""组件注册表 命名空间式组件名 -> 组件信息"""
self._components_by_type: Dict["ComponentType", Dict[str, "ComponentInfo"]] = {
self._components_by_type: dict["ComponentType", dict[str, "ComponentInfo"]] = {
types: {} for types in ComponentType
}
"""类型 -> 组件原名称 -> 组件信息"""
self._components_classes: Dict[
str, Type[Union["BaseCommand", "BaseAction", "BaseTool", "BaseEventHandler", "PlusCommand", "BaseChatter"]]
self._components_classes: dict[
str, type["BaseCommand" | "BaseAction" | "BaseTool" | "BaseEventHandler" | "PlusCommand" | "BaseChatter"]
] = {}
"""命名空间式组件名 -> 组件类"""
# 插件注册表
self._plugins: Dict[str, "PluginInfo"] = {}
self._plugins: dict[str, "PluginInfo"] = {}
"""插件名 -> 插件信息"""
# Action特定注册表
self._action_registry: Dict[str, Type["BaseAction"]] = {}
self._action_registry: dict[str, type["BaseAction"]] = {}
"""Action注册表 action名 -> action类"""
self._default_actions: Dict[str, "ActionInfo"] = {}
self._default_actions: dict[str, "ActionInfo"] = {}
"""默认动作集即启用的Action集用于重置ActionManager状态"""
# Command特定注册表
self._command_registry: Dict[str, Type["BaseCommand"]] = {}
self._command_registry: dict[str, type["BaseCommand"]] = {}
"""Command类注册表 command名 -> command类"""
self._command_patterns: Dict[Pattern, str] = {}
self._command_patterns: dict[Pattern, str] = {}
"""编译后的正则 -> command名"""
# 工具特定注册表
self._tool_registry: Dict[str, Type["BaseTool"]] = {} # 工具名 -> 工具类
self._llm_available_tools: Dict[str, Type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
self._tool_registry: dict[str, type["BaseTool"]] = {} # 工具名 -> 工具类
self._llm_available_tools: dict[str, type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
# EventHandler特定注册表
self._event_handler_registry: Dict[str, Type["BaseEventHandler"]] = {}
self._event_handler_registry: dict[str, type["BaseEventHandler"]] = {}
"""event_handler名 -> event_handler类"""
self._enabled_event_handlers: Dict[str, Type["BaseEventHandler"]] = {}
self._enabled_event_handlers: dict[str, type["BaseEventHandler"]] = {}
"""启用的事件处理器 event_handler名 -> event_handler类"""
self._chatter_registry: Dict[str, Type["BaseChatter"]] = {}
self._chatter_registry: dict[str, type["BaseChatter"]] = {}
"""chatter名 -> chatter类"""
self._enabled_chatter_registry: Dict[str, Type["BaseChatter"]] = {}
self._enabled_chatter_registry: dict[str, type["BaseChatter"]] = {}
"""启用的chatter名 -> chatter类"""
logger.info("组件注册中心初始化完成")
@@ -101,7 +100,7 @@ class ComponentRegistry:
def register_component(
self,
component_info: ComponentInfo,
component_class: Type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]],
component_class: type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]],
) -> bool:
"""注册组件
@@ -174,7 +173,7 @@ class ComponentRegistry:
)
return True
def _register_action_component(self, action_info: "ActionInfo", action_class: Type["BaseAction"]) -> bool:
def _register_action_component(self, action_info: "ActionInfo", action_class: type["BaseAction"]) -> bool:
"""注册Action组件到Action特定注册表"""
if not (action_name := action_info.name):
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
@@ -194,7 +193,7 @@ class ComponentRegistry:
return True
def _register_command_component(self, command_info: "CommandInfo", command_class: Type["BaseCommand"]) -> bool:
def _register_command_component(self, command_info: "CommandInfo", command_class: type["BaseCommand"]) -> bool:
"""注册Command组件到Command特定注册表"""
if not (command_name := command_info.name):
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
@@ -221,7 +220,7 @@ class ComponentRegistry:
return True
def _register_plus_command_component(
self, plus_command_info: "PlusCommandInfo", plus_command_class: Type["PlusCommand"]
self, plus_command_info: "PlusCommandInfo", plus_command_class: type["PlusCommand"]
) -> bool:
"""注册PlusCommand组件到特定注册表"""
plus_command_name = plus_command_info.name
@@ -235,7 +234,7 @@ class ComponentRegistry:
# 创建专门的PlusCommand注册表如果还没有
if not hasattr(self, "_plus_command_registry"):
self._plus_command_registry: Dict[str, Type["PlusCommand"]] = {}
self._plus_command_registry: dict[str, type["PlusCommand"]] = {}
plus_command_class.plugin_name = plus_command_info.plugin_name
# 设置插件配置
@@ -245,7 +244,7 @@ class ComponentRegistry:
logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
return True
def _register_tool_component(self, tool_info: "ToolInfo", tool_class: Type["BaseTool"]) -> bool:
def _register_tool_component(self, tool_info: "ToolInfo", tool_class: type["BaseTool"]) -> bool:
"""注册Tool组件到Tool特定注册表"""
tool_name = tool_info.name
@@ -261,7 +260,7 @@ class ComponentRegistry:
return True
def _register_event_handler_component(
self, handler_info: "EventHandlerInfo", handler_class: Type["BaseEventHandler"]
self, handler_info: "EventHandlerInfo", handler_class: type["BaseEventHandler"]
) -> bool:
if not (handler_name := handler_info.name):
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
@@ -287,7 +286,7 @@ class ComponentRegistry:
handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
)
def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: Type["BaseChatter"]) -> bool:
def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: type["BaseChatter"]) -> bool:
"""注册Chatter组件到Chatter特定注册表"""
chatter_name = chatter_info.name
@@ -532,7 +531,7 @@ class ComponentRegistry:
self,
component_name: str,
component_type: Optional["ComponentType"] = None,
) -> Optional[Union[Type["BaseCommand"], Type["BaseAction"], Type["BaseEventHandler"], Type["BaseTool"]]]:
) -> type["BaseCommand"] | type["BaseAction"] | type["BaseEventHandler"] | type["BaseTool"] | None:
"""获取组件类,支持自动命名空间解析
Args:
@@ -574,18 +573,18 @@ class ComponentRegistry:
# 4. 都没找到
return None
def get_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]:
def get_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]:
"""获取指定类型的所有组件"""
return self._components_by_type.get(component_type, {}).copy()
def get_enabled_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]:
def get_enabled_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]:
"""获取指定类型的所有启用组件"""
components = self.get_components_by_type(component_type)
return {name: info for name, info in components.items() if info.enabled}
# === Action特定查询方法 ===
def get_action_registry(self) -> Dict[str, Type["BaseAction"]]:
def get_action_registry(self) -> dict[str, type["BaseAction"]]:
"""获取Action注册表"""
return self._action_registry.copy()
@@ -594,13 +593,13 @@ class ComponentRegistry:
info = self.get_component_info(action_name, ComponentType.ACTION)
return info if isinstance(info, ActionInfo) else None
def get_default_actions(self) -> Dict[str, ActionInfo]:
def get_default_actions(self) -> dict[str, ActionInfo]:
"""获取默认动作集"""
return self._default_actions.copy()
# === Command特定查询方法 ===
def get_command_registry(self) -> Dict[str, Type["BaseCommand"]]:
def get_command_registry(self) -> dict[str, type["BaseCommand"]]:
"""获取Command注册表"""
return self._command_registry.copy()
@@ -609,11 +608,11 @@ class ComponentRegistry:
info = self.get_component_info(command_name, ComponentType.COMMAND)
return info if isinstance(info, CommandInfo) else None
def get_command_patterns(self) -> Dict[Pattern, str]:
def get_command_patterns(self) -> dict[Pattern, str]:
"""获取Command模式注册表"""
return self._command_patterns.copy()
def find_command_by_text(self, text: str) -> Optional[Tuple[Type["BaseCommand"], dict, "CommandInfo"]]:
def find_command_by_text(self, text: str) -> tuple[type["BaseCommand"], dict, "CommandInfo"] | None:
# sourcery skip: use-named-expression, use-next
"""根据文本查找匹配的命令
@@ -640,11 +639,11 @@ class ComponentRegistry:
return None
# === Tool 特定查询方法 ===
def get_tool_registry(self) -> Dict[str, Type["BaseTool"]]:
def get_tool_registry(self) -> dict[str, type["BaseTool"]]:
"""获取Tool注册表"""
return self._tool_registry.copy()
def get_llm_available_tools(self) -> Dict[str, Type["BaseTool"]]:
def get_llm_available_tools(self) -> dict[str, type["BaseTool"]]:
"""获取LLM可用的Tool列表"""
return self._llm_available_tools.copy()
@@ -661,10 +660,10 @@ class ComponentRegistry:
return info if isinstance(info, ToolInfo) else None
# === PlusCommand 特定查询方法 ===
def get_plus_command_registry(self) -> Dict[str, Type["PlusCommand"]]:
def get_plus_command_registry(self) -> dict[str, type["PlusCommand"]]:
"""获取PlusCommand注册表"""
if not hasattr(self, "_plus_command_registry"):
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
self._plus_command_registry: dict[str, type[PlusCommand]] = {}
return self._plus_command_registry.copy()
def get_registered_plus_command_info(self, command_name: str) -> Optional["PlusCommandInfo"]:
@@ -681,7 +680,7 @@ class ComponentRegistry:
# === EventHandler 特定查询方法 ===
def get_event_handler_registry(self) -> Dict[str, Type["BaseEventHandler"]]:
def get_event_handler_registry(self) -> dict[str, type["BaseEventHandler"]]:
"""获取事件处理器注册表"""
return self._event_handler_registry.copy()
@@ -690,21 +689,21 @@ class ComponentRegistry:
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
return info if isinstance(info, EventHandlerInfo) else None
def get_enabled_event_handlers(self) -> Dict[str, Type["BaseEventHandler"]]:
def get_enabled_event_handlers(self) -> dict[str, type["BaseEventHandler"]]:
"""获取启用的事件处理器"""
return self._enabled_event_handlers.copy()
# === Chatter 特定查询方法 ===
def get_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]:
def get_chatter_registry(self) -> dict[str, type["BaseChatter"]]:
"""获取Chatter注册表"""
if not hasattr(self, "_chatter_registry"):
self._chatter_registry: Dict[str, Type[BaseChatter]] = {}
self._chatter_registry: dict[str, type[BaseChatter]] = {}
return self._chatter_registry.copy()
def get_enabled_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]:
def get_enabled_chatter_registry(self) -> dict[str, type["BaseChatter"]]:
"""获取启用的Chatter注册表"""
if not hasattr(self, "_enabled_chatter_registry"):
self._enabled_chatter_registry: Dict[str, Type[BaseChatter]] = {}
self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {}
return self._enabled_chatter_registry.copy()
def get_registered_chatter_info(self, chatter_name: str) -> Optional["ChatterInfo"]:
@@ -718,7 +717,7 @@ class ComponentRegistry:
"""获取插件信息"""
return self._plugins.get(plugin_name)
def get_all_plugins(self) -> Dict[str, "PluginInfo"]:
def get_all_plugins(self) -> dict[str, "PluginInfo"]:
"""获取所有插件"""
return self._plugins.copy()
@@ -726,7 +725,7 @@ class ComponentRegistry:
# """获取所有启用的插件"""
# return {name: info for name, info in self._plugins.items() if info.enabled}
def get_plugin_components(self, plugin_name: str) -> List["ComponentInfo"]:
def get_plugin_components(self, plugin_name: str) -> list["ComponentInfo"]:
"""获取插件的所有组件"""
plugin_info = self.get_plugin_info(plugin_name)
return plugin_info.components if plugin_info else []
@@ -753,7 +752,7 @@ class ComponentRegistry:
config_path = Path("config") / "plugins" / plugin_name / "config.toml"
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
with open(config_path, encoding="utf-8") as f:
config_data = toml.load(f)
logger.debug(f"从配置文件读取插件 {plugin_name} 的配置")
return config_data
@@ -762,7 +761,7 @@ class ComponentRegistry:
return {}
def get_registry_stats(self) -> Dict[str, Any]:
def get_registry_stats(self) -> dict[str, Any]:
"""获取注册中心统计信息"""
action_components: int = 0
command_components: int = 0

View File

@@ -3,8 +3,8 @@
提供统一的事件注册、管理和触发接口
"""
from typing import Dict, Type, List, Optional, Any, Union
from threading import Lock
from typing import Any, Optional
from src.common.logger import get_logger
from src.plugin_system import BaseEventHandler
@@ -37,17 +37,17 @@ class EventManager:
if self._initialized:
return
self._events: Dict[str, BaseEvent] = {}
self._event_handlers: Dict[str, Type[BaseEventHandler]] = {}
self._pending_subscriptions: Dict[str, List[str]] = {} # 缓存失败的订阅
self._events: dict[str, BaseEvent] = {}
self._event_handlers: dict[str, type[BaseEventHandler]] = {}
self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅
self._initialized = True
logger.info("EventManager 单例初始化完成")
def register_event(
self,
event_name: Union[EventType, str],
allowed_subscribers: List[str] = None,
allowed_triggers: List[str] = None,
event_name: EventType | str,
allowed_subscribers: list[str] = None,
allowed_triggers: list[str] = None,
) -> bool:
"""注册一个新的事件
@@ -75,7 +75,7 @@ class EventManager:
return True
def get_event(self, event_name: Union[EventType, str]) -> Optional[BaseEvent]:
def get_event(self, event_name: EventType | str) -> BaseEvent | None:
"""获取指定事件实例
Args:
@@ -86,7 +86,7 @@ class EventManager:
"""
return self._events.get(event_name)
def get_all_events(self) -> Dict[str, BaseEvent]:
def get_all_events(self) -> dict[str, BaseEvent]:
"""获取所有已注册的事件
Returns:
@@ -94,7 +94,7 @@ class EventManager:
"""
return self._events.copy()
def get_enabled_events(self) -> Dict[str, BaseEvent]:
def get_enabled_events(self) -> dict[str, BaseEvent]:
"""获取所有已启用的事件
Returns:
@@ -102,7 +102,7 @@ class EventManager:
"""
return {name: event for name, event in self._events.items() if event.enabled}
def get_disabled_events(self) -> Dict[str, BaseEvent]:
def get_disabled_events(self) -> dict[str, BaseEvent]:
"""获取所有已禁用的事件
Returns:
@@ -110,7 +110,7 @@ class EventManager:
"""
return {name: event for name, event in self._events.items() if not event.enabled}
def enable_event(self, event_name: Union[EventType, str]) -> bool:
def enable_event(self, event_name: EventType | str) -> bool:
"""启用指定事件
Args:
@@ -128,7 +128,7 @@ class EventManager:
logger.info(f"事件 {event_name} 已启用")
return True
def disable_event(self, event_name: Union[EventType, str]) -> bool:
def disable_event(self, event_name: EventType | str) -> bool:
"""禁用指定事件
Args:
@@ -146,9 +146,7 @@ class EventManager:
logger.info(f"事件 {event_name} 已禁用")
return True
def register_event_handler(
self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None
) -> bool:
def register_event_handler(self, handler_class: type[BaseEventHandler], plugin_config: dict | None = None) -> bool:
"""注册事件处理器
Args:
@@ -190,7 +188,7 @@ class EventManager:
logger.info(f"事件处理器 {handler_name} 注册成功")
return True
def get_event_handler(self, handler_name: str) -> Optional[Type[BaseEventHandler]]:
def get_event_handler(self, handler_name: str) -> type[BaseEventHandler] | None:
"""获取指定事件处理器实例
Args:
@@ -209,7 +207,7 @@ class EventManager:
"""
return self._event_handlers.copy()
def subscribe_handler_to_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
def subscribe_handler_to_event(self, handler_name: str, event_name: EventType | str) -> bool:
"""订阅事件处理器到指定事件
Args:
@@ -246,7 +244,7 @@ class EventManager:
logger.info(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成")
return True
def unsubscribe_handler_from_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
def unsubscribe_handler_from_event(self, handler_name: str, event_name: EventType | str) -> bool:
"""从指定事件取消订阅事件处理器
Args:
@@ -276,7 +274,7 @@ class EventManager:
return removed
def get_event_subscribers(self, event_name: Union[EventType, str]) -> Dict[str, BaseEventHandler]:
def get_event_subscribers(self, event_name: EventType | str) -> dict[str, BaseEventHandler]:
"""获取订阅指定事件的所有事件处理器
Args:
@@ -292,8 +290,8 @@ class EventManager:
return {handler.handler_name: handler for handler in event.subscribers}
async def trigger_event(
self, event_name: Union[EventType, str], permission_group: Optional[str] = "", **kwargs
) -> Optional[HandlerResultsCollection]:
self, event_name: EventType | str, permission_group: str | None = "", **kwargs
) -> HandlerResultsCollection | None:
"""触发指定事件
Args:
@@ -345,7 +343,7 @@ class EventManager:
self._event_handlers.clear()
logger.info("所有事件和处理器已清除")
def get_event_summary(self) -> Dict[str, Any]:
def get_event_summary(self) -> dict[str, Any]:
"""获取事件系统摘要
Returns:
@@ -364,7 +362,7 @@ class EventManager:
"pending_subscriptions": len(self._pending_subscriptions),
}
def _process_pending_subscriptions(self, event_name: Union[EventType, str]) -> None:
def _process_pending_subscriptions(self, event_name: EventType | str) -> None:
"""处理指定事件的缓存订阅
Args:

View File

@@ -1,5 +1,3 @@
from typing import List, Dict
from src.common.logger import get_logger
logger = get_logger("global_announcement_manager")
@@ -8,13 +6,13 @@ logger = get_logger("global_announcement_manager")
class GlobalAnnouncementManager:
def __init__(self) -> None:
# 用户禁用的动作chat_id -> [action_name]
self._user_disabled_actions: Dict[str, List[str]] = {}
self._user_disabled_actions: dict[str, list[str]] = {}
# 用户禁用的命令chat_id -> [command_name]
self._user_disabled_commands: Dict[str, List[str]] = {}
self._user_disabled_commands: dict[str, list[str]] = {}
# 用户禁用的事件处理器chat_id -> [handler_name]
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
self._user_disabled_event_handlers: dict[str, list[str]] = {}
# 用户禁用的工具chat_id -> [tool_name]
self._user_disabled_tools: Dict[str, List[str]] = {}
self._user_disabled_tools: dict[str, list[str]] = {}
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
"""禁用特定聊天的某个动作"""
@@ -100,19 +98,19 @@ class GlobalAnnouncementManager:
return False
return False
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
def get_disabled_chat_actions(self, chat_id: str) -> list[str]:
"""获取特定聊天禁用的所有动作"""
return self._user_disabled_actions.get(chat_id, []).copy()
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
def get_disabled_chat_commands(self, chat_id: str) -> list[str]:
"""获取特定聊天禁用的所有命令"""
return self._user_disabled_commands.get(chat_id, []).copy()
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
def get_disabled_chat_event_handlers(self, chat_id: str) -> list[str]:
"""获取特定聊天禁用的所有事件处理器"""
return self._user_disabled_event_handlers.get(chat_id, []).copy()
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
def get_disabled_chat_tools(self, chat_id: str) -> list[str]:
"""获取特定聊天禁用的所有工具"""
return self._user_disabled_tools.get(chat_id, []).copy()

View File

@@ -4,16 +4,16 @@
这个模块提供了权限系统的核心实现,包括权限检查、权限节点管理、用户权限管理等功能。
"""
from typing import List, Set, Tuple
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from datetime import datetime
from sqlalchemy import select, delete
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.ext.asyncio import async_sessionmaker
from src.common.database.sqlalchemy_models import PermissionNodes, UserPermissions, get_engine
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import get_engine, PermissionNodes, UserPermissions
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo
from src.config.config import global_config
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo
logger = get_logger(__name__)
@@ -24,7 +24,7 @@ class PermissionManager(IPermissionManager):
def __init__(self):
self.engine = None
self.SessionLocal = None
self._master_users: Set[Tuple[str, str]] = set()
self._master_users: set[tuple[str, str]] = set()
self._load_master_users()
async def initialize(self):
@@ -276,7 +276,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"撤销权限时发生未知错误: {e}")
return False
async def get_user_permissions(self, user: UserInfo) -> List[str]:
async def get_user_permissions(self, user: UserInfo) -> list[str]:
"""
获取用户拥有的所有权限节点
@@ -328,7 +328,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"获取用户权限时发生未知错误: {e}")
return []
async def get_all_permission_nodes(self) -> List[PermissionNode]:
async def get_all_permission_nodes(self) -> list[PermissionNode]:
"""
获取所有已注册的权限节点
@@ -356,7 +356,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"获取所有权限节点时发生未知错误: {e}")
return []
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[PermissionNode]:
"""
获取指定插件的所有权限节点
@@ -431,7 +431,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"删除插件权限时发生未知错误: {e}")
return False
async def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]:
async def get_users_with_permission(self, permission_node: str) -> list[tuple[str, str]]:
"""
获取拥有指定权限的所有用户

View File

@@ -1,19 +1,17 @@
import asyncio
import importlib
import os
import traceback
import importlib
from typing import Dict, List, Optional, Tuple, Type, Any
from importlib.util import spec_from_file_location, module_from_spec
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from typing import Any, Optional
from src.common.logger import get_logger
from src.plugin_system.base.plugin_base import PluginBase
from src.plugin_system.base.component_types import ComponentType
from src.plugin_system.base.plugin_base import PluginBase
from src.plugin_system.utils.manifest_utils import VersionComparator
from .component_registry import component_registry
from .component_registry import component_registry
logger = get_logger("plugin_manager")
@@ -26,12 +24,12 @@ class PluginManager:
"""
def __init__(self):
self.plugin_directories: List[str] = [] # 插件根目录列表
self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
self.plugin_directories: list[str] = [] # 插件根目录列表
self.plugin_classes: dict[str, type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
self.plugin_paths: dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
self.loaded_plugins: dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
self.failed_plugins: dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
# 确保插件目录存在
self._ensure_plugin_directories()
@@ -54,7 +52,7 @@ class PluginManager:
# === 插件加载管理 ===
def load_all_plugins(self) -> Tuple[int, int]:
def load_all_plugins(self) -> tuple[int, int]:
"""加载所有插件
Returns:
@@ -87,7 +85,7 @@ class PluginManager:
return total_registered, total_failed_registration
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
def load_registered_plugin_classes(self, plugin_name: str) -> tuple[bool, int]:
# sourcery skip: extract-duplicate-method, extract-method
"""
加载已经注册的插件类
@@ -142,7 +140,7 @@ class PluginManager:
except FileNotFoundError as e:
# manifest文件缺失
error_msg = f"缺少manifest文件: {str(e)}"
error_msg = f"缺少manifest文件: {e!s}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
@@ -150,14 +148,14 @@ class PluginManager:
except ValueError as e:
# manifest文件格式错误或验证失败
traceback.print_exc()
error_msg = f"manifest验证失败: {str(e)}"
error_msg = f"manifest验证失败: {e!s}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except Exception as e:
# 其他错误
error_msg = f"未知错误: {str(e)}"
error_msg = f"未知错误: {e!s}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
logger.debug("详细错误信息: ", exc_info=True)
@@ -192,7 +190,7 @@ class PluginManager:
logger.debug(f"插件 {plugin_name} 重载成功")
return True
def rescan_plugin_directory(self) -> Tuple[int, int]:
def rescan_plugin_directory(self) -> tuple[int, int]:
"""
重新扫描插件根目录
"""
@@ -220,7 +218,7 @@ class PluginManager:
return self.loaded_plugins.get(plugin_name)
# === 查询方法 ===
def list_loaded_plugins(self) -> List[str]:
def list_loaded_plugins(self) -> list[str]:
"""
列出所有当前加载的插件。
@@ -229,7 +227,7 @@ class PluginManager:
"""
return list(self.loaded_plugins.keys())
def list_registered_plugins(self) -> List[str]:
def list_registered_plugins(self) -> list[str]:
"""
列出所有已注册的插件类。
@@ -238,7 +236,7 @@ class PluginManager:
"""
return list(self.plugin_classes.keys())
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
def get_plugin_path(self, plugin_name: str) -> str | None:
"""
获取指定插件的路径。
@@ -329,7 +327,7 @@ class PluginManager:
# == 兼容性检查 ==
@staticmethod
def _check_plugin_version_compatibility(plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
def _check_plugin_version_compatibility(plugin_name: str, manifest_data: dict[str, Any]) -> tuple[bool, str]:
"""检查插件版本兼容性
Args:
@@ -569,7 +567,7 @@ class PluginManager:
return True
except Exception as e:
logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}", exc_info=True)
logger.error(f"❌ 插件卸载失败: {plugin_name} - {e!s}", exc_info=True)
return False
def reload_plugin(self, plugin_name: str) -> bool:
@@ -606,7 +604,7 @@ class PluginManager:
return False
except Exception as e:
logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}", exc_info=True)
logger.error(f"❌ 插件重载失败: {plugin_name} - {e!s}", exc_info=True)
return False
def force_reload_plugin(self, plugin_name: str) -> bool:

View File

@@ -1,16 +1,17 @@
import inspect
import time
from typing import List, Dict, Tuple, Optional, Any
from typing import Any
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.payload_content import ToolCall
from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.chat.utils.prompt import Prompt, global_prompt_manager
import inspect
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache
logger = get_logger("tool_use")
@@ -56,14 +57,14 @@ class ToolExecutor:
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
# 二步工具调用状态管理
self._pending_step_two_tools: Dict[str, Dict[str, Any]] = {}
self._pending_step_two_tools: dict[str, dict[str, Any]] = {}
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
logger.info(f"{self.log_prefix}工具执行器初始化完成")
async def execute_from_chat_message(
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
) -> Tuple[List[Dict[str, Any]], List[str], str]:
) -> tuple[list[dict[str, Any]], list[str], str]:
"""从聊天消息执行工具
Args:
@@ -113,7 +114,7 @@ class ToolExecutor:
else:
return tool_results, [], ""
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
def _get_tool_definitions(self) -> list[dict[str, Any]]:
all_tools = get_llm_available_tool_definitions()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
@@ -129,7 +130,7 @@ class ToolExecutor:
return tool_definitions
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
"""执行工具调用
Args:
@@ -138,7 +139,7 @@ class ToolExecutor:
Returns:
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
"""
tool_results: List[Dict[str, Any]] = []
tool_results: list[dict[str, Any]] = []
used_tools = []
if not tool_calls:
@@ -192,7 +193,7 @@ class ToolExecutor:
error_info = {
"type": "tool_error",
"id": f"tool_error_{time.time()}",
"content": f"工具{tool_name}执行失败: {str(e)}",
"content": f"工具{tool_name}执行失败: {e!s}",
"tool_name": tool_name,
"timestamp": time.time(),
}
@@ -201,8 +202,8 @@ class ToolExecutor:
return tool_results, used_tools
async def execute_tool_call(
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
) -> Optional[Dict[str, Any]]:
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
) -> dict[str, Any] | None:
"""执行单个工具调用,并处理缓存"""
function_args = tool_call.args or {}
@@ -256,8 +257,8 @@ class ToolExecutor:
return result
async def _original_execute_tool_call(
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
) -> Optional[Dict[str, Any]]:
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
) -> dict[str, Any] | None:
"""执行单个工具调用的原始逻辑"""
try:
function_name = tool_call.func_name
@@ -323,10 +324,10 @@ class ToolExecutor:
logger.warning(f"{self.log_prefix}工具 {function_name} 返回空结果")
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
logger.error(f"执行工具调用时发生错误: {e!s}")
raise e
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
async def execute_specific_tool_simple(self, tool_name: str, tool_args: dict) -> dict | None:
"""直接执行指定工具
Args:

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
本模块包含一个从Python包的“安装名”到其“导入名”的映射。

View File

@@ -1,4 +1,3 @@
from typing import Optional
from src.common.logger import get_logger
logger = get_logger("dependency_config")
@@ -66,7 +65,7 @@ class DependencyConfig:
# 全局配置实例
_global_dependency_config: Optional[DependencyConfig] = None
_global_dependency_config: DependencyConfig | None = None
def get_dependency_config() -> DependencyConfig:

View File

@@ -1,8 +1,9 @@
import subprocess
import sys
import importlib
import importlib.util
from typing import List, Tuple, Optional, Any
import subprocess
import sys
from typing import Any
from packaging import version
from packaging.requirements import Requirement
@@ -19,7 +20,7 @@ class DependencyManager:
负责检查和自动安装插件的Python包依赖
"""
def __init__(self, auto_install: bool = True, use_mirror: bool = False, mirror_url: Optional[str] = None):
def __init__(self, auto_install: bool = True, use_mirror: bool = False, mirror_url: str | None = None):
"""初始化依赖管理器
Args:
@@ -46,7 +47,7 @@ class DependencyManager:
self.mirror_url = mirror_url or ""
self.install_timeout = 300
def check_dependencies(self, dependencies: Any, plugin_name: str = "") -> Tuple[bool, List[str], List[str]]:
def check_dependencies(self, dependencies: Any, plugin_name: str = "") -> tuple[bool, list[str], list[str]]:
"""检查依赖包是否满足要求
Args:
@@ -69,7 +70,7 @@ class DependencyManager:
logger.info(f"{log_prefix}缺少依赖包: {dep.get_pip_requirement()}")
missing_packages.append(dep.get_pip_requirement())
except Exception as e:
error_msg = f"检查依赖 {dep.package_name} 时发生错误: {str(e)}"
error_msg = f"检查依赖 {dep.package_name} 时发生错误: {e!s}"
error_messages.append(error_msg)
logger.error(f"{log_prefix}{error_msg}")
@@ -84,7 +85,7 @@ class DependencyManager:
return all_satisfied, missing_packages, error_messages
def install_dependencies(self, packages: List[str], plugin_name: str = "") -> Tuple[bool, List[str]]:
def install_dependencies(self, packages: list[str], plugin_name: str = "") -> tuple[bool, list[str]]:
"""自动安装缺失的依赖包
Args:
@@ -115,7 +116,7 @@ class DependencyManager:
logger.error(f"{log_prefix}❌ 安装失败: {package}")
except Exception as e:
failed_packages.append(package)
logger.error(f"{log_prefix}❌ 安装 {package} 时发生异常: {str(e)}")
logger.error(f"{log_prefix}❌ 安装 {package} 时发生异常: {e!s}")
success = len(failed_packages) == 0
if success:
@@ -125,7 +126,7 @@ class DependencyManager:
return success, failed_packages
def check_and_install_dependencies(self, dependencies: Any, plugin_name: str = "") -> Tuple[bool, List[str]]:
def check_and_install_dependencies(self, dependencies: Any, plugin_name: str = "") -> tuple[bool, list[str]]:
"""检查并自动安装依赖(组合操作)
Args:
@@ -163,7 +164,7 @@ class DependencyManager:
return False, all_errors
@staticmethod
def _normalize_dependencies(dependencies: Any) -> List[PythonDependency]:
def _normalize_dependencies(dependencies: Any) -> list[PythonDependency]:
"""将依赖列表标准化为PythonDependency对象"""
normalized = []
@@ -277,7 +278,7 @@ class DependencyManager:
# 全局依赖管理器实例
_global_dependency_manager: Optional[DependencyManager] = None
_global_dependency_manager: DependencyManager | None = None
def get_dependency_manager() -> DependencyManager:
@@ -288,7 +289,7 @@ def get_dependency_manager() -> DependencyManager:
return _global_dependency_manager
def configure_dependency_manager(auto_install: bool = True, use_mirror: bool = False, mirror_url: Optional[str] = None):
def configure_dependency_manager(auto_install: bool = True, use_mirror: bool = False, mirror_url: str | None = None):
"""配置全局依赖管理器"""
global _global_dependency_manager
_global_dependency_manager = DependencyManager(

View File

@@ -5,7 +5,8 @@
"""
import re
from typing import Dict, Any, Tuple
from typing import Any
from src.common.logger import get_logger
from src.config.config import MMC_VERSION
@@ -70,7 +71,7 @@ class VersionComparator:
return normalized
@staticmethod
def parse_version(version: str) -> Tuple[int, int, int]:
def parse_version(version: str) -> tuple[int, int, int]:
"""解析版本号为元组
Args:
@@ -109,7 +110,7 @@ class VersionComparator:
return 0
@staticmethod
def check_forward_compatibility(current_version: str, max_version: str) -> Tuple[bool, str]:
def check_forward_compatibility(current_version: str, max_version: str) -> tuple[bool, str]:
"""检查向前兼容性(仅使用兼容性映射表)
Args:
@@ -131,7 +132,7 @@ class VersionComparator:
return False, ""
@staticmethod
def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]:
def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> tuple[bool, str]:
"""检查版本是否在指定范围内,支持兼容性检查
Args:
@@ -195,7 +196,7 @@ class VersionComparator:
logger.info(f"添加兼容性映射:{base_normalized} -> {compatible_versions}")
@staticmethod
def get_compatibility_info() -> Dict[str, list]:
def get_compatibility_info() -> dict[str, list]:
"""获取当前的兼容性映射表
Returns:
@@ -232,7 +233,7 @@ class ManifestValidator:
self.validation_errors = []
self.validation_warnings = []
def validate_manifest(self, manifest_data: Dict[str, Any]) -> bool:
def validate_manifest(self, manifest_data: dict[str, Any]) -> bool:
"""验证manifest数据
Args:
@@ -266,7 +267,7 @@ class ManifestValidator:
if "name" not in author or not author["name"]:
self.validation_errors.append("作者信息缺少name字段或为空")
# url字段是可选的
if "url" in author and author["url"]:
if author.get("url"):
url = author["url"]
if not (url.startswith("http://") or url.startswith("https://")):
self.validation_warnings.append("作者URL建议使用完整的URL格式")
@@ -305,7 +306,7 @@ class ManifestValidator:
# 检查URL格式可选字段
for url_field in ["homepage_url", "repository_url"]:
if url_field in manifest_data and manifest_data[url_field]:
if manifest_data.get(url_field):
url: str = manifest_data[url_field]
if not (url.startswith("http://") or url.startswith("https://")):
self.validation_warnings.append(f"{url_field}建议使用完整的URL格式")

View File

@@ -4,19 +4,19 @@
提供方便的权限检查装饰器,用于插件命令和其他需要权限验证的地方。
"""
from collections.abc import Callable
from functools import wraps
from typing import Callable, Optional
from inspect import iscoroutinefunction
from src.chat.message_receive.chat_stream import ChatStream
from src.plugin_system.apis.logging_api import get_logger
from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.apis.send_api import text_to_stream
from src.plugin_system.apis.logging_api import get_logger
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger(__name__)
def require_permission(permission_node: str, deny_message: Optional[str] = None):
def require_permission(permission_node: str, deny_message: str | None = None):
"""
权限检查装饰器
@@ -90,7 +90,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None)
return decorator
def require_master(deny_message: Optional[str] = None):
def require_master(deny_message: str | None = None):
"""
Master权限检查装饰器
@@ -186,9 +186,7 @@ class PermissionChecker:
return permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id)
@staticmethod
async def ensure_permission(
chat_stream: ChatStream, permission_node: str, deny_message: Optional[str] = None
) -> bool:
async def ensure_permission(chat_stream: ChatStream, permission_node: str, deny_message: str | None = None) -> bool:
"""
确保用户拥有指定权限如果没有权限会发送消息并返回False
@@ -209,7 +207,7 @@ class PermissionChecker:
return has_permission
@staticmethod
async def ensure_master(chat_stream: ChatStream, deny_message: Optional[str] = None) -> bool:
async def ensure_master(chat_stream: ChatStream, deny_message: str | None = None) -> bool:
"""
确保用户为Master用户如果不是会发送消息并返回False