re-style: 格式化代码
This commit is contained in:
@@ -14,14 +14,15 @@ from src.plugin_system.apis import (
|
||||
generator_api,
|
||||
llm_api,
|
||||
message_api,
|
||||
permission_api,
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
schedule_api,
|
||||
send_api,
|
||||
tool_api,
|
||||
permission_api,
|
||||
schedule_api,
|
||||
)
|
||||
from src.plugin_system.apis.chat_api import ChatManager as context_api
|
||||
|
||||
from .logging_api import get_logger
|
||||
from .plugin_register_api import register_plugin
|
||||
|
||||
@@ -30,18 +31,18 @@ __all__ = [
|
||||
"chat_api",
|
||||
"component_manage_api",
|
||||
"config_api",
|
||||
"context_api",
|
||||
"database_api",
|
||||
"emoji_api",
|
||||
"generator_api",
|
||||
"get_logger",
|
||||
"llm_api",
|
||||
"message_api",
|
||||
"permission_api",
|
||||
"person_api",
|
||||
"plugin_manage_api",
|
||||
"send_api",
|
||||
"get_logger",
|
||||
"register_plugin",
|
||||
"tool_api",
|
||||
"permission_api",
|
||||
"context_api",
|
||||
"schedule_api",
|
||||
"send_api",
|
||||
"tool_api",
|
||||
]
|
||||
|
||||
@@ -12,11 +12,11 @@
|
||||
streams = chat.get_all_group_streams()
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("chat_api")
|
||||
|
||||
@@ -31,7 +31,7 @@ class ChatManager:
|
||||
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
|
||||
|
||||
@staticmethod
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有聊天流
|
||||
|
||||
@@ -57,7 +57,7 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有群聊聊天流
|
||||
|
||||
@@ -80,7 +80,7 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有私聊聊天流
|
||||
|
||||
@@ -107,8 +107,8 @@ class ChatManager:
|
||||
|
||||
@staticmethod
|
||||
def get_group_stream_by_group_id(
|
||||
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
|
||||
group_id: str, platform: str | None | SpecialTypes = "qq"
|
||||
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast
|
||||
"""根据群ID获取聊天流
|
||||
|
||||
Args:
|
||||
@@ -144,8 +144,8 @@ class ChatManager:
|
||||
|
||||
@staticmethod
|
||||
def get_private_stream_by_user_id(
|
||||
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
|
||||
user_id: str, platform: str | None | SpecialTypes = "qq"
|
||||
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast
|
||||
"""根据用户ID获取私聊流
|
||||
|
||||
Args:
|
||||
@@ -203,7 +203,7 @@ class ChatManager:
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]:
|
||||
"""获取聊天流详细信息
|
||||
|
||||
Args:
|
||||
@@ -222,7 +222,7 @@ class ChatManager:
|
||||
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||
|
||||
try:
|
||||
info: Dict[str, Any] = {
|
||||
info: dict[str, Any] = {
|
||||
"stream_id": chat_stream.stream_id,
|
||||
"platform": chat_stream.platform,
|
||||
"type": ChatManager.get_stream_type(chat_stream),
|
||||
@@ -250,7 +250,7 @@ class ChatManager:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_streams_summary() -> Dict[str, int]:
|
||||
def get_streams_summary() -> dict[str, int]:
|
||||
"""获取聊天流统计摘要
|
||||
|
||||
Returns:
|
||||
@@ -285,27 +285,27 @@ class ChatManager:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
"""获取所有聊天流的便捷函数"""
|
||||
return ChatManager.get_all_streams(platform)
|
||||
|
||||
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
"""获取群聊聊天流的便捷函数"""
|
||||
return ChatManager.get_group_streams(platform)
|
||||
|
||||
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
"""获取私聊聊天流的便捷函数"""
|
||||
return ChatManager.get_private_streams(platform)
|
||||
|
||||
|
||||
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
|
||||
def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None:
|
||||
"""根据群ID获取聊天流的便捷函数"""
|
||||
return ChatManager.get_group_stream_by_group_id(group_id, platform)
|
||||
|
||||
|
||||
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
|
||||
def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None:
|
||||
"""根据用户ID获取私聊流的便捷函数"""
|
||||
return ChatManager.get_private_stream_by_user_id(user_id, platform)
|
||||
|
||||
@@ -315,11 +315,11 @@ def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
return ChatManager.get_stream_type(chat_stream)
|
||||
|
||||
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]:
|
||||
"""获取聊天流信息的便捷函数"""
|
||||
return ChatManager.get_stream_info(chat_stream)
|
||||
|
||||
|
||||
def get_streams_summary() -> Dict[str, int]:
|
||||
def get_streams_summary() -> dict[str, int]:
|
||||
"""获取聊天流统计摘要的便捷函数"""
|
||||
return ChatManager.get_streams_summary()
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
from typing import Optional, Union, Dict
|
||||
from src.plugin_system.base.component_types import (
|
||||
CommandInfo,
|
||||
ActionInfo,
|
||||
CommandInfo,
|
||||
ComponentType,
|
||||
EventHandlerInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
ToolInfo,
|
||||
)
|
||||
|
||||
|
||||
# === 插件信息查询 ===
|
||||
def get_all_plugin_info() -> Dict[str, PluginInfo]:
|
||||
def get_all_plugin_info() -> dict[str, PluginInfo]:
|
||||
"""
|
||||
获取所有插件的信息。
|
||||
|
||||
@@ -22,7 +21,7 @@ def get_all_plugin_info() -> Dict[str, PluginInfo]:
|
||||
return component_registry.get_all_plugins()
|
||||
|
||||
|
||||
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
|
||||
def get_plugin_info(plugin_name: str) -> PluginInfo | None:
|
||||
"""
|
||||
获取指定插件的信息。
|
||||
|
||||
@@ -40,7 +39,7 @@ def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
component_name: str, component_type: ComponentType
|
||||
) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
) -> CommandInfo | ActionInfo | EventHandlerInfo | None:
|
||||
"""
|
||||
获取指定组件的信息。
|
||||
|
||||
@@ -57,7 +56,7 @@ def get_component_info(
|
||||
|
||||
def get_components_info_by_type(
|
||||
component_type: ComponentType,
|
||||
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
) -> dict[str, CommandInfo | ActionInfo | EventHandlerInfo]:
|
||||
"""
|
||||
获取指定类型的所有组件信息。
|
||||
|
||||
@@ -74,7 +73,7 @@ def get_components_info_by_type(
|
||||
|
||||
def get_enabled_components_info_by_type(
|
||||
component_type: ComponentType,
|
||||
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
) -> dict[str, CommandInfo | ActionInfo | EventHandlerInfo]:
|
||||
"""
|
||||
获取指定类型的所有启用的组件信息。
|
||||
|
||||
@@ -90,7 +89,7 @@ def get_enabled_components_info_by_type(
|
||||
|
||||
|
||||
# === Action 查询方法 ===
|
||||
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
|
||||
def get_registered_action_info(action_name: str) -> ActionInfo | None:
|
||||
"""
|
||||
获取指定 Action 的注册信息。
|
||||
|
||||
@@ -105,7 +104,7 @@ def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
|
||||
return component_registry.get_registered_action_info(action_name)
|
||||
|
||||
|
||||
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
|
||||
def get_registered_command_info(command_name: str) -> CommandInfo | None:
|
||||
"""
|
||||
获取指定 Command 的注册信息。
|
||||
|
||||
@@ -120,7 +119,7 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
|
||||
return component_registry.get_registered_command_info(command_name)
|
||||
|
||||
|
||||
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
|
||||
def get_registered_tool_info(tool_name: str) -> ToolInfo | None:
|
||||
"""
|
||||
获取指定 Tool 的注册信息。
|
||||
|
||||
@@ -138,7 +137,7 @@ def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
|
||||
# === EventHandler 特定查询方法 ===
|
||||
def get_registered_event_handler_info(
|
||||
event_handler_name: str,
|
||||
) -> Optional[EventHandlerInfo]:
|
||||
) -> EventHandlerInfo | None:
|
||||
"""
|
||||
获取指定 EventHandler 的注册信息。
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -3,20 +3,20 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Any, Optional, List
|
||||
from typing import Any
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
build_readable_messages_with_id,
|
||||
)
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
|
||||
logger = get_logger("cross_context_api")
|
||||
|
||||
|
||||
def get_context_groups(chat_id: str) -> Optional[List[List[str]]]:
|
||||
def get_context_groups(chat_id: str) -> list[list[str]] | None:
|
||||
"""
|
||||
获取当前聊天所在的共享组的其他聊天ID
|
||||
"""
|
||||
@@ -41,7 +41,7 @@ def get_context_groups(chat_id: str) -> Optional[List[List[str]]]:
|
||||
return None
|
||||
|
||||
|
||||
async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: List[List[str]]) -> str:
|
||||
async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: list[list[str]]) -> str:
|
||||
"""
|
||||
构建跨群聊/私聊上下文 (Normal模式)
|
||||
"""
|
||||
@@ -74,8 +74,8 @@ async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos:
|
||||
|
||||
async def build_cross_context_s4u(
|
||||
chat_stream: ChatStream,
|
||||
other_chat_infos: List[List[str]],
|
||||
target_user_info: Optional[Dict[str, Any]],
|
||||
other_chat_infos: list[list[str]],
|
||||
target_user_info: dict[str, Any] | None,
|
||||
) -> str:
|
||||
"""
|
||||
构建跨群聊/私聊上下文 (S4U模式)
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
注意:此模块现在使用SQLAlchemy实现,提供更好的连接管理和错误处理
|
||||
"""
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get, store_action_info, MODEL_MAPPING
|
||||
from src.common.database.sqlalchemy_database_api import MODEL_MAPPING, db_get, db_query, db_save, store_action_info
|
||||
|
||||
# 保持向后兼容性
|
||||
__all__ = ["db_query", "db_save", "db_get", "store_action_info", "MODEL_MAPPING"]
|
||||
__all__ = ["MODEL_MAPPING", "db_get", "db_query", "db_save", "store_action_info"]
|
||||
|
||||
@@ -10,10 +10,9 @@
|
||||
|
||||
import random
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("emoji_api")
|
||||
|
||||
@@ -23,7 +22,7 @@ logger = get_logger("emoji_api")
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
|
||||
async def get_by_description(description: str) -> tuple[str, str, str] | None:
|
||||
"""根据描述选择表情包
|
||||
|
||||
Args:
|
||||
@@ -65,7 +64,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
|
||||
return None
|
||||
|
||||
|
||||
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||
async def get_random(count: int | None = 1) -> list[tuple[str, str, str]]:
|
||||
"""随机获取指定数量的表情包
|
||||
|
||||
Args:
|
||||
@@ -137,7 +136,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||
return []
|
||||
|
||||
|
||||
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
async def get_by_emotion(emotion: str) -> tuple[str, str, str] | None:
|
||||
"""根据情感标签获取表情包
|
||||
|
||||
Args:
|
||||
@@ -227,7 +226,7 @@ def get_info():
|
||||
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
||||
|
||||
|
||||
def get_emotions() -> List[str]:
|
||||
def get_emotions() -> list[str]:
|
||||
"""获取所有可用的情感标签
|
||||
|
||||
Returns:
|
||||
@@ -247,7 +246,7 @@ def get_emotions() -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def get_descriptions() -> List[str]:
|
||||
def get_descriptions() -> list[str]:
|
||||
"""获取所有表情包描述
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -9,13 +9,15 @@
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Tuple, Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ActionInfo
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -30,10 +32,10 @@ logger = get_logger("generator_api")
|
||||
|
||||
|
||||
def get_replyer(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
chat_stream: ChatStream | None = None,
|
||||
chat_id: str | None = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer]:
|
||||
) -> DefaultReplyer | None:
|
||||
"""获取回复器对象
|
||||
|
||||
优先使用chat_stream,如果没有则使用chat_id直接查找。
|
||||
@@ -71,13 +73,13 @@ def get_replyer(
|
||||
|
||||
|
||||
async def generate_reply(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
action_data: Optional[Dict[str, Any]] = None,
|
||||
chat_stream: ChatStream | None = None,
|
||||
chat_id: str | None = None,
|
||||
action_data: dict[str, Any] | None = None,
|
||||
reply_to: str = "",
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
reply_message: dict[str, Any] | None = None,
|
||||
extra_info: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
available_actions: dict[str, ActionInfo] | None = None,
|
||||
enable_tool: bool = False,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
@@ -85,7 +87,7 @@ async def generate_reply(
|
||||
request_type: str = "generator_api",
|
||||
from_plugin: bool = True,
|
||||
read_mark: float = 0.0,
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
) -> tuple[bool, list[tuple[str, Any]], str | None]:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
@@ -168,9 +170,9 @@ async def generate_reply(
|
||||
|
||||
|
||||
async def rewrite_reply(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
reply_data: Optional[Dict[str, Any]] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
chat_stream: ChatStream | None = None,
|
||||
reply_data: dict[str, Any] | None = None,
|
||||
chat_id: str | None = None,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
raw_reply: str = "",
|
||||
@@ -178,7 +180,7 @@ async def rewrite_reply(
|
||||
reply_to: str = "",
|
||||
return_prompt: bool = False,
|
||||
request_type: str = "generator_api",
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
) -> tuple[bool, list[tuple[str, Any]], str | None]:
|
||||
"""重写回复
|
||||
|
||||
Args:
|
||||
@@ -237,7 +239,7 @@ async def rewrite_reply(
|
||||
return False, [], None
|
||||
|
||||
|
||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
|
||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> list[tuple[str, Any]]:
|
||||
"""将文本处理为更拟人化的文本
|
||||
|
||||
Args:
|
||||
@@ -266,11 +268,11 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
|
||||
|
||||
|
||||
async def generate_response_custom(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
chat_stream: ChatStream | None = None,
|
||||
chat_id: str | None = None,
|
||||
request_type: str = "generator_api",
|
||||
prompt: str = "",
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""
|
||||
使用自定义提示生成回复
|
||||
|
||||
|
||||
@@ -7,12 +7,13 @@
|
||||
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
|
||||
"""
|
||||
|
||||
from typing import Tuple, Dict, List, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
|
||||
logger = get_logger("llm_api")
|
||||
|
||||
@@ -21,7 +22,7 @@ logger = get_logger("llm_api")
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, TaskConfig]:
|
||||
def get_available_models() -> dict[str, TaskConfig]:
|
||||
"""获取所有可用的模型配置
|
||||
|
||||
Returns:
|
||||
@@ -31,7 +32,7 @@ def get_available_models() -> Dict[str, TaskConfig]:
|
||||
# 自动获取所有属性并转换为字典形式
|
||||
models = model_config.model_task_config
|
||||
attrs = dir(models)
|
||||
rets: Dict[str, TaskConfig] = {}
|
||||
rets: dict[str, TaskConfig] = {}
|
||||
for attr in attrs:
|
||||
if not attr.startswith("__"):
|
||||
try:
|
||||
@@ -52,9 +53,9 @@ async def generate_with_model(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str]:
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[bool, str, str, str]:
|
||||
"""使用指定模型生成内容
|
||||
|
||||
Args:
|
||||
@@ -78,7 +79,7 @@ async def generate_with_model(
|
||||
return True, response, reasoning_content, model_name
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
error_msg = f"生成内容时出错: {e!s}"
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", ""
|
||||
|
||||
@@ -86,11 +87,11 @@ async def generate_with_model(
|
||||
async def generate_with_model_with_tools(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
tool_options: List[Dict[str, Any]] | None = None,
|
||||
tool_options: list[dict[str, Any]] | None = None,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[bool, str, str, str, list[ToolCall] | None]:
|
||||
"""使用指定模型和工具生成内容
|
||||
|
||||
Args:
|
||||
@@ -117,6 +118,6 @@ async def generate_with_model_with_tools(
|
||||
return True, response, reasoning_content, model_name, tool_call
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
error_msg = f"生成内容时出错: {e!s}"
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", "", None
|
||||
|
||||
@@ -8,26 +8,26 @@
|
||||
readable_text = message_api.build_readable_messages(messages)
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
get_raw_msg_by_timestamp_with_chat_users,
|
||||
get_raw_msg_by_timestamp_random,
|
||||
get_raw_msg_by_timestamp_with_users,
|
||||
get_raw_msg_before_timestamp,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
get_raw_msg_before_timestamp_with_users,
|
||||
num_new_messages_since,
|
||||
num_new_messages_since_with_users,
|
||||
build_readable_messages,
|
||||
build_readable_messages_with_list,
|
||||
get_person_id_list,
|
||||
get_raw_msg_before_timestamp,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
get_raw_msg_before_timestamp_with_users,
|
||||
get_raw_msg_by_timestamp,
|
||||
get_raw_msg_by_timestamp_random,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
get_raw_msg_by_timestamp_with_chat_users,
|
||||
get_raw_msg_by_timestamp_with_users,
|
||||
num_new_messages_since,
|
||||
num_new_messages_since_with_users,
|
||||
)
|
||||
|
||||
from src.config.config import global_config
|
||||
|
||||
# =============================================================================
|
||||
# 消息查询API函数
|
||||
@@ -36,7 +36,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
|
||||
async def get_messages_by_time(
|
||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定时间范围内的消息
|
||||
|
||||
@@ -70,7 +70,7 @@ async def get_messages_by_time_in_chat(
|
||||
limit_mode: str = "latest",
|
||||
filter_mai: bool = False,
|
||||
filter_command: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定时间范围内的消息
|
||||
|
||||
@@ -111,7 +111,7 @@ async def get_messages_by_time_in_chat_inclusive(
|
||||
limit_mode: str = "latest",
|
||||
filter_mai: bool = False,
|
||||
filter_command: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定时间范围内的消息(包含边界)
|
||||
|
||||
@@ -152,10 +152,10 @@ async def get_messages_by_time_in_chat_for_users(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
person_ids: List[str],
|
||||
person_ids: list[str],
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定用户在指定时间范围内的消息
|
||||
|
||||
@@ -186,7 +186,7 @@ async def get_messages_by_time_in_chat_for_users(
|
||||
|
||||
async def get_random_chat_messages(
|
||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
随机选择一个聊天,返回该聊天在指定时间范围内的消息
|
||||
|
||||
@@ -213,8 +213,8 @@ async def get_random_chat_messages(
|
||||
|
||||
|
||||
async def get_messages_by_time_for_users(
|
||||
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
start_time: float, end_time: float, person_ids: list[str], limit: int = 0, limit_mode: str = "latest"
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户在所有聊天中指定时间范围内的消息
|
||||
|
||||
@@ -238,7 +238,7 @@ async def get_messages_by_time_for_users(
|
||||
return await get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
|
||||
|
||||
|
||||
async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]:
|
||||
async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定时间戳之前的消息
|
||||
|
||||
@@ -294,8 +294,8 @@ async def get_messages_before_time_in_chat(
|
||||
|
||||
|
||||
async def get_messages_before_time_for_users(
|
||||
timestamp: float, person_ids: List[str], limit: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
timestamp: float, person_ids: list[str], limit: int = 0
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户在指定时间戳之前的消息
|
||||
|
||||
@@ -319,7 +319,7 @@ async def get_messages_before_time_for_users(
|
||||
|
||||
async def get_recent_messages(
|
||||
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中最近一段时间的消息
|
||||
|
||||
@@ -358,7 +358,7 @@ async def get_recent_messages(
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
|
||||
async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: float | None = None) -> int:
|
||||
"""
|
||||
计算指定聊天中从开始时间到结束时间的新消息数量
|
||||
|
||||
@@ -382,7 +382,7 @@ async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Op
|
||||
return await num_new_messages_since(chat_id, start_time, end_time)
|
||||
|
||||
|
||||
async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
|
||||
async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: list[str]) -> int:
|
||||
"""
|
||||
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
|
||||
|
||||
@@ -413,7 +413,7 @@ async def count_new_messages_for_users(chat_id: str, start_time: float, end_time
|
||||
|
||||
|
||||
async def build_readable_messages_to_str(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
@@ -442,12 +442,12 @@ async def build_readable_messages_to_str(
|
||||
|
||||
|
||||
async def build_readable_messages_with_details(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
) -> tuple[str, list[tuple[float, str, str]]]:
|
||||
"""
|
||||
将消息列表构建成可读的字符串,并返回详细信息
|
||||
|
||||
@@ -464,7 +464,7 @@ async def build_readable_messages_with_details(
|
||||
return await build_readable_messages_with_list(messages, replace_bot_name, merge_messages, timestamp_mode, truncate)
|
||||
|
||||
|
||||
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
async def get_person_ids_from_messages(messages: list[dict[str, Any]]) -> list[str]:
|
||||
"""
|
||||
从消息列表中提取不重复的用户ID列表
|
||||
|
||||
@@ -482,7 +482,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
async def filter_mai_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
从消息列表中移除麦麦的消息
|
||||
Args:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""纯异步权限API定义。所有外部调用方必须使用 await。"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -48,18 +48,18 @@ class IPermissionManager(ABC):
|
||||
async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_permissions(self, user: UserInfo) -> List[str]: ...
|
||||
async def get_user_permissions(self, user: UserInfo) -> list[str]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_permission_nodes(self) -> List[PermissionNode]: ...
|
||||
async def get_all_permission_nodes(self) -> list[PermissionNode]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: ...
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[PermissionNode]: ...
|
||||
|
||||
|
||||
class PermissionAPI:
|
||||
def __init__(self):
|
||||
self._permission_manager: Optional[IPermissionManager] = None
|
||||
self._permission_manager: IPermissionManager | None = None
|
||||
# 需要保留的前缀(视为绝对节点名,不再自动加 plugins.<plugin>. 前缀)
|
||||
self.RESERVED_PREFIXES: tuple[str, ...] = "system."
|
||||
# 系统节点列表 (name, description, default_granted)
|
||||
@@ -147,11 +147,11 @@ class PermissionAPI:
|
||||
self._ensure_manager()
|
||||
return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node)
|
||||
|
||||
async def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
|
||||
async def get_user_permissions(self, platform: str, user_id: str) -> list[str]:
|
||||
self._ensure_manager()
|
||||
return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id))
|
||||
|
||||
async def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
|
||||
async def get_all_permission_nodes(self) -> list[dict[str, Any]]:
|
||||
self._ensure_manager()
|
||||
nodes = await self._permission_manager.get_all_permission_nodes()
|
||||
return [
|
||||
@@ -164,7 +164,7 @@ class PermissionAPI:
|
||||
for n in nodes
|
||||
]
|
||||
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[dict[str, Any]]:
|
||||
self._ensure_manager()
|
||||
nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name)
|
||||
return [
|
||||
|
||||
@@ -7,9 +7,10 @@
|
||||
value = await person_api.get_person_value(person_id, "nickname")
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
logger = get_logger("person_api")
|
||||
|
||||
@@ -63,7 +64,7 @@ async def get_person_value(person_id: str, field_name: str, default: Any = None)
|
||||
return default
|
||||
|
||||
|
||||
async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict:
|
||||
async def get_person_values(person_id: str, field_names: list, default_dict: dict | None = None) -> dict:
|
||||
"""批量获取用户信息字段值
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from typing import Tuple, List
|
||||
|
||||
|
||||
def list_loaded_plugins() -> List[str]:
|
||||
def list_loaded_plugins() -> list[str]:
|
||||
"""
|
||||
列出所有当前加载的插件。
|
||||
|
||||
@@ -13,7 +10,7 @@ def list_loaded_plugins() -> List[str]:
|
||||
return plugin_manager.list_loaded_plugins()
|
||||
|
||||
|
||||
def list_registered_plugins() -> List[str]:
|
||||
def list_registered_plugins() -> list[str]:
|
||||
"""
|
||||
列出所有已注册的插件。
|
||||
|
||||
@@ -80,7 +77,7 @@ async def reload_plugin(plugin_name: str) -> bool:
|
||||
return await plugin_manager.reload_registered_plugin(plugin_name)
|
||||
|
||||
|
||||
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
|
||||
def load_plugin(plugin_name: str) -> tuple[bool, int]:
|
||||
"""
|
||||
加载指定的插件。
|
||||
|
||||
@@ -109,7 +106,7 @@ def add_plugin_directory(plugin_directory: str) -> bool:
|
||||
return plugin_manager.add_plugin_directory(plugin_directory)
|
||||
|
||||
|
||||
def rescan_plugin_directory() -> Tuple[int, int]:
|
||||
def rescan_plugin_directory() -> tuple[int, int]:
|
||||
"""
|
||||
重新扫描插件目录,加载新插件。
|
||||
Returns:
|
||||
|
||||
@@ -6,8 +6,8 @@ logger = get_logger("plugin_manager") # 复用plugin_manager名称
|
||||
|
||||
|
||||
def register_plugin(cls):
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
"""插件注册装饰器
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.sqlalchemy_models import MonthlyPlan
|
||||
from src.common.logger import get_logger
|
||||
@@ -44,7 +44,7 @@ class ScheduleAPI:
|
||||
"""日程表与月度计划API - 负责日程和计划信息的查询与管理"""
|
||||
|
||||
@staticmethod
|
||||
async def get_today_schedule() -> Optional[List[Dict[str, Any]]]:
|
||||
async def get_today_schedule() -> list[dict[str, Any]] | None:
|
||||
"""(异步) 获取今天的日程安排
|
||||
|
||||
Returns:
|
||||
@@ -58,7 +58,7 @@ class ScheduleAPI:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_current_activity() -> Optional[str]:
|
||||
async def get_current_activity() -> str | None:
|
||||
"""(异步) 获取当前正在进行的活动
|
||||
|
||||
Returns:
|
||||
@@ -87,7 +87,7 @@ class ScheduleAPI:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]:
|
||||
async def get_monthly_plans(target_month: str | None = None) -> list[MonthlyPlan]:
|
||||
"""(异步) 获取指定月份的有效月度计划
|
||||
|
||||
Args:
|
||||
@@ -106,7 +106,7 @@ class ScheduleAPI:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool:
|
||||
async def ensure_monthly_plans(target_month: str | None = None) -> bool:
|
||||
"""(异步) 确保指定月份存在月度计划,如果不存在则触发生成
|
||||
|
||||
Args:
|
||||
@@ -125,7 +125,7 @@ class ScheduleAPI:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def archive_monthly_plans(target_month: Optional[str] = None) -> bool:
|
||||
async def archive_monthly_plans(target_month: str | None = None) -> bool:
|
||||
"""(异步) 归档指定月份的月度计划
|
||||
|
||||
Args:
|
||||
@@ -150,12 +150,12 @@ class ScheduleAPI:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_today_schedule() -> Optional[List[Dict[str, Any]]]:
|
||||
async def get_today_schedule() -> list[dict[str, Any]] | None:
|
||||
"""(异步) 获取今天的日程安排的便捷函数"""
|
||||
return await ScheduleAPI.get_today_schedule()
|
||||
|
||||
|
||||
async def get_current_activity() -> Optional[str]:
|
||||
async def get_current_activity() -> str | None:
|
||||
"""(异步) 获取当前正在进行的活动的便捷函数"""
|
||||
return await ScheduleAPI.get_current_activity()
|
||||
|
||||
@@ -165,16 +165,16 @@ async def regenerate_schedule() -> bool:
|
||||
return await ScheduleAPI.regenerate_schedule()
|
||||
|
||||
|
||||
async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]:
|
||||
async def get_monthly_plans(target_month: str | None = None) -> list[MonthlyPlan]:
|
||||
"""(异步) 获取指定月份的有效月度计划的便捷函数"""
|
||||
return await ScheduleAPI.get_monthly_plans(target_month)
|
||||
|
||||
|
||||
async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool:
|
||||
async def ensure_monthly_plans(target_month: str | None = None) -> bool:
|
||||
"""(异步) 确保指定月份存在月度计划的便捷函数"""
|
||||
return await ScheduleAPI.ensure_monthly_plans(target_month)
|
||||
|
||||
|
||||
async def archive_monthly_plans(target_month: Optional[str] = None) -> bool:
|
||||
async def archive_monthly_plans(target_month: str | None = None) -> bool:
|
||||
"""(异步) 归档指定月份的月度计划的便捷函数"""
|
||||
return await ScheduleAPI.archive_monthly_plans(target_month)
|
||||
|
||||
@@ -28,29 +28,28 @@
|
||||
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Optional, Union, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from maim_message import Seg, UserInfo
|
||||
|
||||
# 导入依赖
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from maim_message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv
|
||||
from maim_message import Seg
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
# 日志记录器
|
||||
logger = get_logger("send_api")
|
||||
|
||||
# 适配器命令响应等待池
|
||||
_adapter_response_pool: Dict[str, asyncio.Future] = {}
|
||||
_adapter_response_pool: dict[str, asyncio.Future] = {}
|
||||
|
||||
|
||||
def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[MessageRecv]:
|
||||
def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None:
|
||||
"""查找要回复的消息
|
||||
|
||||
Args:
|
||||
@@ -134,13 +133,13 @@ async def wait_adapter_response(request_id: str, timeout: float = 30.0) -> dict:
|
||||
|
||||
async def _send_to_target(
|
||||
message_type: str,
|
||||
content: Union[str, dict],
|
||||
content: str | dict,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
set_reply: bool = False,
|
||||
reply_to_message: Optional[Dict[str, Any]] = None,
|
||||
reply_to_message: dict[str, Any] | None = None,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
) -> bool:
|
||||
@@ -247,7 +246,7 @@ async def text_to_stream(
|
||||
stream_id: str,
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
reply_to_message: Optional[Dict[str, Any]] = None,
|
||||
reply_to_message: dict[str, Any] | None = None,
|
||||
set_reply: bool = True,
|
||||
storage_message: bool = True,
|
||||
) -> bool:
|
||||
@@ -313,7 +312,7 @@ async def image_to_stream(
|
||||
|
||||
|
||||
async def command_to_stream(
|
||||
command: Union[str, dict],
|
||||
command: str | dict,
|
||||
stream_id: str,
|
||||
storage_message: bool = True,
|
||||
display_message: str = "",
|
||||
@@ -341,7 +340,7 @@ async def custom_to_stream(
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
reply_to_message: Optional[Dict[str, Any]] = None,
|
||||
reply_to_message: dict[str, Any] | None = None,
|
||||
set_reply: bool = True,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
@@ -377,8 +376,8 @@ async def custom_to_stream(
|
||||
async def adapter_command_to_stream(
|
||||
action: str,
|
||||
params: dict,
|
||||
platform: Optional[str] = "qq",
|
||||
stream_id: Optional[str] = None,
|
||||
platform: str | None = "qq",
|
||||
stream_id: str | None = None,
|
||||
timeout: float = 30.0,
|
||||
storage_message: bool = False,
|
||||
) -> dict:
|
||||
@@ -497,4 +496,4 @@ async def adapter_command_to_stream(
|
||||
except Exception as e:
|
||||
logger.error(f"[SendAPI] 发送适配器命令时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return {"status": "error", "message": f"发送适配器命令时出错: {str(e)}"}
|
||||
return {"status": "error", "message": f"发送适配器命令时出错: {e!s}"}
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
from typing import Optional, Type
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_api")
|
||||
|
||||
|
||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
def get_tool_instance(tool_name: str) -> BaseTool | None:
|
||||
"""获取公开工具实例"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
@@ -18,7 +16,7 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
else:
|
||||
plugin_config = None
|
||||
|
||||
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
|
||||
tool_class: type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
|
||||
return tool_class(plugin_config) if tool_class else None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user