re-style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
00ba07e0e1
commit
a79253c714
@@ -5,33 +5,49 @@ MaiBot 插件系统
|
||||
"""
|
||||
|
||||
# 导出主要的公共接口
|
||||
from .apis import (
|
||||
chat_api,
|
||||
component_manage_api,
|
||||
config_api,
|
||||
database_api,
|
||||
emoji_api,
|
||||
generator_api,
|
||||
get_logger,
|
||||
llm_api,
|
||||
message_api,
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
register_plugin,
|
||||
send_api,
|
||||
tool_api,
|
||||
)
|
||||
from .base import (
|
||||
BasePlugin,
|
||||
ActionActivationType,
|
||||
ActionInfo,
|
||||
BaseAction,
|
||||
BaseCommand,
|
||||
BaseTool,
|
||||
ConfigField,
|
||||
ComponentType,
|
||||
ActionActivationType,
|
||||
ChatMode,
|
||||
ComponentInfo,
|
||||
ActionInfo,
|
||||
CommandInfo,
|
||||
PlusCommandInfo,
|
||||
PluginInfo,
|
||||
ToolInfo,
|
||||
PythonDependency,
|
||||
BaseEventHandler,
|
||||
BasePlugin,
|
||||
BaseTool,
|
||||
ChatMode,
|
||||
ChatType,
|
||||
CommandArgs,
|
||||
CommandInfo,
|
||||
ComponentInfo,
|
||||
ComponentType,
|
||||
ConfigField,
|
||||
EventHandlerInfo,
|
||||
EventType,
|
||||
MaiMessages,
|
||||
ToolParamType,
|
||||
PluginInfo,
|
||||
# 新增的增强命令系统
|
||||
PlusCommand,
|
||||
CommandArgs,
|
||||
PlusCommandAdapter,
|
||||
PlusCommandInfo,
|
||||
PythonDependency,
|
||||
ToolInfo,
|
||||
ToolParamType,
|
||||
create_plus_command_adapter,
|
||||
ChatType,
|
||||
)
|
||||
|
||||
# 导入工具模块
|
||||
@@ -41,28 +57,10 @@ from .utils import (
|
||||
# validate_plugin_manifest,
|
||||
# generate_plugin_manifest,
|
||||
)
|
||||
from .utils.dependency_config import configure_dependency_settings, get_dependency_config
|
||||
|
||||
# 导入依赖管理模块
|
||||
from .utils.dependency_manager import get_dependency_manager, configure_dependency_manager
|
||||
from .utils.dependency_config import get_dependency_config, configure_dependency_settings
|
||||
|
||||
from .apis import (
|
||||
chat_api,
|
||||
tool_api,
|
||||
component_manage_api,
|
||||
config_api,
|
||||
database_api,
|
||||
emoji_api,
|
||||
generator_api,
|
||||
llm_api,
|
||||
message_api,
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
send_api,
|
||||
register_plugin,
|
||||
get_logger,
|
||||
)
|
||||
|
||||
from .utils.dependency_manager import configure_dependency_manager, get_dependency_manager
|
||||
|
||||
__version__ = "2.0.0"
|
||||
|
||||
|
||||
@@ -14,14 +14,15 @@ from src.plugin_system.apis import (
|
||||
generator_api,
|
||||
llm_api,
|
||||
message_api,
|
||||
permission_api,
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
schedule_api,
|
||||
send_api,
|
||||
tool_api,
|
||||
permission_api,
|
||||
schedule_api,
|
||||
)
|
||||
from src.plugin_system.apis.chat_api import ChatManager as context_api
|
||||
|
||||
from .logging_api import get_logger
|
||||
from .plugin_register_api import register_plugin
|
||||
|
||||
@@ -30,18 +31,18 @@ __all__ = [
|
||||
"chat_api",
|
||||
"component_manage_api",
|
||||
"config_api",
|
||||
"context_api",
|
||||
"database_api",
|
||||
"emoji_api",
|
||||
"generator_api",
|
||||
"get_logger",
|
||||
"llm_api",
|
||||
"message_api",
|
||||
"permission_api",
|
||||
"person_api",
|
||||
"plugin_manage_api",
|
||||
"send_api",
|
||||
"get_logger",
|
||||
"register_plugin",
|
||||
"tool_api",
|
||||
"permission_api",
|
||||
"context_api",
|
||||
"schedule_api",
|
||||
"send_api",
|
||||
"tool_api",
|
||||
]
|
||||
|
||||
@@ -12,11 +12,11 @@
|
||||
streams = chat.get_all_group_streams()
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("chat_api")
|
||||
|
||||
@@ -31,7 +31,7 @@ class ChatManager:
|
||||
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
|
||||
|
||||
@staticmethod
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有聊天流
|
||||
|
||||
@@ -57,7 +57,7 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有群聊聊天流
|
||||
|
||||
@@ -80,7 +80,7 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有私聊聊天流
|
||||
|
||||
@@ -107,8 +107,8 @@ class ChatManager:
|
||||
|
||||
@staticmethod
|
||||
def get_group_stream_by_group_id(
|
||||
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
|
||||
group_id: str, platform: str | None | SpecialTypes = "qq"
|
||||
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast
|
||||
"""根据群ID获取聊天流
|
||||
|
||||
Args:
|
||||
@@ -144,8 +144,8 @@ class ChatManager:
|
||||
|
||||
@staticmethod
|
||||
def get_private_stream_by_user_id(
|
||||
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
|
||||
user_id: str, platform: str | None | SpecialTypes = "qq"
|
||||
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast
|
||||
"""根据用户ID获取私聊流
|
||||
|
||||
Args:
|
||||
@@ -203,7 +203,7 @@ class ChatManager:
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]:
|
||||
"""获取聊天流详细信息
|
||||
|
||||
Args:
|
||||
@@ -222,7 +222,7 @@ class ChatManager:
|
||||
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||
|
||||
try:
|
||||
info: Dict[str, Any] = {
|
||||
info: dict[str, Any] = {
|
||||
"stream_id": chat_stream.stream_id,
|
||||
"platform": chat_stream.platform,
|
||||
"type": ChatManager.get_stream_type(chat_stream),
|
||||
@@ -250,7 +250,7 @@ class ChatManager:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_streams_summary() -> Dict[str, int]:
|
||||
def get_streams_summary() -> dict[str, int]:
|
||||
"""获取聊天流统计摘要
|
||||
|
||||
Returns:
|
||||
@@ -285,27 +285,27 @@ class ChatManager:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
"""获取所有聊天流的便捷函数"""
|
||||
return ChatManager.get_all_streams(platform)
|
||||
|
||||
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
"""获取群聊聊天流的便捷函数"""
|
||||
return ChatManager.get_group_streams(platform)
|
||||
|
||||
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
"""获取私聊聊天流的便捷函数"""
|
||||
return ChatManager.get_private_streams(platform)
|
||||
|
||||
|
||||
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
|
||||
def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None:
|
||||
"""根据群ID获取聊天流的便捷函数"""
|
||||
return ChatManager.get_group_stream_by_group_id(group_id, platform)
|
||||
|
||||
|
||||
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
|
||||
def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None:
|
||||
"""根据用户ID获取私聊流的便捷函数"""
|
||||
return ChatManager.get_private_stream_by_user_id(user_id, platform)
|
||||
|
||||
@@ -315,11 +315,11 @@ def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
return ChatManager.get_stream_type(chat_stream)
|
||||
|
||||
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]:
|
||||
"""获取聊天流信息的便捷函数"""
|
||||
return ChatManager.get_stream_info(chat_stream)
|
||||
|
||||
|
||||
def get_streams_summary() -> Dict[str, int]:
|
||||
def get_streams_summary() -> dict[str, int]:
|
||||
"""获取聊天流统计摘要的便捷函数"""
|
||||
return ChatManager.get_streams_summary()
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
from typing import Optional, Union, Dict
|
||||
from src.plugin_system.base.component_types import (
|
||||
CommandInfo,
|
||||
ActionInfo,
|
||||
CommandInfo,
|
||||
ComponentType,
|
||||
EventHandlerInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
ToolInfo,
|
||||
)
|
||||
|
||||
|
||||
# === 插件信息查询 ===
|
||||
def get_all_plugin_info() -> Dict[str, PluginInfo]:
|
||||
def get_all_plugin_info() -> dict[str, PluginInfo]:
|
||||
"""
|
||||
获取所有插件的信息。
|
||||
|
||||
@@ -22,7 +21,7 @@ def get_all_plugin_info() -> Dict[str, PluginInfo]:
|
||||
return component_registry.get_all_plugins()
|
||||
|
||||
|
||||
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
|
||||
def get_plugin_info(plugin_name: str) -> PluginInfo | None:
|
||||
"""
|
||||
获取指定插件的信息。
|
||||
|
||||
@@ -40,7 +39,7 @@ def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
component_name: str, component_type: ComponentType
|
||||
) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
) -> CommandInfo | ActionInfo | EventHandlerInfo | None:
|
||||
"""
|
||||
获取指定组件的信息。
|
||||
|
||||
@@ -57,7 +56,7 @@ def get_component_info(
|
||||
|
||||
def get_components_info_by_type(
|
||||
component_type: ComponentType,
|
||||
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
) -> dict[str, CommandInfo | ActionInfo | EventHandlerInfo]:
|
||||
"""
|
||||
获取指定类型的所有组件信息。
|
||||
|
||||
@@ -74,7 +73,7 @@ def get_components_info_by_type(
|
||||
|
||||
def get_enabled_components_info_by_type(
|
||||
component_type: ComponentType,
|
||||
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
) -> dict[str, CommandInfo | ActionInfo | EventHandlerInfo]:
|
||||
"""
|
||||
获取指定类型的所有启用的组件信息。
|
||||
|
||||
@@ -90,7 +89,7 @@ def get_enabled_components_info_by_type(
|
||||
|
||||
|
||||
# === Action 查询方法 ===
|
||||
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
|
||||
def get_registered_action_info(action_name: str) -> ActionInfo | None:
|
||||
"""
|
||||
获取指定 Action 的注册信息。
|
||||
|
||||
@@ -105,7 +104,7 @@ def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
|
||||
return component_registry.get_registered_action_info(action_name)
|
||||
|
||||
|
||||
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
|
||||
def get_registered_command_info(command_name: str) -> CommandInfo | None:
|
||||
"""
|
||||
获取指定 Command 的注册信息。
|
||||
|
||||
@@ -120,7 +119,7 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
|
||||
return component_registry.get_registered_command_info(command_name)
|
||||
|
||||
|
||||
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
|
||||
def get_registered_tool_info(tool_name: str) -> ToolInfo | None:
|
||||
"""
|
||||
获取指定 Tool 的注册信息。
|
||||
|
||||
@@ -138,7 +137,7 @@ def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
|
||||
# === EventHandler 特定查询方法 ===
|
||||
def get_registered_event_handler_info(
|
||||
event_handler_name: str,
|
||||
) -> Optional[EventHandlerInfo]:
|
||||
) -> EventHandlerInfo | None:
|
||||
"""
|
||||
获取指定 EventHandler 的注册信息。
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -3,20 +3,20 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Any, Optional, List
|
||||
from typing import Any
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
build_readable_messages_with_id,
|
||||
)
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
|
||||
logger = get_logger("cross_context_api")
|
||||
|
||||
|
||||
def get_context_groups(chat_id: str) -> Optional[List[List[str]]]:
|
||||
def get_context_groups(chat_id: str) -> list[list[str]] | None:
|
||||
"""
|
||||
获取当前聊天所在的共享组的其他聊天ID
|
||||
"""
|
||||
@@ -41,7 +41,7 @@ def get_context_groups(chat_id: str) -> Optional[List[List[str]]]:
|
||||
return None
|
||||
|
||||
|
||||
async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: List[List[str]]) -> str:
|
||||
async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: list[list[str]]) -> str:
|
||||
"""
|
||||
构建跨群聊/私聊上下文 (Normal模式)
|
||||
"""
|
||||
@@ -74,8 +74,8 @@ async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos:
|
||||
|
||||
async def build_cross_context_s4u(
|
||||
chat_stream: ChatStream,
|
||||
other_chat_infos: List[List[str]],
|
||||
target_user_info: Optional[Dict[str, Any]],
|
||||
other_chat_infos: list[list[str]],
|
||||
target_user_info: dict[str, Any] | None,
|
||||
) -> str:
|
||||
"""
|
||||
构建跨群聊/私聊上下文 (S4U模式)
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
注意:此模块现在使用SQLAlchemy实现,提供更好的连接管理和错误处理
|
||||
"""
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get, store_action_info, MODEL_MAPPING
|
||||
from src.common.database.sqlalchemy_database_api import MODEL_MAPPING, db_get, db_query, db_save, store_action_info
|
||||
|
||||
# 保持向后兼容性
|
||||
__all__ = ["db_query", "db_save", "db_get", "store_action_info", "MODEL_MAPPING"]
|
||||
__all__ = ["MODEL_MAPPING", "db_get", "db_query", "db_save", "store_action_info"]
|
||||
|
||||
@@ -10,10 +10,9 @@
|
||||
|
||||
import random
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("emoji_api")
|
||||
|
||||
@@ -23,7 +22,7 @@ logger = get_logger("emoji_api")
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
|
||||
async def get_by_description(description: str) -> tuple[str, str, str] | None:
|
||||
"""根据描述选择表情包
|
||||
|
||||
Args:
|
||||
@@ -65,7 +64,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
|
||||
return None
|
||||
|
||||
|
||||
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||
async def get_random(count: int | None = 1) -> list[tuple[str, str, str]]:
|
||||
"""随机获取指定数量的表情包
|
||||
|
||||
Args:
|
||||
@@ -137,7 +136,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||
return []
|
||||
|
||||
|
||||
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
async def get_by_emotion(emotion: str) -> tuple[str, str, str] | None:
|
||||
"""根据情感标签获取表情包
|
||||
|
||||
Args:
|
||||
@@ -227,7 +226,7 @@ def get_info():
|
||||
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
||||
|
||||
|
||||
def get_emotions() -> List[str]:
|
||||
def get_emotions() -> list[str]:
|
||||
"""获取所有可用的情感标签
|
||||
|
||||
Returns:
|
||||
@@ -247,7 +246,7 @@ def get_emotions() -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def get_descriptions() -> List[str]:
|
||||
def get_descriptions() -> list[str]:
|
||||
"""获取所有表情包描述
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -9,13 +9,15 @@
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Tuple, Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ActionInfo
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -30,10 +32,10 @@ logger = get_logger("generator_api")
|
||||
|
||||
|
||||
def get_replyer(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
chat_stream: ChatStream | None = None,
|
||||
chat_id: str | None = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer]:
|
||||
) -> DefaultReplyer | None:
|
||||
"""获取回复器对象
|
||||
|
||||
优先使用chat_stream,如果没有则使用chat_id直接查找。
|
||||
@@ -71,15 +73,13 @@ def get_replyer(
|
||||
|
||||
|
||||
async def generate_reply(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
action_data: Optional[Dict[str, Any]] = None,
|
||||
chat_stream: ChatStream | None = None,
|
||||
chat_id: str | None = None,
|
||||
action_data: dict[str, Any] | None = None,
|
||||
reply_to: str = "",
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
reply_message: dict[str, Any] | None = None,
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
choosen_actions: Optional[List[Dict[str, Any]]] = None,
|
||||
available_actions: dict[str, ActionInfo] | None = None,
|
||||
enable_tool: bool = False,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
@@ -87,7 +87,7 @@ async def generate_reply(
|
||||
request_type: str = "generator_api",
|
||||
from_plugin: bool = True,
|
||||
read_mark: float = 0.0,
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
) -> tuple[bool, list[tuple[str, Any]], str | None]:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
@@ -183,9 +183,9 @@ async def generate_reply(
|
||||
|
||||
|
||||
async def rewrite_reply(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
reply_data: Optional[Dict[str, Any]] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
chat_stream: ChatStream | None = None,
|
||||
reply_data: dict[str, Any] | None = None,
|
||||
chat_id: str | None = None,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
raw_reply: str = "",
|
||||
@@ -193,7 +193,7 @@ async def rewrite_reply(
|
||||
reply_to: str = "",
|
||||
return_prompt: bool = False,
|
||||
request_type: str = "generator_api",
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
) -> tuple[bool, list[tuple[str, Any]], str | None]:
|
||||
"""重写回复
|
||||
|
||||
Args:
|
||||
@@ -252,7 +252,7 @@ async def rewrite_reply(
|
||||
return False, [], None
|
||||
|
||||
|
||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
|
||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> list[tuple[str, Any]]:
|
||||
"""将文本处理为更拟人化的文本
|
||||
|
||||
Args:
|
||||
@@ -281,11 +281,11 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
|
||||
|
||||
|
||||
async def generate_response_custom(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
chat_stream: ChatStream | None = None,
|
||||
chat_id: str | None = None,
|
||||
request_type: str = "generator_api",
|
||||
prompt: str = "",
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""
|
||||
使用自定义提示生成回复
|
||||
|
||||
|
||||
@@ -7,12 +7,13 @@
|
||||
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
|
||||
"""
|
||||
|
||||
from typing import Tuple, Dict, List, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
|
||||
logger = get_logger("llm_api")
|
||||
|
||||
@@ -21,7 +22,7 @@ logger = get_logger("llm_api")
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, TaskConfig]:
|
||||
def get_available_models() -> dict[str, TaskConfig]:
|
||||
"""获取所有可用的模型配置
|
||||
|
||||
Returns:
|
||||
@@ -31,7 +32,7 @@ def get_available_models() -> Dict[str, TaskConfig]:
|
||||
# 自动获取所有属性并转换为字典形式
|
||||
models = model_config.model_task_config
|
||||
attrs = dir(models)
|
||||
rets: Dict[str, TaskConfig] = {}
|
||||
rets: dict[str, TaskConfig] = {}
|
||||
for attr in attrs:
|
||||
if not attr.startswith("__"):
|
||||
try:
|
||||
@@ -52,9 +53,9 @@ async def generate_with_model(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str]:
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[bool, str, str, str]:
|
||||
"""使用指定模型生成内容
|
||||
|
||||
Args:
|
||||
@@ -78,7 +79,7 @@ async def generate_with_model(
|
||||
return True, response, reasoning_content, model_name
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
error_msg = f"生成内容时出错: {e!s}"
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", ""
|
||||
|
||||
@@ -86,11 +87,11 @@ async def generate_with_model(
|
||||
async def generate_with_model_with_tools(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
tool_options: List[Dict[str, Any]] | None = None,
|
||||
tool_options: list[dict[str, Any]] | None = None,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[bool, str, str, str, list[ToolCall] | None]:
|
||||
"""使用指定模型和工具生成内容
|
||||
|
||||
Args:
|
||||
@@ -117,6 +118,6 @@ async def generate_with_model_with_tools(
|
||||
return True, response, reasoning_content, model_name, tool_call
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
error_msg = f"生成内容时出错: {e!s}"
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", "", None
|
||||
|
||||
@@ -8,26 +8,26 @@
|
||||
readable_text = message_api.build_readable_messages(messages)
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
get_raw_msg_by_timestamp_with_chat_users,
|
||||
get_raw_msg_by_timestamp_random,
|
||||
get_raw_msg_by_timestamp_with_users,
|
||||
get_raw_msg_before_timestamp,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
get_raw_msg_before_timestamp_with_users,
|
||||
num_new_messages_since,
|
||||
num_new_messages_since_with_users,
|
||||
build_readable_messages,
|
||||
build_readable_messages_with_list,
|
||||
get_person_id_list,
|
||||
get_raw_msg_before_timestamp,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
get_raw_msg_before_timestamp_with_users,
|
||||
get_raw_msg_by_timestamp,
|
||||
get_raw_msg_by_timestamp_random,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
get_raw_msg_by_timestamp_with_chat_users,
|
||||
get_raw_msg_by_timestamp_with_users,
|
||||
num_new_messages_since,
|
||||
num_new_messages_since_with_users,
|
||||
)
|
||||
|
||||
from src.config.config import global_config
|
||||
|
||||
# =============================================================================
|
||||
# 消息查询API函数
|
||||
@@ -36,7 +36,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
|
||||
async def get_messages_by_time(
|
||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定时间范围内的消息
|
||||
|
||||
@@ -70,7 +70,7 @@ async def get_messages_by_time_in_chat(
|
||||
limit_mode: str = "latest",
|
||||
filter_mai: bool = False,
|
||||
filter_command: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定时间范围内的消息
|
||||
|
||||
@@ -111,7 +111,7 @@ async def get_messages_by_time_in_chat_inclusive(
|
||||
limit_mode: str = "latest",
|
||||
filter_mai: bool = False,
|
||||
filter_command: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定时间范围内的消息(包含边界)
|
||||
|
||||
@@ -152,10 +152,10 @@ async def get_messages_by_time_in_chat_for_users(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
person_ids: List[str],
|
||||
person_ids: list[str],
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定用户在指定时间范围内的消息
|
||||
|
||||
@@ -186,7 +186,7 @@ async def get_messages_by_time_in_chat_for_users(
|
||||
|
||||
async def get_random_chat_messages(
|
||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
随机选择一个聊天,返回该聊天在指定时间范围内的消息
|
||||
|
||||
@@ -213,8 +213,8 @@ async def get_random_chat_messages(
|
||||
|
||||
|
||||
async def get_messages_by_time_for_users(
|
||||
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
start_time: float, end_time: float, person_ids: list[str], limit: int = 0, limit_mode: str = "latest"
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户在所有聊天中指定时间范围内的消息
|
||||
|
||||
@@ -238,7 +238,7 @@ async def get_messages_by_time_for_users(
|
||||
return await get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
|
||||
|
||||
|
||||
async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]:
|
||||
async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定时间戳之前的消息
|
||||
|
||||
@@ -294,8 +294,8 @@ async def get_messages_before_time_in_chat(
|
||||
|
||||
|
||||
async def get_messages_before_time_for_users(
|
||||
timestamp: float, person_ids: List[str], limit: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
timestamp: float, person_ids: list[str], limit: int = 0
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户在指定时间戳之前的消息
|
||||
|
||||
@@ -319,7 +319,7 @@ async def get_messages_before_time_for_users(
|
||||
|
||||
async def get_recent_messages(
|
||||
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中最近一段时间的消息
|
||||
|
||||
@@ -358,7 +358,7 @@ async def get_recent_messages(
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
|
||||
async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: float | None = None) -> int:
|
||||
"""
|
||||
计算指定聊天中从开始时间到结束时间的新消息数量
|
||||
|
||||
@@ -382,7 +382,7 @@ async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Op
|
||||
return await num_new_messages_since(chat_id, start_time, end_time)
|
||||
|
||||
|
||||
async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
|
||||
async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: list[str]) -> int:
|
||||
"""
|
||||
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
|
||||
|
||||
@@ -413,7 +413,7 @@ async def count_new_messages_for_users(chat_id: str, start_time: float, end_time
|
||||
|
||||
|
||||
async def build_readable_messages_to_str(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
@@ -442,12 +442,12 @@ async def build_readable_messages_to_str(
|
||||
|
||||
|
||||
async def build_readable_messages_with_details(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
) -> tuple[str, list[tuple[float, str, str]]]:
|
||||
"""
|
||||
将消息列表构建成可读的字符串,并返回详细信息
|
||||
|
||||
@@ -464,7 +464,7 @@ async def build_readable_messages_with_details(
|
||||
return await build_readable_messages_with_list(messages, replace_bot_name, merge_messages, timestamp_mode, truncate)
|
||||
|
||||
|
||||
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
async def get_person_ids_from_messages(messages: list[dict[str, Any]]) -> list[str]:
|
||||
"""
|
||||
从消息列表中提取不重复的用户ID列表
|
||||
|
||||
@@ -482,7 +482,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
async def filter_mai_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
从消息列表中移除麦麦的消息
|
||||
Args:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""纯异步权限API定义。所有外部调用方必须使用 await。"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -48,18 +48,18 @@ class IPermissionManager(ABC):
|
||||
async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_permissions(self, user: UserInfo) -> List[str]: ...
|
||||
async def get_user_permissions(self, user: UserInfo) -> list[str]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_permission_nodes(self) -> List[PermissionNode]: ...
|
||||
async def get_all_permission_nodes(self) -> list[PermissionNode]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: ...
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[PermissionNode]: ...
|
||||
|
||||
|
||||
class PermissionAPI:
|
||||
def __init__(self):
|
||||
self._permission_manager: Optional[IPermissionManager] = None
|
||||
self._permission_manager: IPermissionManager | None = None
|
||||
# 需要保留的前缀(视为绝对节点名,不再自动加 plugins.<plugin>. 前缀)
|
||||
self.RESERVED_PREFIXES: tuple[str, ...] = "system."
|
||||
# 系统节点列表 (name, description, default_granted)
|
||||
@@ -147,11 +147,11 @@ class PermissionAPI:
|
||||
self._ensure_manager()
|
||||
return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node)
|
||||
|
||||
async def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
|
||||
async def get_user_permissions(self, platform: str, user_id: str) -> list[str]:
|
||||
self._ensure_manager()
|
||||
return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id))
|
||||
|
||||
async def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
|
||||
async def get_all_permission_nodes(self) -> list[dict[str, Any]]:
|
||||
self._ensure_manager()
|
||||
nodes = await self._permission_manager.get_all_permission_nodes()
|
||||
return [
|
||||
@@ -164,7 +164,7 @@ class PermissionAPI:
|
||||
for n in nodes
|
||||
]
|
||||
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[dict[str, Any]]:
|
||||
self._ensure_manager()
|
||||
nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name)
|
||||
return [
|
||||
|
||||
@@ -8,8 +8,9 @@
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.person_info import Person
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
logger = get_logger("person_api")
|
||||
|
||||
@@ -63,7 +64,7 @@ async def get_person_value(person_id: str, field_name: str, default: Any = None)
|
||||
return default
|
||||
|
||||
|
||||
async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict:
|
||||
async def get_person_values(person_id: str, field_names: list, default_dict: dict | None = None) -> dict:
|
||||
"""批量获取用户信息字段值
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from typing import Tuple, List
|
||||
|
||||
|
||||
def list_loaded_plugins() -> List[str]:
|
||||
def list_loaded_plugins() -> list[str]:
|
||||
"""
|
||||
列出所有当前加载的插件。
|
||||
|
||||
@@ -13,7 +10,7 @@ def list_loaded_plugins() -> List[str]:
|
||||
return plugin_manager.list_loaded_plugins()
|
||||
|
||||
|
||||
def list_registered_plugins() -> List[str]:
|
||||
def list_registered_plugins() -> list[str]:
|
||||
"""
|
||||
列出所有已注册的插件。
|
||||
|
||||
@@ -80,7 +77,7 @@ async def reload_plugin(plugin_name: str) -> bool:
|
||||
return await plugin_manager.reload_registered_plugin(plugin_name)
|
||||
|
||||
|
||||
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
|
||||
def load_plugin(plugin_name: str) -> tuple[bool, int]:
|
||||
"""
|
||||
加载指定的插件。
|
||||
|
||||
@@ -109,7 +106,7 @@ def add_plugin_directory(plugin_directory: str) -> bool:
|
||||
return plugin_manager.add_plugin_directory(plugin_directory)
|
||||
|
||||
|
||||
def rescan_plugin_directory() -> Tuple[int, int]:
|
||||
def rescan_plugin_directory() -> tuple[int, int]:
|
||||
"""
|
||||
重新扫描插件目录,加载新插件。
|
||||
Returns:
|
||||
|
||||
@@ -6,8 +6,8 @@ logger = get_logger("plugin_manager") # 复用plugin_manager名称
|
||||
|
||||
|
||||
def register_plugin(cls):
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
"""插件注册装饰器
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.sqlalchemy_models import MonthlyPlan
|
||||
from src.common.logger import get_logger
|
||||
@@ -44,7 +44,7 @@ class ScheduleAPI:
|
||||
"""日程表与月度计划API - 负责日程和计划信息的查询与管理"""
|
||||
|
||||
@staticmethod
|
||||
async def get_today_schedule() -> Optional[List[Dict[str, Any]]]:
|
||||
async def get_today_schedule() -> list[dict[str, Any]] | None:
|
||||
"""(异步) 获取今天的日程安排
|
||||
|
||||
Returns:
|
||||
@@ -58,7 +58,7 @@ class ScheduleAPI:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_current_activity() -> Optional[str]:
|
||||
async def get_current_activity() -> str | None:
|
||||
"""(异步) 获取当前正在进行的活动
|
||||
|
||||
Returns:
|
||||
@@ -87,7 +87,7 @@ class ScheduleAPI:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]:
|
||||
async def get_monthly_plans(target_month: str | None = None) -> list[MonthlyPlan]:
|
||||
"""(异步) 获取指定月份的有效月度计划
|
||||
|
||||
Args:
|
||||
@@ -106,7 +106,7 @@ class ScheduleAPI:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool:
|
||||
async def ensure_monthly_plans(target_month: str | None = None) -> bool:
|
||||
"""(异步) 确保指定月份存在月度计划,如果不存在则触发生成
|
||||
|
||||
Args:
|
||||
@@ -125,7 +125,7 @@ class ScheduleAPI:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def archive_monthly_plans(target_month: Optional[str] = None) -> bool:
|
||||
async def archive_monthly_plans(target_month: str | None = None) -> bool:
|
||||
"""(异步) 归档指定月份的月度计划
|
||||
|
||||
Args:
|
||||
@@ -150,12 +150,12 @@ class ScheduleAPI:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_today_schedule() -> Optional[List[Dict[str, Any]]]:
|
||||
async def get_today_schedule() -> list[dict[str, Any]] | None:
|
||||
"""(异步) 获取今天的日程安排的便捷函数"""
|
||||
return await ScheduleAPI.get_today_schedule()
|
||||
|
||||
|
||||
async def get_current_activity() -> Optional[str]:
|
||||
async def get_current_activity() -> str | None:
|
||||
"""(异步) 获取当前正在进行的活动的便捷函数"""
|
||||
return await ScheduleAPI.get_current_activity()
|
||||
|
||||
@@ -165,16 +165,16 @@ async def regenerate_schedule() -> bool:
|
||||
return await ScheduleAPI.regenerate_schedule()
|
||||
|
||||
|
||||
async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]:
|
||||
async def get_monthly_plans(target_month: str | None = None) -> list[MonthlyPlan]:
|
||||
"""(异步) 获取指定月份的有效月度计划的便捷函数"""
|
||||
return await ScheduleAPI.get_monthly_plans(target_month)
|
||||
|
||||
|
||||
async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool:
|
||||
async def ensure_monthly_plans(target_month: str | None = None) -> bool:
|
||||
"""(异步) 确保指定月份存在月度计划的便捷函数"""
|
||||
return await ScheduleAPI.ensure_monthly_plans(target_month)
|
||||
|
||||
|
||||
async def archive_monthly_plans(target_month: Optional[str] = None) -> bool:
|
||||
async def archive_monthly_plans(target_month: str | None = None) -> bool:
|
||||
"""(异步) 归档指定月份的月度计划的便捷函数"""
|
||||
return await ScheduleAPI.archive_monthly_plans(target_month)
|
||||
|
||||
@@ -28,29 +28,28 @@
|
||||
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Optional, Union, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from maim_message import Seg, UserInfo
|
||||
|
||||
# 导入依赖
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from maim_message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv
|
||||
from maim_message import Seg
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
# 日志记录器
|
||||
logger = get_logger("send_api")
|
||||
|
||||
# 适配器命令响应等待池
|
||||
_adapter_response_pool: Dict[str, asyncio.Future] = {}
|
||||
_adapter_response_pool: dict[str, asyncio.Future] = {}
|
||||
|
||||
|
||||
def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[MessageRecv]:
|
||||
def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None:
|
||||
"""查找要回复的消息
|
||||
|
||||
Args:
|
||||
@@ -134,13 +133,13 @@ async def wait_adapter_response(request_id: str, timeout: float = 30.0) -> dict:
|
||||
|
||||
async def _send_to_target(
|
||||
message_type: str,
|
||||
content: Union[str, dict],
|
||||
content: str | dict,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
set_reply: bool = False,
|
||||
reply_to_message: Optional[Dict[str, Any]] = None,
|
||||
reply_to_message: dict[str, Any] | None = None,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
selected_expressions:List[int] = None,
|
||||
@@ -249,7 +248,7 @@ async def text_to_stream(
|
||||
stream_id: str,
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
reply_to_message: Optional[Dict[str, Any]] = None,
|
||||
reply_to_message: dict[str, Any] | None = None,
|
||||
set_reply: bool = True,
|
||||
storage_message: bool = True,
|
||||
selected_expressions:List[int] = None,
|
||||
@@ -317,7 +316,7 @@ async def image_to_stream(
|
||||
|
||||
|
||||
async def command_to_stream(
|
||||
command: Union[str, dict],
|
||||
command: str | dict,
|
||||
stream_id: str,
|
||||
storage_message: bool = True,
|
||||
display_message: str = "",
|
||||
@@ -345,7 +344,7 @@ async def custom_to_stream(
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
reply_to_message: Optional[Dict[str, Any]] = None,
|
||||
reply_to_message: dict[str, Any] | None = None,
|
||||
set_reply: bool = True,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
@@ -381,8 +380,8 @@ async def custom_to_stream(
|
||||
async def adapter_command_to_stream(
|
||||
action: str,
|
||||
params: dict,
|
||||
platform: Optional[str] = "qq",
|
||||
stream_id: Optional[str] = None,
|
||||
platform: str | None = "qq",
|
||||
stream_id: str | None = None,
|
||||
timeout: float = 30.0,
|
||||
storage_message: bool = False,
|
||||
) -> dict:
|
||||
@@ -501,4 +500,4 @@ async def adapter_command_to_stream(
|
||||
except Exception as e:
|
||||
logger.error(f"[SendAPI] 发送适配器命令时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return {"status": "error", "message": f"发送适配器命令时出错: {str(e)}"}
|
||||
return {"status": "error", "message": f"发送适配器命令时出错: {e!s}"}
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
from typing import Optional, Type
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_api")
|
||||
|
||||
|
||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
def get_tool_instance(tool_name: str) -> BaseTool | None:
|
||||
"""获取公开工具实例"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
@@ -18,7 +16,7 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
else:
|
||||
plugin_config = None
|
||||
|
||||
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
|
||||
tool_class: type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
|
||||
return tool_class(plugin_config) if tool_class else None
|
||||
|
||||
|
||||
|
||||
@@ -4,31 +4,31 @@
|
||||
提供插件开发的基础类和类型定义
|
||||
"""
|
||||
|
||||
from .base_plugin import BasePlugin
|
||||
from .base_action import BaseAction
|
||||
from .base_tool import BaseTool
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .base_plugin import BasePlugin
|
||||
from .base_tool import BaseTool
|
||||
from .command_args import CommandArgs
|
||||
from .component_types import (
|
||||
ComponentType,
|
||||
ActionActivationType,
|
||||
ActionInfo,
|
||||
ChatMode,
|
||||
ChatType,
|
||||
ComponentInfo,
|
||||
ActionInfo,
|
||||
CommandInfo,
|
||||
PlusCommandInfo,
|
||||
ToolInfo,
|
||||
PluginInfo,
|
||||
PythonDependency,
|
||||
ComponentInfo,
|
||||
ComponentType,
|
||||
EventHandlerInfo,
|
||||
EventType,
|
||||
MaiMessages,
|
||||
PluginInfo,
|
||||
PlusCommandInfo,
|
||||
PythonDependency,
|
||||
ToolInfo,
|
||||
ToolParamType,
|
||||
)
|
||||
from .config_types import ConfigField
|
||||
from .plus_command import PlusCommand, PlusCommandAdapter, create_plus_command_adapter
|
||||
from .command_args import CommandArgs
|
||||
|
||||
__all__ = [
|
||||
"BasePlugin",
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional, List, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType, ChatType
|
||||
from src.plugin_system.apis import send_api, database_api, message_api
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import database_api, message_api, send_api
|
||||
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ChatMode, ChatType, ComponentType
|
||||
|
||||
logger = get_logger("base_action")
|
||||
|
||||
|
||||
@@ -37,7 +36,7 @@ class BaseAction(ABC):
|
||||
"""是否为二步Action。如果为True,Action将分两步执行:第一步选择操作,第二步执行具体操作"""
|
||||
step_one_description: str = ""
|
||||
"""第一步的描述,用于向LLM展示Action的基本功能"""
|
||||
sub_actions: List[Tuple[str, str, Dict[str, str]]] = []
|
||||
sub_actions: list[tuple[str, str, dict[str, str]]] = []
|
||||
"""子Action列表,格式为[(子Action名, 子Action描述, 子Action参数)]。仅在二步Action中使用"""
|
||||
|
||||
def __init__(
|
||||
@@ -48,8 +47,8 @@ class BaseAction(ABC):
|
||||
thinking_id: str,
|
||||
chat_stream: ChatStream,
|
||||
log_prefix: str = "",
|
||||
plugin_config: Optional[dict] = None,
|
||||
action_message: Optional[dict] = None,
|
||||
plugin_config: dict | None = None,
|
||||
action_message: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
# sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs
|
||||
@@ -107,8 +106,8 @@ class BaseAction(ABC):
|
||||
# 二步Action相关实例属性
|
||||
self.is_two_step_action: bool = getattr(self.__class__, "is_two_step_action", False)
|
||||
self.step_one_description: str = getattr(self.__class__, "step_one_description", "")
|
||||
self.sub_actions: List[Tuple[str, str, Dict[str, str]]] = getattr(self.__class__, "sub_actions", []).copy()
|
||||
self._selected_sub_action: Optional[str] = None
|
||||
self.sub_actions: list[tuple[str, str, dict[str, str]]] = getattr(self.__class__, "sub_actions", []).copy()
|
||||
self._selected_sub_action: str | None = None
|
||||
"""当前选择的子Action名称,用于二步Action的状态管理"""
|
||||
|
||||
# =============================================================================
|
||||
@@ -198,7 +197,7 @@ class BaseAction(ABC):
|
||||
"""
|
||||
return self._validate_chat_type()
|
||||
|
||||
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
|
||||
async def wait_for_new_message(self, timeout: int = 1200) -> tuple[bool, str]:
|
||||
"""等待新消息或超时
|
||||
|
||||
在loop_start_time之后等待新消息,如果没有新消息且没有超时,就一直等待。
|
||||
@@ -230,7 +229,7 @@ class BaseAction(ABC):
|
||||
|
||||
# 检查新消息
|
||||
current_time = time.time()
|
||||
new_message_count = message_api.count_new_messages(
|
||||
new_message_count = await message_api.count_new_messages(
|
||||
chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time
|
||||
)
|
||||
|
||||
@@ -256,7 +255,7 @@ class BaseAction(ABC):
|
||||
return False, ""
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
|
||||
return False, f"等待新消息失败: {str(e)}"
|
||||
return False, f"等待新消息失败: {e!s}"
|
||||
|
||||
async def send_text(self, content: str, reply_to: str = "", typing: bool = False) -> bool:
|
||||
"""发送文本消息
|
||||
@@ -359,7 +358,7 @@ class BaseAction(ABC):
|
||||
)
|
||||
|
||||
async def send_command(
|
||||
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
|
||||
self, command_name: str, args: dict | None = None, display_message: str = "", storage_message: bool = True
|
||||
) -> bool:
|
||||
"""发送命令消息
|
||||
|
||||
@@ -400,7 +399,7 @@ class BaseAction(ABC):
|
||||
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
|
||||
return False
|
||||
|
||||
async def call_action(self, action_name: str, action_data: Optional[dict] = None) -> Tuple[bool, str]:
|
||||
async def call_action(self, action_name: str, action_data: dict | None = None) -> tuple[bool, str]:
|
||||
"""
|
||||
在当前Action中调用另一个Action。
|
||||
|
||||
@@ -515,7 +514,7 @@ class BaseAction(ABC):
|
||||
sub_actions=getattr(cls, "sub_actions", []).copy(),
|
||||
)
|
||||
|
||||
async def handle_step_one(self) -> Tuple[bool, str]:
|
||||
async def handle_step_one(self) -> tuple[bool, str]:
|
||||
"""处理二步Action的第一步
|
||||
|
||||
Returns:
|
||||
@@ -547,7 +546,7 @@ class BaseAction(ABC):
|
||||
# 调用第二步执行
|
||||
return await self.execute_step_two(selected_action)
|
||||
|
||||
async def execute_step_two(self, sub_action_name: str) -> Tuple[bool, str]:
|
||||
async def execute_step_two(self, sub_action_name: str) -> tuple[bool, str]:
|
||||
"""执行二步Action的第二步
|
||||
|
||||
Args:
|
||||
@@ -563,7 +562,7 @@ class BaseAction(ABC):
|
||||
return False, f"二步Action必须实现execute_step_two方法来处理操作: {sub_action_name}"
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行Action的抽象方法,子类必须实现
|
||||
|
||||
对于二步Action,会自动处理第一步逻辑
|
||||
@@ -578,7 +577,7 @@ class BaseAction(ABC):
|
||||
# 普通Action由子类实现
|
||||
pass
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
async def handle_action(self) -> tuple[bool, str]:
|
||||
"""兼容旧系统的handle_action接口,委托给execute方法
|
||||
|
||||
为了保持向后兼容性,旧系统的代码可能会调用handle_action方法。
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from .component_types import ChatType
|
||||
from src.plugin_system.base.component_types import ChatterInfo, ComponentType
|
||||
|
||||
from .component_types import ChatType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
|
||||
@@ -13,7 +15,7 @@ class BaseChatter(ABC):
|
||||
"""Chatter组件的名称"""
|
||||
chatter_description: str = ""
|
||||
"""Chatter组件的描述"""
|
||||
chat_types: List[ChatType] = [ChatType.PRIVATE, ChatType.GROUP]
|
||||
chat_types: list[ChatType] = [ChatType.PRIVATE, ChatType.GROUP]
|
||||
|
||||
def __init__(self, stream_id: str, action_manager: "ChatterActionManager"):
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Tuple, Optional, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import CommandInfo, ComponentType, ChatType
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType
|
||||
|
||||
logger = get_logger("base_command")
|
||||
|
||||
@@ -29,7 +29,7 @@ class BaseCommand(ABC):
|
||||
chat_type_allow: ChatType = ChatType.ALL
|
||||
"""允许的聊天类型,默认为所有类型"""
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
|
||||
"""初始化Command组件
|
||||
|
||||
Args:
|
||||
@@ -37,7 +37,7 @@ class BaseCommand(ABC):
|
||||
plugin_config: 插件配置字典
|
||||
"""
|
||||
self.message = message
|
||||
self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组
|
||||
self.matched_groups: dict[str, str] = {} # 存储正则表达式匹配的命名组
|
||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||
|
||||
self.log_prefix = "[Command]"
|
||||
@@ -55,7 +55,7 @@ class BaseCommand(ABC):
|
||||
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
|
||||
)
|
||||
|
||||
def set_matched_groups(self, groups: Dict[str, str]) -> None:
|
||||
def set_matched_groups(self, groups: dict[str, str]) -> None:
|
||||
"""设置正则表达式匹配的命名组
|
||||
|
||||
Args:
|
||||
@@ -93,7 +93,7 @@ class BaseCommand(ABC):
|
||||
return self._validate_chat_type()
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
async def execute(self) -> tuple[bool, str | None, bool]:
|
||||
"""执行Command的抽象方法,子类必须实现
|
||||
|
||||
Returns:
|
||||
@@ -176,7 +176,7 @@ class BaseCommand(ABC):
|
||||
)
|
||||
|
||||
async def send_command(
|
||||
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
|
||||
self, command_name: str, args: dict | None = None, display_message: str = "", storage_message: bool = True
|
||||
) -> bool:
|
||||
"""发送命令消息
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -25,22 +25,22 @@ class HandlerResult:
|
||||
class HandlerResultsCollection:
|
||||
"""HandlerResult集合,提供便捷的查询方法"""
|
||||
|
||||
def __init__(self, results: List[HandlerResult]):
|
||||
def __init__(self, results: list[HandlerResult]):
|
||||
self.results = results
|
||||
|
||||
def all_continue_process(self) -> bool:
|
||||
"""检查是否所有handler的continue_process都为True"""
|
||||
return all(result.continue_process for result in self.results)
|
||||
|
||||
def get_all_results(self) -> List[HandlerResult]:
|
||||
def get_all_results(self) -> list[HandlerResult]:
|
||||
"""获取所有HandlerResult"""
|
||||
return self.results
|
||||
|
||||
def get_failed_handlers(self) -> List[HandlerResult]:
|
||||
def get_failed_handlers(self) -> list[HandlerResult]:
|
||||
"""获取执行失败的handler结果"""
|
||||
return [result for result in self.results if not result.success]
|
||||
|
||||
def get_stopped_handlers(self) -> List[HandlerResult]:
|
||||
def get_stopped_handlers(self) -> list[HandlerResult]:
|
||||
"""获取continue_process为False的handler结果"""
|
||||
return [result for result in self.results if not result.continue_process]
|
||||
|
||||
@@ -57,7 +57,7 @@ class HandlerResultsCollection:
|
||||
else:
|
||||
return {result.handler_name: result.message for result in self.results}
|
||||
|
||||
def get_handler_result(self, handler_name: str) -> Optional[HandlerResult]:
|
||||
def get_handler_result(self, handler_name: str) -> HandlerResult | None:
|
||||
"""获取指定handler的结果"""
|
||||
for result in self.results:
|
||||
if result.handler_name == handler_name:
|
||||
@@ -72,7 +72,7 @@ class HandlerResultsCollection:
|
||||
"""获取执行失败的handler数量"""
|
||||
return sum(1 for result in self.results if not result.success)
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
def get_summary(self) -> dict[str, Any]:
|
||||
"""获取执行摘要"""
|
||||
return {
|
||||
"total_handlers": len(self.results),
|
||||
@@ -85,13 +85,13 @@ class HandlerResultsCollection:
|
||||
|
||||
|
||||
class BaseEvent:
|
||||
def __init__(self, name: str, allowed_subscribers: List[str] = None, allowed_triggers: List[str] = None):
|
||||
def __init__(self, name: str, allowed_subscribers: list[str] = None, allowed_triggers: list[str] = None):
|
||||
self.name = name
|
||||
self.enabled = True
|
||||
self.allowed_subscribers = allowed_subscribers # 记录事件处理器名
|
||||
self.allowed_triggers = allowed_triggers # 记录插件名
|
||||
|
||||
self.subscribers: List["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表
|
||||
self.subscribers: list["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表
|
||||
|
||||
self.event_handle_lock = asyncio.Lock()
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional, List, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .component_types import EventType, EventHandlerInfo, ComponentType
|
||||
|
||||
from .component_types import ComponentType, EventHandlerInfo, EventType
|
||||
|
||||
logger = get_logger("base_event_handler")
|
||||
|
||||
@@ -21,7 +21,7 @@ class BaseEventHandler(ABC):
|
||||
"""处理器权重,越大权重越高"""
|
||||
intercept_message: bool = False
|
||||
"""是否拦截消息,默认为否"""
|
||||
init_subscribe: List[Union[EventType, str]] = [EventType.UNKNOWN]
|
||||
init_subscribe: list[EventType | str] = [EventType.UNKNOWN]
|
||||
"""初始化时订阅的事件名称"""
|
||||
plugin_name = None
|
||||
|
||||
@@ -44,7 +44,7 @@ class BaseEventHandler(ABC):
|
||||
self.plugin_config = getattr(self.__class__, "plugin_config", {})
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, kwargs: dict | None) -> Tuple[bool, bool, Optional[str]]:
|
||||
async def execute(self, kwargs: dict | None) -> tuple[bool, bool, str | None]:
|
||||
"""执行事件处理的抽象方法,子类必须实现
|
||||
Args:
|
||||
kwargs (dict | None): 事件消息对象,当你注册的事件为ON_START和ON_STOP时message为None
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Type, Tuple, Union
|
||||
from .plugin_base import PluginBase
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, PlusCommandInfo, EventHandlerInfo, ToolInfo
|
||||
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, PlusCommandInfo, ToolInfo
|
||||
|
||||
from .base_action import BaseAction
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .base_tool import BaseTool
|
||||
from .plugin_base import PluginBase
|
||||
from .plus_command import PlusCommand
|
||||
|
||||
logger = get_logger("base_plugin")
|
||||
@@ -28,14 +28,12 @@ class BasePlugin(PluginBase):
|
||||
@abstractmethod
|
||||
def get_plugin_components(
|
||||
self,
|
||||
) -> List[
|
||||
Union[
|
||||
Tuple[ActionInfo, Type[BaseAction]],
|
||||
Tuple[CommandInfo, Type[BaseCommand]],
|
||||
Tuple[PlusCommandInfo, Type[PlusCommand]],
|
||||
Tuple[EventHandlerInfo, Type[BaseEventHandler]],
|
||||
Tuple[ToolInfo, Type[BaseTool]],
|
||||
]
|
||||
) -> list[
|
||||
tuple[ActionInfo, type[BaseAction]]
|
||||
| tuple[CommandInfo, type[BaseCommand]]
|
||||
| tuple[PlusCommandInfo, type[PlusCommand]]
|
||||
| tuple[EventHandlerInfo, type[BaseEventHandler]]
|
||||
| tuple[ToolInfo, type[BaseTool]]
|
||||
]:
|
||||
"""获取插件包含的组件列表
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -17,7 +18,7 @@ class BaseTool(ABC):
|
||||
"""工具的名称"""
|
||||
description: str = ""
|
||||
"""工具的描述"""
|
||||
parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = []
|
||||
parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
|
||||
"""工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式
|
||||
param_name: 参数名称
|
||||
param_type: 参数类型
|
||||
@@ -35,7 +36,7 @@ class BaseTool(ABC):
|
||||
"""是否为该工具启用缓存"""
|
||||
cache_ttl: int = 3600
|
||||
"""缓存的TTL值(秒),默认为3600秒(1小时)"""
|
||||
semantic_cache_query_key: Optional[str] = None
|
||||
semantic_cache_query_key: str | None = None
|
||||
"""用于语义缓存的查询参数键名。如果设置,将使用此参数的值进行语义相似度搜索"""
|
||||
|
||||
# 二步工具调用相关属性
|
||||
@@ -43,10 +44,10 @@ class BaseTool(ABC):
|
||||
"""是否为二步工具。如果为True,工具将分两步调用:第一步展示工具信息,第二步执行具体操作"""
|
||||
step_one_description: str = ""
|
||||
"""第一步的描述,用于向LLM展示工具的基本功能"""
|
||||
sub_tools: List[Tuple[str, str, List[Tuple[str, ToolParamType, str, bool, List[str] | None]]]] = []
|
||||
sub_tools: list[tuple[str, str, list[tuple[str, ToolParamType, str, bool, list[str] | None]]]] = []
|
||||
"""子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用"""
|
||||
|
||||
def __init__(self, plugin_config: Optional[dict] = None):
|
||||
def __init__(self, plugin_config: dict | None = None):
|
||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||
|
||||
@classmethod
|
||||
@@ -101,7 +102,7 @@ class BaseTool(ABC):
|
||||
raise ValueError(f"未找到子工具: {sub_tool_name}")
|
||||
|
||||
@classmethod
|
||||
def get_all_sub_tool_definitions(cls) -> List[dict[str, Any]]:
|
||||
def get_all_sub_tool_definitions(cls) -> list[dict[str, Any]]:
|
||||
"""获取所有子工具的定义
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
提供简单易用的命令参数解析功能
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
import shlex
|
||||
|
||||
|
||||
@@ -20,7 +19,7 @@ class CommandArgs:
|
||||
raw_args: 原始参数字符串
|
||||
"""
|
||||
self._raw_args = raw_args.strip()
|
||||
self._parsed_args: Optional[List[str]] = None
|
||||
self._parsed_args: list[str] | None = None
|
||||
|
||||
def get_raw(self) -> str:
|
||||
"""获取完整的参数字符串
|
||||
@@ -30,7 +29,7 @@ class CommandArgs:
|
||||
"""
|
||||
return self._raw_args
|
||||
|
||||
def get_args(self) -> List[str]:
|
||||
def get_args(self) -> list[str]:
|
||||
"""获取解析后的参数列表
|
||||
|
||||
将参数按空格分割,支持引号包围的参数
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from maim_message import Seg
|
||||
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||
|
||||
|
||||
# 组件类型枚举
|
||||
@@ -114,7 +115,7 @@ class ComponentInfo:
|
||||
enabled: bool = True # 是否启用
|
||||
plugin_name: str = "" # 所属插件名称
|
||||
is_built_in: bool = False # 是否为内置组件
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
|
||||
metadata: dict[str, Any] = field(default_factory=dict) # 额外元数据
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
@@ -125,18 +126,18 @@ class ComponentInfo:
|
||||
class ActionInfo(ComponentInfo):
|
||||
"""动作组件信息"""
|
||||
|
||||
action_parameters: Dict[str, str] = field(
|
||||
action_parameters: dict[str, str] = field(
|
||||
default_factory=dict
|
||||
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
|
||||
action_require: List[str] = field(default_factory=list) # 动作需求说明
|
||||
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
|
||||
action_require: list[str] = field(default_factory=list) # 动作需求说明
|
||||
associated_types: list[str] = field(default_factory=list) # 关联的消息类型
|
||||
# 激活类型相关
|
||||
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
random_activation_probability: float = 0.0
|
||||
llm_judge_prompt: str = ""
|
||||
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
|
||||
activation_keywords: list[str] = field(default_factory=list) # 激活关键词列表
|
||||
keyword_case_sensitive: bool = False
|
||||
# 模式和并行设置
|
||||
mode_enable: ChatMode = ChatMode.ALL
|
||||
@@ -145,7 +146,7 @@ class ActionInfo(ComponentInfo):
|
||||
# 二步Action相关属性
|
||||
is_two_step_action: bool = False # 是否为二步Action
|
||||
step_one_description: str = "" # 第一步的描述
|
||||
sub_actions: List[Tuple[str, str, Dict[str, str]]] = field(default_factory=list) # 子Action列表
|
||||
sub_actions: list[tuple[str, str, dict[str, str]]] = field(default_factory=list) # 子Action列表
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
@@ -178,7 +179,7 @@ class CommandInfo(ComponentInfo):
|
||||
class PlusCommandInfo(ComponentInfo):
|
||||
"""增强命令组件信息"""
|
||||
|
||||
command_aliases: List[str] = field(default_factory=list) # 命令别名列表
|
||||
command_aliases: list[str] = field(default_factory=list) # 命令别名列表
|
||||
priority: int = 0 # 命令优先级
|
||||
chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型
|
||||
intercept_message: bool = False # 是否拦截消息
|
||||
@@ -194,7 +195,7 @@ class PlusCommandInfo(ComponentInfo):
|
||||
class ToolInfo(ComponentInfo):
|
||||
"""工具组件信息"""
|
||||
|
||||
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(
|
||||
tool_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = field(
|
||||
default_factory=list
|
||||
) # 工具参数定义
|
||||
tool_description: str = "" # 工具描述
|
||||
@@ -248,18 +249,18 @@ class PluginInfo:
|
||||
author: str = "" # 插件作者
|
||||
enabled: bool = True # 是否启用
|
||||
is_built_in: bool = False # 是否为内置插件
|
||||
components: List[ComponentInfo] = field(default_factory=list) # 包含的组件列表
|
||||
dependencies: List[str] = field(default_factory=list) # 依赖的其他插件
|
||||
python_dependencies: List[PythonDependency] = field(default_factory=list) # Python包依赖
|
||||
components: list[ComponentInfo] = field(default_factory=list) # 包含的组件列表
|
||||
dependencies: list[str] = field(default_factory=list) # 依赖的其他插件
|
||||
python_dependencies: list[PythonDependency] = field(default_factory=list) # Python包依赖
|
||||
config_file: str = "" # 配置文件路径
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
|
||||
metadata: dict[str, Any] = field(default_factory=dict) # 额外元数据
|
||||
# 新增:manifest相关信息
|
||||
manifest_data: Dict[str, Any] = field(default_factory=dict) # manifest文件数据
|
||||
manifest_data: dict[str, Any] = field(default_factory=dict) # manifest文件数据
|
||||
license: str = "" # 插件许可证
|
||||
homepage_url: str = "" # 插件主页
|
||||
repository_url: str = "" # 插件仓库地址
|
||||
keywords: List[str] = field(default_factory=list) # 插件关键词
|
||||
categories: List[str] = field(default_factory=list) # 插件分类
|
||||
keywords: list[str] = field(default_factory=list) # 插件关键词
|
||||
categories: list[str] = field(default_factory=list) # 插件分类
|
||||
min_host_version: str = "" # 最低主机版本要求
|
||||
max_host_version: str = "" # 最高主机版本要求
|
||||
|
||||
@@ -279,7 +280,7 @@ class PluginInfo:
|
||||
if self.categories is None:
|
||||
self.categories = []
|
||||
|
||||
def get_missing_packages(self) -> List[PythonDependency]:
|
||||
def get_missing_packages(self) -> list[PythonDependency]:
|
||||
"""检查缺失的Python包"""
|
||||
missing = []
|
||||
for dep in self.python_dependencies:
|
||||
@@ -290,7 +291,7 @@ class PluginInfo:
|
||||
missing.append(dep)
|
||||
return missing
|
||||
|
||||
def get_pip_requirements(self) -> List[str]:
|
||||
def get_pip_requirements(self) -> list[str]:
|
||||
"""获取所有pip安装格式的依赖"""
|
||||
return [dep.get_pip_requirement() for dep in self.python_dependencies]
|
||||
|
||||
@@ -299,16 +300,16 @@ class PluginInfo:
|
||||
class MaiMessages:
|
||||
"""MaiM插件消息"""
|
||||
|
||||
message_segments: List[Seg] = field(default_factory=list)
|
||||
message_segments: list[Seg] = field(default_factory=list)
|
||||
"""消息段列表,支持多段消息"""
|
||||
|
||||
message_base_info: Dict[str, Any] = field(default_factory=dict)
|
||||
message_base_info: dict[str, Any] = field(default_factory=dict)
|
||||
"""消息基本信息,包含平台,用户信息等数据"""
|
||||
|
||||
plain_text: str = ""
|
||||
"""纯文本消息内容"""
|
||||
|
||||
raw_message: Optional[str] = None
|
||||
raw_message: str | None = None
|
||||
"""原始消息内容"""
|
||||
|
||||
is_group_message: bool = False
|
||||
@@ -317,28 +318,28 @@ class MaiMessages:
|
||||
is_private_message: bool = False
|
||||
"""是否为私聊消息"""
|
||||
|
||||
stream_id: Optional[str] = None
|
||||
stream_id: str | None = None
|
||||
"""流ID,用于标识消息流"""
|
||||
|
||||
llm_prompt: Optional[str] = None
|
||||
llm_prompt: str | None = None
|
||||
"""LLM提示词"""
|
||||
|
||||
llm_response_content: Optional[str] = None
|
||||
llm_response_content: str | None = None
|
||||
"""LLM响应内容"""
|
||||
|
||||
llm_response_reasoning: Optional[str] = None
|
||||
llm_response_reasoning: str | None = None
|
||||
"""LLM响应推理内容"""
|
||||
|
||||
llm_response_model: Optional[str] = None
|
||||
llm_response_model: str | None = None
|
||||
"""LLM响应模型名称"""
|
||||
|
||||
llm_response_tool_call: Optional[List[ToolCall]] = None
|
||||
llm_response_tool_call: list[ToolCall] | None = None
|
||||
"""LLM使用的工具调用"""
|
||||
|
||||
action_usage: Optional[List[str]] = None
|
||||
action_usage: list[str] | None = None
|
||||
"""使用的Action"""
|
||||
|
||||
additional_data: Dict[Any, Any] = field(default_factory=dict)
|
||||
additional_data: dict[Any, Any] = field(default_factory=dict)
|
||||
"""附加数据,可以存储额外信息"""
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
插件系统配置类型定义
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, List
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -13,6 +13,6 @@ class ConfigField:
|
||||
type: type # 字段类型
|
||||
default: Any # 默认值
|
||||
description: str # 字段描述
|
||||
example: Optional[str] = None # 示例值
|
||||
example: str | None = None # 示例值
|
||||
required: bool = False # 是否必需
|
||||
choices: Optional[List[Any]] = field(default_factory=list) # 可选值列表
|
||||
choices: list[Any] | None = field(default_factory=list) # 可选值列表
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any, Union
|
||||
import os
|
||||
import toml
|
||||
import orjson
|
||||
import shutil
|
||||
import datetime
|
||||
import os
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
import toml
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import CONFIG_DIR
|
||||
@@ -38,12 +39,12 @@ class PluginBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dependencies(self) -> List[str]:
|
||||
def dependencies(self) -> list[str]:
|
||||
return [] # 依赖的其他插件
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def python_dependencies(self) -> List[Union[str, PythonDependency]]:
|
||||
def python_dependencies(self) -> list[str | PythonDependency]:
|
||||
return [] # Python包依赖,支持字符串列表或PythonDependency对象列表
|
||||
|
||||
@property
|
||||
@@ -53,15 +54,15 @@ class PluginBase(ABC):
|
||||
|
||||
# manifest文件相关
|
||||
manifest_file_name: str = "_manifest.json" # manifest文件名
|
||||
manifest_data: Dict[str, Any] = {} # manifest数据
|
||||
manifest_data: dict[str, Any] = {} # manifest数据
|
||||
|
||||
# 配置定义
|
||||
@property
|
||||
@abstractmethod
|
||||
def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]:
|
||||
def config_schema(self) -> dict[str, dict[str, ConfigField] | str]:
|
||||
return {}
|
||||
|
||||
config_section_descriptions: Dict[str, str] = {}
|
||||
config_section_descriptions: dict[str, str] = {}
|
||||
|
||||
def __init__(self, plugin_dir: str):
|
||||
"""初始化插件
|
||||
@@ -69,7 +70,7 @@ class PluginBase(ABC):
|
||||
Args:
|
||||
plugin_dir: 插件目录路径,由插件管理器传递
|
||||
"""
|
||||
self.config: Dict[str, Any] = {} # 插件配置
|
||||
self.config: dict[str, Any] = {} # 插件配置
|
||||
self.plugin_dir = plugin_dir # 插件目录路径
|
||||
self.log_prefix = f"[Plugin:{self.plugin_name}]"
|
||||
self._is_enabled = self.enable_plugin # 从插件定义中获取默认启用状态
|
||||
@@ -144,7 +145,7 @@ class PluginBase(ABC):
|
||||
raise FileNotFoundError(error_msg)
|
||||
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
with open(manifest_path, encoding="utf-8") as f:
|
||||
self.manifest_data = orjson.loads(f.read())
|
||||
|
||||
logger.debug(f"{self.log_prefix} 成功加载manifest文件: {manifest_path}")
|
||||
@@ -155,8 +156,8 @@ class PluginBase(ABC):
|
||||
except orjson.JSONDecodeError as e:
|
||||
error_msg = f"{self.log_prefix} manifest文件格式错误: {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg) # noqa
|
||||
except IOError as e:
|
||||
raise ValueError(error_msg)
|
||||
except OSError as e:
|
||||
error_msg = f"{self.log_prefix} 读取manifest文件失败: {e}"
|
||||
logger.error(error_msg)
|
||||
raise IOError(error_msg) # noqa
|
||||
@@ -266,7 +267,7 @@ class PluginBase(ABC):
|
||||
with open(config_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(toml_str)
|
||||
logger.info(f"{self.log_prefix} 已生成默认配置文件: {config_file_path}")
|
||||
except IOError as e:
|
||||
except OSError as e:
|
||||
logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True)
|
||||
|
||||
def _backup_config_file(self, config_file_path: str) -> str:
|
||||
@@ -288,13 +289,13 @@ class PluginBase(ABC):
|
||||
return ""
|
||||
|
||||
def _synchronize_config(
|
||||
self, schema_config: Dict[str, Any], user_config: Dict[str, Any]
|
||||
) -> tuple[Dict[str, Any], bool]:
|
||||
self, schema_config: dict[str, Any], user_config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], bool]:
|
||||
"""递归地将用户配置与 schema 同步,返回同步后的配置和是否发生变化的标志"""
|
||||
changed = False
|
||||
|
||||
# 内部递归函数
|
||||
def _sync_dicts(schema_dict: Dict[str, Any], user_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, Any]:
|
||||
def _sync_dicts(schema_dict: dict[str, Any], user_dict: dict[str, Any], parent_key: str = "") -> dict[str, Any]:
|
||||
nonlocal changed
|
||||
synced_dict = schema_dict.copy()
|
||||
|
||||
@@ -326,7 +327,7 @@ class PluginBase(ABC):
|
||||
final_config = _sync_dicts(schema_config, user_config)
|
||||
return final_config, changed
|
||||
|
||||
def _generate_config_from_schema(self) -> Dict[str, Any]:
|
||||
def _generate_config_from_schema(self) -> dict[str, Any]:
|
||||
# sourcery skip: dict-comprehension
|
||||
"""根据schema生成配置数据结构(不写入文件)"""
|
||||
if not self.config_schema:
|
||||
@@ -348,7 +349,7 @@ class PluginBase(ABC):
|
||||
|
||||
return config_data
|
||||
|
||||
def _save_config_to_file(self, config_data: Dict[str, Any], config_file_path: str):
|
||||
def _save_config_to_file(self, config_data: dict[str, Any], config_file_path: str):
|
||||
"""将配置数据保存为TOML文件(包含注释)"""
|
||||
if not self.config_schema:
|
||||
logger.debug(f"{self.log_prefix} 插件未定义config_schema,不生成配置文件")
|
||||
@@ -410,7 +411,7 @@ class PluginBase(ABC):
|
||||
with open(config_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(toml_str)
|
||||
logger.info(f"{self.log_prefix} 配置文件已保存: {config_file_path}")
|
||||
except IOError as e:
|
||||
except OSError as e:
|
||||
logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True)
|
||||
|
||||
def _load_plugin_config(self): # sourcery skip: extract-method
|
||||
@@ -456,7 +457,7 @@ class PluginBase(ABC):
|
||||
return
|
||||
|
||||
try:
|
||||
with open(user_config_path, "r", encoding="utf-8") as f:
|
||||
with open(user_config_path, encoding="utf-8") as f:
|
||||
user_config = toml.load(f) or {}
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True)
|
||||
@@ -520,7 +521,7 @@ class PluginBase(ABC):
|
||||
|
||||
return current
|
||||
|
||||
def _normalize_python_dependencies(self, dependencies: Any) -> List[PythonDependency]:
|
||||
def _normalize_python_dependencies(self, dependencies: Any) -> list[PythonDependency]:
|
||||
"""将依赖列表标准化为PythonDependency对象"""
|
||||
from packaging.requirements import Requirement
|
||||
|
||||
@@ -549,7 +550,7 @@ class PluginBase(ABC):
|
||||
|
||||
return normalized
|
||||
|
||||
def _check_python_dependencies(self, dependencies: List[PythonDependency]) -> bool:
|
||||
def _check_python_dependencies(self, dependencies: list[PythonDependency]) -> bool:
|
||||
"""检查Python依赖并尝试自动安装"""
|
||||
if not dependencies:
|
||||
logger.info(f"{self.log_prefix} 无Python依赖需要检查")
|
||||
|
||||
@@ -3,17 +3,16 @@
|
||||
提供更简单易用的命令处理方式,无需手写正则表达式
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional, List
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import PlusCommandInfo, ComponentType, ChatType
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo
|
||||
|
||||
logger = get_logger("plus_command")
|
||||
|
||||
@@ -39,7 +38,7 @@ class PlusCommand(ABC):
|
||||
command_description: str = ""
|
||||
"""命令描述"""
|
||||
|
||||
command_aliases: List[str] = []
|
||||
command_aliases: list[str] = []
|
||||
"""命令别名列表,如 ['say', 'repeat']"""
|
||||
|
||||
priority: int = 0
|
||||
@@ -51,7 +50,7 @@ class PlusCommand(ABC):
|
||||
intercept_message: bool = False
|
||||
"""是否拦截消息,不进行后续处理"""
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
|
||||
"""初始化命令组件
|
||||
|
||||
Args:
|
||||
@@ -172,7 +171,7 @@ class PlusCommand(ABC):
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
|
||||
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
|
||||
"""执行命令的抽象方法,子类必须实现
|
||||
|
||||
Args:
|
||||
@@ -341,7 +340,7 @@ class PlusCommandAdapter(BaseCommand):
|
||||
将PlusCommand适配到现有的插件系统,继承BaseCommand
|
||||
"""
|
||||
|
||||
def __init__(self, plus_command_class, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
def __init__(self, plus_command_class, message: MessageRecv, plugin_config: dict | None = None):
|
||||
"""初始化适配器
|
||||
|
||||
Args:
|
||||
@@ -363,7 +362,7 @@ class PlusCommandAdapter(BaseCommand):
|
||||
# 创建PlusCommand实例
|
||||
self.plus_command = plus_command_class(message, plugin_config)
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
async def execute(self) -> tuple[bool, str | None, bool]:
|
||||
"""执行命令
|
||||
|
||||
Returns:
|
||||
@@ -382,7 +381,7 @@ class PlusCommandAdapter(BaseCommand):
|
||||
return await self.plus_command.execute(self.plus_command.args)
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {e}", exc_info=True)
|
||||
return False, f"命令执行出错: {str(e)}", self.intercept_message
|
||||
return False, f"命令执行出错: {e!s}", self.intercept_message
|
||||
|
||||
|
||||
def create_plus_command_adapter(plus_command_class):
|
||||
@@ -401,13 +400,13 @@ def create_plus_command_adapter(plus_command_class):
|
||||
command_pattern = plus_command_class._generate_command_pattern()
|
||||
chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
|
||||
super().__init__(message, plugin_config)
|
||||
self.plus_command = plus_command_class(message, plugin_config)
|
||||
self.priority = getattr(plus_command_class, "priority", 0)
|
||||
self.intercept_message = getattr(plus_command_class, "intercept_message", False)
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
async def execute(self) -> tuple[bool, str | None, bool]:
|
||||
"""执行命令"""
|
||||
# 从BaseCommand的正则匹配结果中提取参数
|
||||
args_text = ""
|
||||
@@ -429,7 +428,7 @@ def create_plus_command_adapter(plus_command_class):
|
||||
return await self.plus_command.execute(command_args)
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {e}", exc_info=True)
|
||||
return False, f"命令执行出错: {str(e)}", self.intercept_message
|
||||
return False, f"命令执行出错: {e!s}", self.intercept_message
|
||||
|
||||
return AdapterClass
|
||||
|
||||
|
||||
@@ -4,14 +4,14 @@
|
||||
提供插件的加载、注册和管理功能
|
||||
"""
|
||||
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"event_manager",
|
||||
"global_announcement_manager",
|
||||
"plugin_manager",
|
||||
]
|
||||
|
||||
@@ -1,27 +1,26 @@
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
ComponentInfo,
|
||||
ActionInfo,
|
||||
ToolInfo,
|
||||
CommandInfo,
|
||||
PlusCommandInfo,
|
||||
EventHandlerInfo,
|
||||
ChatterInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
)
|
||||
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import (
|
||||
ActionInfo,
|
||||
ChatterInfo,
|
||||
CommandInfo,
|
||||
ComponentInfo,
|
||||
ComponentType,
|
||||
EventHandlerInfo,
|
||||
PluginInfo,
|
||||
PlusCommandInfo,
|
||||
ToolInfo,
|
||||
)
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
|
||||
logger = get_logger("component_registry")
|
||||
|
||||
@@ -34,46 +33,46 @@ class ComponentRegistry:
|
||||
|
||||
def __init__(self):
|
||||
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
|
||||
self._components: Dict[str, "ComponentInfo"] = {}
|
||||
self._components: dict[str, "ComponentInfo"] = {}
|
||||
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
||||
self._components_by_type: Dict["ComponentType", Dict[str, "ComponentInfo"]] = {
|
||||
self._components_by_type: dict["ComponentType", dict[str, "ComponentInfo"]] = {
|
||||
types: {} for types in ComponentType
|
||||
}
|
||||
"""类型 -> 组件原名称 -> 组件信息"""
|
||||
self._components_classes: Dict[
|
||||
str, Type[Union["BaseCommand", "BaseAction", "BaseTool", "BaseEventHandler", "PlusCommand", "BaseChatter"]]
|
||||
self._components_classes: dict[
|
||||
str, type["BaseCommand" | "BaseAction" | "BaseTool" | "BaseEventHandler" | "PlusCommand" | "BaseChatter"]
|
||||
] = {}
|
||||
"""命名空间式组件名 -> 组件类"""
|
||||
|
||||
# 插件注册表
|
||||
self._plugins: Dict[str, "PluginInfo"] = {}
|
||||
self._plugins: dict[str, "PluginInfo"] = {}
|
||||
"""插件名 -> 插件信息"""
|
||||
|
||||
# Action特定注册表
|
||||
self._action_registry: Dict[str, Type["BaseAction"]] = {}
|
||||
self._action_registry: dict[str, type["BaseAction"]] = {}
|
||||
"""Action注册表 action名 -> action类"""
|
||||
self._default_actions: Dict[str, "ActionInfo"] = {}
|
||||
self._default_actions: dict[str, "ActionInfo"] = {}
|
||||
"""默认动作集,即启用的Action集,用于重置ActionManager状态"""
|
||||
|
||||
# Command特定注册表
|
||||
self._command_registry: Dict[str, Type["BaseCommand"]] = {}
|
||||
self._command_registry: dict[str, type["BaseCommand"]] = {}
|
||||
"""Command类注册表 command名 -> command类"""
|
||||
self._command_patterns: Dict[Pattern, str] = {}
|
||||
self._command_patterns: dict[Pattern, str] = {}
|
||||
"""编译后的正则 -> command名"""
|
||||
|
||||
# 工具特定注册表
|
||||
self._tool_registry: Dict[str, Type["BaseTool"]] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: Dict[str, Type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
|
||||
self._tool_registry: dict[str, type["BaseTool"]] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: dict[str, type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
|
||||
|
||||
# EventHandler特定注册表
|
||||
self._event_handler_registry: Dict[str, Type["BaseEventHandler"]] = {}
|
||||
self._event_handler_registry: dict[str, type["BaseEventHandler"]] = {}
|
||||
"""event_handler名 -> event_handler类"""
|
||||
self._enabled_event_handlers: Dict[str, Type["BaseEventHandler"]] = {}
|
||||
self._enabled_event_handlers: dict[str, type["BaseEventHandler"]] = {}
|
||||
"""启用的事件处理器 event_handler名 -> event_handler类"""
|
||||
|
||||
self._chatter_registry: Dict[str, Type["BaseChatter"]] = {}
|
||||
self._chatter_registry: dict[str, type["BaseChatter"]] = {}
|
||||
"""chatter名 -> chatter类"""
|
||||
self._enabled_chatter_registry: Dict[str, Type["BaseChatter"]] = {}
|
||||
self._enabled_chatter_registry: dict[str, type["BaseChatter"]] = {}
|
||||
"""启用的chatter名 -> chatter类"""
|
||||
logger.info("组件注册中心初始化完成")
|
||||
|
||||
@@ -101,7 +100,7 @@ class ComponentRegistry:
|
||||
def register_component(
|
||||
self,
|
||||
component_info: ComponentInfo,
|
||||
component_class: Type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]],
|
||||
component_class: type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]],
|
||||
) -> bool:
|
||||
"""注册组件
|
||||
|
||||
@@ -174,7 +173,7 @@ class ComponentRegistry:
|
||||
)
|
||||
return True
|
||||
|
||||
def _register_action_component(self, action_info: "ActionInfo", action_class: Type["BaseAction"]) -> bool:
|
||||
def _register_action_component(self, action_info: "ActionInfo", action_class: type["BaseAction"]) -> bool:
|
||||
"""注册Action组件到Action特定注册表"""
|
||||
if not (action_name := action_info.name):
|
||||
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
|
||||
@@ -194,7 +193,7 @@ class ComponentRegistry:
|
||||
|
||||
return True
|
||||
|
||||
def _register_command_component(self, command_info: "CommandInfo", command_class: Type["BaseCommand"]) -> bool:
|
||||
def _register_command_component(self, command_info: "CommandInfo", command_class: type["BaseCommand"]) -> bool:
|
||||
"""注册Command组件到Command特定注册表"""
|
||||
if not (command_name := command_info.name):
|
||||
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
|
||||
@@ -221,7 +220,7 @@ class ComponentRegistry:
|
||||
return True
|
||||
|
||||
def _register_plus_command_component(
|
||||
self, plus_command_info: "PlusCommandInfo", plus_command_class: Type["PlusCommand"]
|
||||
self, plus_command_info: "PlusCommandInfo", plus_command_class: type["PlusCommand"]
|
||||
) -> bool:
|
||||
"""注册PlusCommand组件到特定注册表"""
|
||||
plus_command_name = plus_command_info.name
|
||||
@@ -235,7 +234,7 @@ class ComponentRegistry:
|
||||
|
||||
# 创建专门的PlusCommand注册表(如果还没有)
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
self._plus_command_registry: Dict[str, Type["PlusCommand"]] = {}
|
||||
self._plus_command_registry: dict[str, type["PlusCommand"]] = {}
|
||||
|
||||
plus_command_class.plugin_name = plus_command_info.plugin_name
|
||||
# 设置插件配置
|
||||
@@ -245,7 +244,7 @@ class ComponentRegistry:
|
||||
logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
|
||||
return True
|
||||
|
||||
def _register_tool_component(self, tool_info: "ToolInfo", tool_class: Type["BaseTool"]) -> bool:
|
||||
def _register_tool_component(self, tool_info: "ToolInfo", tool_class: type["BaseTool"]) -> bool:
|
||||
"""注册Tool组件到Tool特定注册表"""
|
||||
tool_name = tool_info.name
|
||||
|
||||
@@ -261,7 +260,7 @@ class ComponentRegistry:
|
||||
return True
|
||||
|
||||
def _register_event_handler_component(
|
||||
self, handler_info: "EventHandlerInfo", handler_class: Type["BaseEventHandler"]
|
||||
self, handler_info: "EventHandlerInfo", handler_class: type["BaseEventHandler"]
|
||||
) -> bool:
|
||||
if not (handler_name := handler_info.name):
|
||||
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
|
||||
@@ -287,7 +286,7 @@ class ComponentRegistry:
|
||||
handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
|
||||
)
|
||||
|
||||
def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: Type["BaseChatter"]) -> bool:
|
||||
def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: type["BaseChatter"]) -> bool:
|
||||
"""注册Chatter组件到Chatter特定注册表"""
|
||||
chatter_name = chatter_info.name
|
||||
|
||||
@@ -532,7 +531,7 @@ class ComponentRegistry:
|
||||
self,
|
||||
component_name: str,
|
||||
component_type: Optional["ComponentType"] = None,
|
||||
) -> Optional[Union[Type["BaseCommand"], Type["BaseAction"], Type["BaseEventHandler"], Type["BaseTool"]]]:
|
||||
) -> type["BaseCommand"] | type["BaseAction"] | type["BaseEventHandler"] | type["BaseTool"] | None:
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
@@ -574,18 +573,18 @@ class ComponentRegistry:
|
||||
# 4. 都没找到
|
||||
return None
|
||||
|
||||
def get_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]:
|
||||
def get_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]:
|
||||
"""获取指定类型的所有组件"""
|
||||
return self._components_by_type.get(component_type, {}).copy()
|
||||
|
||||
def get_enabled_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]:
|
||||
def get_enabled_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]:
|
||||
"""获取指定类型的所有启用组件"""
|
||||
components = self.get_components_by_type(component_type)
|
||||
return {name: info for name, info in components.items() if info.enabled}
|
||||
|
||||
# === Action特定查询方法 ===
|
||||
|
||||
def get_action_registry(self) -> Dict[str, Type["BaseAction"]]:
|
||||
def get_action_registry(self) -> dict[str, type["BaseAction"]]:
|
||||
"""获取Action注册表"""
|
||||
return self._action_registry.copy()
|
||||
|
||||
@@ -594,13 +593,13 @@ class ComponentRegistry:
|
||||
info = self.get_component_info(action_name, ComponentType.ACTION)
|
||||
return info if isinstance(info, ActionInfo) else None
|
||||
|
||||
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
||||
def get_default_actions(self) -> dict[str, ActionInfo]:
|
||||
"""获取默认动作集"""
|
||||
return self._default_actions.copy()
|
||||
|
||||
# === Command特定查询方法 ===
|
||||
|
||||
def get_command_registry(self) -> Dict[str, Type["BaseCommand"]]:
|
||||
def get_command_registry(self) -> dict[str, type["BaseCommand"]]:
|
||||
"""获取Command注册表"""
|
||||
return self._command_registry.copy()
|
||||
|
||||
@@ -609,11 +608,11 @@ class ComponentRegistry:
|
||||
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
||||
return info if isinstance(info, CommandInfo) else None
|
||||
|
||||
def get_command_patterns(self) -> Dict[Pattern, str]:
|
||||
def get_command_patterns(self) -> dict[Pattern, str]:
|
||||
"""获取Command模式注册表"""
|
||||
return self._command_patterns.copy()
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type["BaseCommand"], dict, "CommandInfo"]]:
|
||||
def find_command_by_text(self, text: str) -> tuple[type["BaseCommand"], dict, "CommandInfo"] | None:
|
||||
# sourcery skip: use-named-expression, use-next
|
||||
"""根据文本查找匹配的命令
|
||||
|
||||
@@ -640,11 +639,11 @@ class ComponentRegistry:
|
||||
return None
|
||||
|
||||
# === Tool 特定查询方法 ===
|
||||
def get_tool_registry(self) -> Dict[str, Type["BaseTool"]]:
|
||||
def get_tool_registry(self) -> dict[str, type["BaseTool"]]:
|
||||
"""获取Tool注册表"""
|
||||
return self._tool_registry.copy()
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, Type["BaseTool"]]:
|
||||
def get_llm_available_tools(self) -> dict[str, type["BaseTool"]]:
|
||||
"""获取LLM可用的Tool列表"""
|
||||
return self._llm_available_tools.copy()
|
||||
|
||||
@@ -661,10 +660,10 @@ class ComponentRegistry:
|
||||
return info if isinstance(info, ToolInfo) else None
|
||||
|
||||
# === PlusCommand 特定查询方法 ===
|
||||
def get_plus_command_registry(self) -> Dict[str, Type["PlusCommand"]]:
|
||||
def get_plus_command_registry(self) -> dict[str, type["PlusCommand"]]:
|
||||
"""获取PlusCommand注册表"""
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
pass
|
||||
self._plus_command_registry: dict[str, type[PlusCommand]] = {}
|
||||
return self._plus_command_registry.copy()
|
||||
|
||||
def get_registered_plus_command_info(self, command_name: str) -> Optional["PlusCommandInfo"]:
|
||||
@@ -681,7 +680,7 @@ class ComponentRegistry:
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
|
||||
def get_event_handler_registry(self) -> Dict[str, Type["BaseEventHandler"]]:
|
||||
def get_event_handler_registry(self) -> dict[str, type["BaseEventHandler"]]:
|
||||
"""获取事件处理器注册表"""
|
||||
return self._event_handler_registry.copy()
|
||||
|
||||
@@ -690,21 +689,21 @@ class ComponentRegistry:
|
||||
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
|
||||
return info if isinstance(info, EventHandlerInfo) else None
|
||||
|
||||
def get_enabled_event_handlers(self) -> Dict[str, Type["BaseEventHandler"]]:
|
||||
def get_enabled_event_handlers(self) -> dict[str, type["BaseEventHandler"]]:
|
||||
"""获取启用的事件处理器"""
|
||||
return self._enabled_event_handlers.copy()
|
||||
|
||||
# === Chatter 特定查询方法 ===
|
||||
def get_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]:
|
||||
def get_chatter_registry(self) -> dict[str, type["BaseChatter"]]:
|
||||
"""获取Chatter注册表"""
|
||||
if not hasattr(self, "_chatter_registry"):
|
||||
self._chatter_registry: Dict[str, Type[BaseChatter]] = {}
|
||||
self._chatter_registry: dict[str, type[BaseChatter]] = {}
|
||||
return self._chatter_registry.copy()
|
||||
|
||||
def get_enabled_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]:
|
||||
def get_enabled_chatter_registry(self) -> dict[str, type["BaseChatter"]]:
|
||||
"""获取启用的Chatter注册表"""
|
||||
if not hasattr(self, "_enabled_chatter_registry"):
|
||||
self._enabled_chatter_registry: Dict[str, Type[BaseChatter]] = {}
|
||||
self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {}
|
||||
return self._enabled_chatter_registry.copy()
|
||||
|
||||
def get_registered_chatter_info(self, chatter_name: str) -> Optional["ChatterInfo"]:
|
||||
@@ -718,7 +717,7 @@ class ComponentRegistry:
|
||||
"""获取插件信息"""
|
||||
return self._plugins.get(plugin_name)
|
||||
|
||||
def get_all_plugins(self) -> Dict[str, "PluginInfo"]:
|
||||
def get_all_plugins(self) -> dict[str, "PluginInfo"]:
|
||||
"""获取所有插件"""
|
||||
return self._plugins.copy()
|
||||
|
||||
@@ -726,7 +725,7 @@ class ComponentRegistry:
|
||||
# """获取所有启用的插件"""
|
||||
# return {name: info for name, info in self._plugins.items() if info.enabled}
|
||||
|
||||
def get_plugin_components(self, plugin_name: str) -> List["ComponentInfo"]:
|
||||
def get_plugin_components(self, plugin_name: str) -> list["ComponentInfo"]:
|
||||
"""获取插件的所有组件"""
|
||||
plugin_info = self.get_plugin_info(plugin_name)
|
||||
return plugin_info.components if plugin_info else []
|
||||
@@ -754,7 +753,7 @@ class ComponentRegistry:
|
||||
|
||||
config_path = Path("config") / "plugins" / plugin_name / "config.toml"
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
logger.debug(f"从配置文件读取插件 {plugin_name} 的配置")
|
||||
return config_data
|
||||
@@ -763,7 +762,7 @@ class ComponentRegistry:
|
||||
|
||||
return {}
|
||||
|
||||
def get_registry_stats(self) -> Dict[str, Any]:
|
||||
def get_registry_stats(self) -> dict[str, Any]:
|
||||
"""获取注册中心统计信息"""
|
||||
action_components: int = 0
|
||||
command_components: int = 0
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
提供统一的事件注册、管理和触发接口
|
||||
"""
|
||||
|
||||
from typing import Dict, Type, List, Optional, Any, Union
|
||||
from threading import Lock
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseEventHandler
|
||||
@@ -37,17 +37,17 @@ class EventManager:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._events: Dict[str, BaseEvent] = {}
|
||||
self._event_handlers: Dict[str, Type[BaseEventHandler]] = {}
|
||||
self._pending_subscriptions: Dict[str, List[str]] = {} # 缓存失败的订阅
|
||||
self._events: dict[str, BaseEvent] = {}
|
||||
self._event_handlers: dict[str, type[BaseEventHandler]] = {}
|
||||
self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅
|
||||
self._initialized = True
|
||||
logger.info("EventManager 单例初始化完成")
|
||||
|
||||
def register_event(
|
||||
self,
|
||||
event_name: Union[EventType, str],
|
||||
allowed_subscribers: List[str] = None,
|
||||
allowed_triggers: List[str] = None,
|
||||
event_name: EventType | str,
|
||||
allowed_subscribers: list[str] = None,
|
||||
allowed_triggers: list[str] = None,
|
||||
) -> bool:
|
||||
"""注册一个新的事件
|
||||
|
||||
@@ -75,7 +75,7 @@ class EventManager:
|
||||
|
||||
return True
|
||||
|
||||
def get_event(self, event_name: Union[EventType, str]) -> Optional[BaseEvent]:
|
||||
def get_event(self, event_name: EventType | str) -> BaseEvent | None:
|
||||
"""获取指定事件实例
|
||||
|
||||
Args:
|
||||
@@ -86,7 +86,7 @@ class EventManager:
|
||||
"""
|
||||
return self._events.get(event_name)
|
||||
|
||||
def get_all_events(self) -> Dict[str, BaseEvent]:
|
||||
def get_all_events(self) -> dict[str, BaseEvent]:
|
||||
"""获取所有已注册的事件
|
||||
|
||||
Returns:
|
||||
@@ -94,7 +94,7 @@ class EventManager:
|
||||
"""
|
||||
return self._events.copy()
|
||||
|
||||
def get_enabled_events(self) -> Dict[str, BaseEvent]:
|
||||
def get_enabled_events(self) -> dict[str, BaseEvent]:
|
||||
"""获取所有已启用的事件
|
||||
|
||||
Returns:
|
||||
@@ -102,7 +102,7 @@ class EventManager:
|
||||
"""
|
||||
return {name: event for name, event in self._events.items() if event.enabled}
|
||||
|
||||
def get_disabled_events(self) -> Dict[str, BaseEvent]:
|
||||
def get_disabled_events(self) -> dict[str, BaseEvent]:
|
||||
"""获取所有已禁用的事件
|
||||
|
||||
Returns:
|
||||
@@ -110,7 +110,7 @@ class EventManager:
|
||||
"""
|
||||
return {name: event for name, event in self._events.items() if not event.enabled}
|
||||
|
||||
def enable_event(self, event_name: Union[EventType, str]) -> bool:
|
||||
def enable_event(self, event_name: EventType | str) -> bool:
|
||||
"""启用指定事件
|
||||
|
||||
Args:
|
||||
@@ -128,7 +128,7 @@ class EventManager:
|
||||
logger.info(f"事件 {event_name} 已启用")
|
||||
return True
|
||||
|
||||
def disable_event(self, event_name: Union[EventType, str]) -> bool:
|
||||
def disable_event(self, event_name: EventType | str) -> bool:
|
||||
"""禁用指定事件
|
||||
|
||||
Args:
|
||||
@@ -146,9 +146,7 @@ class EventManager:
|
||||
logger.info(f"事件 {event_name} 已禁用")
|
||||
return True
|
||||
|
||||
def register_event_handler(
|
||||
self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None
|
||||
) -> bool:
|
||||
def register_event_handler(self, handler_class: type[BaseEventHandler], plugin_config: dict | None = None) -> bool:
|
||||
"""注册事件处理器
|
||||
|
||||
Args:
|
||||
@@ -190,7 +188,7 @@ class EventManager:
|
||||
logger.info(f"事件处理器 {handler_name} 注册成功")
|
||||
return True
|
||||
|
||||
def get_event_handler(self, handler_name: str) -> Optional[Type[BaseEventHandler]]:
|
||||
def get_event_handler(self, handler_name: str) -> type[BaseEventHandler] | None:
|
||||
"""获取指定事件处理器实例
|
||||
|
||||
Args:
|
||||
@@ -209,7 +207,7 @@ class EventManager:
|
||||
"""
|
||||
return self._event_handlers.copy()
|
||||
|
||||
def subscribe_handler_to_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
|
||||
def subscribe_handler_to_event(self, handler_name: str, event_name: EventType | str) -> bool:
|
||||
"""订阅事件处理器到指定事件
|
||||
|
||||
Args:
|
||||
@@ -246,7 +244,7 @@ class EventManager:
|
||||
logger.info(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成")
|
||||
return True
|
||||
|
||||
def unsubscribe_handler_from_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
|
||||
def unsubscribe_handler_from_event(self, handler_name: str, event_name: EventType | str) -> bool:
|
||||
"""从指定事件取消订阅事件处理器
|
||||
|
||||
Args:
|
||||
@@ -276,7 +274,7 @@ class EventManager:
|
||||
|
||||
return removed
|
||||
|
||||
def get_event_subscribers(self, event_name: Union[EventType, str]) -> Dict[str, BaseEventHandler]:
|
||||
def get_event_subscribers(self, event_name: EventType | str) -> dict[str, BaseEventHandler]:
|
||||
"""获取订阅指定事件的所有事件处理器
|
||||
|
||||
Args:
|
||||
@@ -292,8 +290,8 @@ class EventManager:
|
||||
return {handler.handler_name: handler for handler in event.subscribers}
|
||||
|
||||
async def trigger_event(
|
||||
self, event_name: Union[EventType, str], permission_group: Optional[str] = "", **kwargs
|
||||
) -> Optional[HandlerResultsCollection]:
|
||||
self, event_name: EventType | str, permission_group: str | None = "", **kwargs
|
||||
) -> HandlerResultsCollection | None:
|
||||
"""触发指定事件
|
||||
|
||||
Args:
|
||||
@@ -345,7 +343,7 @@ class EventManager:
|
||||
self._event_handlers.clear()
|
||||
logger.info("所有事件和处理器已清除")
|
||||
|
||||
def get_event_summary(self) -> Dict[str, Any]:
|
||||
def get_event_summary(self) -> dict[str, Any]:
|
||||
"""获取事件系统摘要
|
||||
|
||||
Returns:
|
||||
@@ -364,7 +362,7 @@ class EventManager:
|
||||
"pending_subscriptions": len(self._pending_subscriptions),
|
||||
}
|
||||
|
||||
def _process_pending_subscriptions(self, event_name: Union[EventType, str]) -> None:
|
||||
def _process_pending_subscriptions(self, event_name: EventType | str) -> None:
|
||||
"""处理指定事件的缓存订阅
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("global_announcement_manager")
|
||||
@@ -8,13 +6,13 @@ logger = get_logger("global_announcement_manager")
|
||||
class GlobalAnnouncementManager:
|
||||
def __init__(self) -> None:
|
||||
# 用户禁用的动作,chat_id -> [action_name]
|
||||
self._user_disabled_actions: Dict[str, List[str]] = {}
|
||||
self._user_disabled_actions: dict[str, list[str]] = {}
|
||||
# 用户禁用的命令,chat_id -> [command_name]
|
||||
self._user_disabled_commands: Dict[str, List[str]] = {}
|
||||
self._user_disabled_commands: dict[str, list[str]] = {}
|
||||
# 用户禁用的事件处理器,chat_id -> [handler_name]
|
||||
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
|
||||
self._user_disabled_event_handlers: dict[str, list[str]] = {}
|
||||
# 用户禁用的工具,chat_id -> [tool_name]
|
||||
self._user_disabled_tools: Dict[str, List[str]] = {}
|
||||
self._user_disabled_tools: dict[str, list[str]] = {}
|
||||
|
||||
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""禁用特定聊天的某个动作"""
|
||||
@@ -100,19 +98,19 @@ class GlobalAnnouncementManager:
|
||||
return False
|
||||
return False
|
||||
|
||||
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
|
||||
def get_disabled_chat_actions(self, chat_id: str) -> list[str]:
|
||||
"""获取特定聊天禁用的所有动作"""
|
||||
return self._user_disabled_actions.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
|
||||
def get_disabled_chat_commands(self, chat_id: str) -> list[str]:
|
||||
"""获取特定聊天禁用的所有命令"""
|
||||
return self._user_disabled_commands.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
|
||||
def get_disabled_chat_event_handlers(self, chat_id: str) -> list[str]:
|
||||
"""获取特定聊天禁用的所有事件处理器"""
|
||||
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
|
||||
def get_disabled_chat_tools(self, chat_id: str) -> list[str]:
|
||||
"""获取特定聊天禁用的所有工具"""
|
||||
return self._user_disabled_tools.get(chat_id, []).copy()
|
||||
|
||||
|
||||
@@ -4,16 +4,16 @@
|
||||
这个模块提供了权限系统的核心实现,包括权限检查、权限节点管理、用户权限管理等功能。
|
||||
"""
|
||||
|
||||
from typing import List, Set, Tuple
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, delete
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from src.common.database.sqlalchemy_models import PermissionNodes, UserPermissions, get_engine
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import get_engine, PermissionNodes, UserPermissions
|
||||
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -24,7 +24,7 @@ class PermissionManager(IPermissionManager):
|
||||
def __init__(self):
|
||||
self.engine = None
|
||||
self.SessionLocal = None
|
||||
self._master_users: Set[Tuple[str, str]] = set()
|
||||
self._master_users: set[tuple[str, str]] = set()
|
||||
self._load_master_users()
|
||||
|
||||
async def initialize(self):
|
||||
@@ -276,7 +276,7 @@ class PermissionManager(IPermissionManager):
|
||||
logger.error(f"撤销权限时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
async def get_user_permissions(self, user: UserInfo) -> List[str]:
|
||||
async def get_user_permissions(self, user: UserInfo) -> list[str]:
|
||||
"""
|
||||
获取用户拥有的所有权限节点
|
||||
|
||||
@@ -328,7 +328,7 @@ class PermissionManager(IPermissionManager):
|
||||
logger.error(f"获取用户权限时发生未知错误: {e}")
|
||||
return []
|
||||
|
||||
async def get_all_permission_nodes(self) -> List[PermissionNode]:
|
||||
async def get_all_permission_nodes(self) -> list[PermissionNode]:
|
||||
"""
|
||||
获取所有已注册的权限节点
|
||||
|
||||
@@ -356,7 +356,7 @@ class PermissionManager(IPermissionManager):
|
||||
logger.error(f"获取所有权限节点时发生未知错误: {e}")
|
||||
return []
|
||||
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[PermissionNode]:
|
||||
"""
|
||||
获取指定插件的所有权限节点
|
||||
|
||||
@@ -431,7 +431,7 @@ class PermissionManager(IPermissionManager):
|
||||
logger.error(f"删除插件权限时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
async def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]:
|
||||
async def get_users_with_permission(self, permission_node: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
获取拥有指定权限的所有用户
|
||||
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import traceback
|
||||
import importlib
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Type, Any
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.plugin_base import PluginBase
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.plugin_system.base.plugin_base import PluginBase
|
||||
from src.plugin_system.utils.manifest_utils import VersionComparator
|
||||
from .component_registry import component_registry
|
||||
|
||||
from .component_registry import component_registry
|
||||
|
||||
logger = get_logger("plugin_manager")
|
||||
|
||||
@@ -26,12 +24,12 @@ class PluginManager:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.plugin_directories: List[str] = [] # 插件根目录列表
|
||||
self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
|
||||
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
|
||||
self.plugin_directories: list[str] = [] # 插件根目录列表
|
||||
self.plugin_classes: dict[str, type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
|
||||
self.plugin_paths: dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
|
||||
|
||||
self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
|
||||
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
|
||||
self.loaded_plugins: dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
|
||||
self.failed_plugins: dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
|
||||
|
||||
# 确保插件目录存在
|
||||
self._ensure_plugin_directories()
|
||||
@@ -54,7 +52,7 @@ class PluginManager:
|
||||
|
||||
# === 插件加载管理 ===
|
||||
|
||||
def load_all_plugins(self) -> Tuple[int, int]:
|
||||
def load_all_plugins(self) -> tuple[int, int]:
|
||||
"""加载所有插件
|
||||
|
||||
Returns:
|
||||
@@ -87,7 +85,7 @@ class PluginManager:
|
||||
|
||||
return total_registered, total_failed_registration
|
||||
|
||||
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
|
||||
def load_registered_plugin_classes(self, plugin_name: str) -> tuple[bool, int]:
|
||||
# sourcery skip: extract-duplicate-method, extract-method
|
||||
"""
|
||||
加载已经注册的插件类
|
||||
@@ -142,7 +140,7 @@ class PluginManager:
|
||||
|
||||
except FileNotFoundError as e:
|
||||
# manifest文件缺失
|
||||
error_msg = f"缺少manifest文件: {str(e)}"
|
||||
error_msg = f"缺少manifest文件: {e!s}"
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||
return False, 1
|
||||
@@ -150,14 +148,14 @@ class PluginManager:
|
||||
except ValueError as e:
|
||||
# manifest文件格式错误或验证失败
|
||||
traceback.print_exc()
|
||||
error_msg = f"manifest验证失败: {str(e)}"
|
||||
error_msg = f"manifest验证失败: {e!s}"
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||
return False, 1
|
||||
|
||||
except Exception as e:
|
||||
# 其他错误
|
||||
error_msg = f"未知错误: {str(e)}"
|
||||
error_msg = f"未知错误: {e!s}"
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||
logger.debug("详细错误信息: ", exc_info=True)
|
||||
@@ -192,7 +190,7 @@ class PluginManager:
|
||||
logger.debug(f"插件 {plugin_name} 重载成功")
|
||||
return True
|
||||
|
||||
def rescan_plugin_directory(self) -> Tuple[int, int]:
|
||||
def rescan_plugin_directory(self) -> tuple[int, int]:
|
||||
"""
|
||||
重新扫描插件根目录
|
||||
"""
|
||||
@@ -220,7 +218,7 @@ class PluginManager:
|
||||
return self.loaded_plugins.get(plugin_name)
|
||||
|
||||
# === 查询方法 ===
|
||||
def list_loaded_plugins(self) -> List[str]:
|
||||
def list_loaded_plugins(self) -> list[str]:
|
||||
"""
|
||||
列出所有当前加载的插件。
|
||||
|
||||
@@ -229,7 +227,7 @@ class PluginManager:
|
||||
"""
|
||||
return list(self.loaded_plugins.keys())
|
||||
|
||||
def list_registered_plugins(self) -> List[str]:
|
||||
def list_registered_plugins(self) -> list[str]:
|
||||
"""
|
||||
列出所有已注册的插件类。
|
||||
|
||||
@@ -238,7 +236,7 @@ class PluginManager:
|
||||
"""
|
||||
return list(self.plugin_classes.keys())
|
||||
|
||||
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
|
||||
def get_plugin_path(self, plugin_name: str) -> str | None:
|
||||
"""
|
||||
获取指定插件的路径。
|
||||
|
||||
@@ -329,7 +327,7 @@ class PluginManager:
|
||||
# == 兼容性检查 ==
|
||||
|
||||
@staticmethod
|
||||
def _check_plugin_version_compatibility(plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
def _check_plugin_version_compatibility(plugin_name: str, manifest_data: dict[str, Any]) -> tuple[bool, str]:
|
||||
"""检查插件版本兼容性
|
||||
|
||||
Args:
|
||||
@@ -569,7 +567,7 @@ class PluginManager:
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}", exc_info=True)
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name} - {e!s}", exc_info=True)
|
||||
return False
|
||||
|
||||
def reload_plugin(self, plugin_name: str) -> bool:
|
||||
@@ -606,7 +604,7 @@ class PluginManager:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}", exc_info=True)
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - {e!s}", exc_info=True)
|
||||
return False
|
||||
|
||||
def force_reload_plugin(self, plugin_name: str) -> bool:
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import inspect
|
||||
import time
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.cache_manager import tool_cache
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
import inspect
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
@@ -56,14 +57,14 @@ class ToolExecutor:
|
||||
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
||||
|
||||
# 二步工具调用状态管理
|
||||
self._pending_step_two_tools: Dict[str, Dict[str, Any]] = {}
|
||||
self._pending_step_two_tools: dict[str, dict[str, Any]] = {}
|
||||
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
|
||||
|
||||
logger.info(f"{self.log_prefix}工具执行器初始化完成")
|
||||
|
||||
async def execute_from_chat_message(
|
||||
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
|
||||
) -> Tuple[List[Dict[str, Any]], List[str], str]:
|
||||
) -> tuple[list[dict[str, Any]], list[str], str]:
|
||||
"""从聊天消息执行工具
|
||||
|
||||
Args:
|
||||
@@ -113,7 +114,7 @@ class ToolExecutor:
|
||||
else:
|
||||
return tool_results, [], ""
|
||||
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
def _get_tool_definitions(self) -> list[dict[str, Any]]:
|
||||
all_tools = get_llm_available_tool_definitions()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
|
||||
@@ -129,7 +130,7 @@ class ToolExecutor:
|
||||
|
||||
return tool_definitions
|
||||
|
||||
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
|
||||
"""执行工具调用
|
||||
|
||||
Args:
|
||||
@@ -138,7 +139,7 @@ class ToolExecutor:
|
||||
Returns:
|
||||
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
|
||||
"""
|
||||
tool_results: List[Dict[str, Any]] = []
|
||||
tool_results: list[dict[str, Any]] = []
|
||||
used_tools = []
|
||||
|
||||
if not tool_calls:
|
||||
@@ -192,7 +193,7 @@ class ToolExecutor:
|
||||
error_info = {
|
||||
"type": "tool_error",
|
||||
"id": f"tool_error_{time.time()}",
|
||||
"content": f"工具{tool_name}执行失败: {str(e)}",
|
||||
"content": f"工具{tool_name}执行失败: {e!s}",
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
@@ -201,8 +202,8 @@ class ToolExecutor:
|
||||
return tool_results, used_tools
|
||||
|
||||
async def execute_tool_call(
|
||||
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""执行单个工具调用,并处理缓存"""
|
||||
|
||||
function_args = tool_call.args or {}
|
||||
@@ -256,8 +257,8 @@ class ToolExecutor:
|
||||
return result
|
||||
|
||||
async def _original_execute_tool_call(
|
||||
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""执行单个工具调用的原始逻辑"""
|
||||
try:
|
||||
function_name = tool_call.func_name
|
||||
@@ -323,10 +324,10 @@ class ToolExecutor:
|
||||
logger.warning(f"{self.log_prefix}工具 {function_name} 返回空结果")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
logger.error(f"执行工具调用时发生错误: {e!s}")
|
||||
raise e
|
||||
|
||||
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
|
||||
async def execute_specific_tool_simple(self, tool_name: str, tool_args: dict) -> dict | None:
|
||||
"""直接执行指定工具
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
本模块包含一个从Python包的“安装名”到其“导入名”的映射。
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("dependency_config")
|
||||
@@ -66,7 +65,7 @@ class DependencyConfig:
|
||||
|
||||
|
||||
# 全局配置实例
|
||||
_global_dependency_config: Optional[DependencyConfig] = None
|
||||
_global_dependency_config: DependencyConfig | None = None
|
||||
|
||||
|
||||
def get_dependency_config() -> DependencyConfig:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import importlib
|
||||
import importlib.util
|
||||
from typing import List, Tuple, Optional, Any
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from packaging import version
|
||||
from packaging.requirements import Requirement
|
||||
|
||||
@@ -19,7 +20,7 @@ class DependencyManager:
|
||||
负责检查和自动安装插件的Python包依赖
|
||||
"""
|
||||
|
||||
def __init__(self, auto_install: bool = True, use_mirror: bool = False, mirror_url: Optional[str] = None):
|
||||
def __init__(self, auto_install: bool = True, use_mirror: bool = False, mirror_url: str | None = None):
|
||||
"""初始化依赖管理器
|
||||
|
||||
Args:
|
||||
@@ -46,7 +47,7 @@ class DependencyManager:
|
||||
self.mirror_url = mirror_url or ""
|
||||
self.install_timeout = 300
|
||||
|
||||
def check_dependencies(self, dependencies: Any, plugin_name: str = "") -> Tuple[bool, List[str], List[str]]:
|
||||
def check_dependencies(self, dependencies: Any, plugin_name: str = "") -> tuple[bool, list[str], list[str]]:
|
||||
"""检查依赖包是否满足要求
|
||||
|
||||
Args:
|
||||
@@ -69,7 +70,7 @@ class DependencyManager:
|
||||
logger.info(f"{log_prefix}缺少依赖包: {dep.get_pip_requirement()}")
|
||||
missing_packages.append(dep.get_pip_requirement())
|
||||
except Exception as e:
|
||||
error_msg = f"检查依赖 {dep.package_name} 时发生错误: {str(e)}"
|
||||
error_msg = f"检查依赖 {dep.package_name} 时发生错误: {e!s}"
|
||||
error_messages.append(error_msg)
|
||||
logger.error(f"{log_prefix}{error_msg}")
|
||||
|
||||
@@ -84,7 +85,7 @@ class DependencyManager:
|
||||
|
||||
return all_satisfied, missing_packages, error_messages
|
||||
|
||||
def install_dependencies(self, packages: List[str], plugin_name: str = "") -> Tuple[bool, List[str]]:
|
||||
def install_dependencies(self, packages: list[str], plugin_name: str = "") -> tuple[bool, list[str]]:
|
||||
"""自动安装缺失的依赖包
|
||||
|
||||
Args:
|
||||
@@ -115,7 +116,7 @@ class DependencyManager:
|
||||
logger.error(f"{log_prefix}❌ 安装失败: {package}")
|
||||
except Exception as e:
|
||||
failed_packages.append(package)
|
||||
logger.error(f"{log_prefix}❌ 安装 {package} 时发生异常: {str(e)}")
|
||||
logger.error(f"{log_prefix}❌ 安装 {package} 时发生异常: {e!s}")
|
||||
|
||||
success = len(failed_packages) == 0
|
||||
if success:
|
||||
@@ -125,7 +126,7 @@ class DependencyManager:
|
||||
|
||||
return success, failed_packages
|
||||
|
||||
def check_and_install_dependencies(self, dependencies: Any, plugin_name: str = "") -> Tuple[bool, List[str]]:
|
||||
def check_and_install_dependencies(self, dependencies: Any, plugin_name: str = "") -> tuple[bool, list[str]]:
|
||||
"""检查并自动安装依赖(组合操作)
|
||||
|
||||
Args:
|
||||
@@ -163,7 +164,7 @@ class DependencyManager:
|
||||
return False, all_errors
|
||||
|
||||
@staticmethod
|
||||
def _normalize_dependencies(dependencies: Any) -> List[PythonDependency]:
|
||||
def _normalize_dependencies(dependencies: Any) -> list[PythonDependency]:
|
||||
"""将依赖列表标准化为PythonDependency对象"""
|
||||
normalized = []
|
||||
|
||||
@@ -277,7 +278,7 @@ class DependencyManager:
|
||||
|
||||
|
||||
# 全局依赖管理器实例
|
||||
_global_dependency_manager: Optional[DependencyManager] = None
|
||||
_global_dependency_manager: DependencyManager | None = None
|
||||
|
||||
|
||||
def get_dependency_manager() -> DependencyManager:
|
||||
@@ -288,7 +289,7 @@ def get_dependency_manager() -> DependencyManager:
|
||||
return _global_dependency_manager
|
||||
|
||||
|
||||
def configure_dependency_manager(auto_install: bool = True, use_mirror: bool = False, mirror_url: Optional[str] = None):
|
||||
def configure_dependency_manager(auto_install: bool = True, use_mirror: bool = False, mirror_url: str | None = None):
|
||||
"""配置全局依赖管理器"""
|
||||
global _global_dependency_manager
|
||||
_global_dependency_manager = DependencyManager(
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any, Tuple
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import MMC_VERSION
|
||||
|
||||
@@ -70,7 +71,7 @@ class VersionComparator:
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def parse_version(version: str) -> Tuple[int, int, int]:
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
"""解析版本号为元组
|
||||
|
||||
Args:
|
||||
@@ -109,7 +110,7 @@ class VersionComparator:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def check_forward_compatibility(current_version: str, max_version: str) -> Tuple[bool, str]:
|
||||
def check_forward_compatibility(current_version: str, max_version: str) -> tuple[bool, str]:
|
||||
"""检查向前兼容性(仅使用兼容性映射表)
|
||||
|
||||
Args:
|
||||
@@ -131,7 +132,7 @@ class VersionComparator:
|
||||
return False, ""
|
||||
|
||||
@staticmethod
|
||||
def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]:
|
||||
def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> tuple[bool, str]:
|
||||
"""检查版本是否在指定范围内,支持兼容性检查
|
||||
|
||||
Args:
|
||||
@@ -195,7 +196,7 @@ class VersionComparator:
|
||||
logger.info(f"添加兼容性映射:{base_normalized} -> {compatible_versions}")
|
||||
|
||||
@staticmethod
|
||||
def get_compatibility_info() -> Dict[str, list]:
|
||||
def get_compatibility_info() -> dict[str, list]:
|
||||
"""获取当前的兼容性映射表
|
||||
|
||||
Returns:
|
||||
@@ -232,7 +233,7 @@ class ManifestValidator:
|
||||
self.validation_errors = []
|
||||
self.validation_warnings = []
|
||||
|
||||
def validate_manifest(self, manifest_data: Dict[str, Any]) -> bool:
|
||||
def validate_manifest(self, manifest_data: dict[str, Any]) -> bool:
|
||||
"""验证manifest数据
|
||||
|
||||
Args:
|
||||
@@ -266,7 +267,7 @@ class ManifestValidator:
|
||||
if "name" not in author or not author["name"]:
|
||||
self.validation_errors.append("作者信息缺少name字段或为空")
|
||||
# url字段是可选的
|
||||
if "url" in author and author["url"]:
|
||||
if author.get("url"):
|
||||
url = author["url"]
|
||||
if not (url.startswith("http://") or url.startswith("https://")):
|
||||
self.validation_warnings.append("作者URL建议使用完整的URL格式")
|
||||
@@ -305,7 +306,7 @@ class ManifestValidator:
|
||||
|
||||
# 检查URL格式(可选字段)
|
||||
for url_field in ["homepage_url", "repository_url"]:
|
||||
if url_field in manifest_data and manifest_data[url_field]:
|
||||
if manifest_data.get(url_field):
|
||||
url: str = manifest_data[url_field]
|
||||
if not (url.startswith("http://") or url.startswith("https://")):
|
||||
self.validation_warnings.append(f"{url_field}建议使用完整的URL格式")
|
||||
|
||||
@@ -4,19 +4,19 @@
|
||||
提供方便的权限检查装饰器,用于插件命令和其他需要权限验证的地方。
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
from inspect import iscoroutinefunction
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.plugin_system.apis.logging_api import get_logger
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.apis.send_api import text_to_stream
|
||||
from src.plugin_system.apis.logging_api import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def require_permission(permission_node: str, deny_message: Optional[str] = None):
|
||||
def require_permission(permission_node: str, deny_message: str | None = None):
|
||||
"""
|
||||
权限检查装饰器
|
||||
|
||||
@@ -90,7 +90,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None)
|
||||
return decorator
|
||||
|
||||
|
||||
def require_master(deny_message: Optional[str] = None):
|
||||
def require_master(deny_message: str | None = None):
|
||||
"""
|
||||
Master权限检查装饰器
|
||||
|
||||
@@ -186,9 +186,7 @@ class PermissionChecker:
|
||||
return permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id)
|
||||
|
||||
@staticmethod
|
||||
async def ensure_permission(
|
||||
chat_stream: ChatStream, permission_node: str, deny_message: Optional[str] = None
|
||||
) -> bool:
|
||||
async def ensure_permission(chat_stream: ChatStream, permission_node: str, deny_message: str | None = None) -> bool:
|
||||
"""
|
||||
确保用户拥有指定权限,如果没有权限会发送消息并返回False
|
||||
|
||||
@@ -209,7 +207,7 @@ class PermissionChecker:
|
||||
return has_permission
|
||||
|
||||
@staticmethod
|
||||
async def ensure_master(chat_stream: ChatStream, deny_message: Optional[str] = None) -> bool:
|
||||
async def ensure_master(chat_stream: ChatStream, deny_message: str | None = None) -> bool:
|
||||
"""
|
||||
确保用户为Master用户,如果不是会发送消息并返回False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user