re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -14,14 +14,15 @@ from src.plugin_system.apis import (
generator_api,
llm_api,
message_api,
permission_api,
person_api,
plugin_manage_api,
schedule_api,
send_api,
tool_api,
permission_api,
schedule_api,
)
from src.plugin_system.apis.chat_api import ChatManager as context_api
from .logging_api import get_logger
from .plugin_register_api import register_plugin
@@ -30,18 +31,18 @@ __all__ = [
"chat_api",
"component_manage_api",
"config_api",
"context_api",
"database_api",
"emoji_api",
"generator_api",
"get_logger",
"llm_api",
"message_api",
"permission_api",
"person_api",
"plugin_manage_api",
"send_api",
"get_logger",
"register_plugin",
"tool_api",
"permission_api",
"context_api",
"schedule_api",
"send_api",
"tool_api",
]

View File

@@ -12,11 +12,11 @@
streams = chat.get_all_group_streams()
"""
from typing import List, Dict, Any, Optional
from enum import Enum
from typing import Any
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.common.logger import get_logger
logger = get_logger("chat_api")
@@ -31,7 +31,7 @@ class ChatManager:
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
# sourcery skip: for-append-to-extend
"""获取所有聊天流
@@ -57,7 +57,7 @@ class ChatManager:
return streams
@staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
# sourcery skip: for-append-to-extend
"""获取所有群聊聊天流
@@ -80,7 +80,7 @@ class ChatManager:
return streams
@staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
# sourcery skip: for-append-to-extend
"""获取所有私聊聊天流
@@ -107,8 +107,8 @@ class ChatManager:
@staticmethod
def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
group_id: str, platform: str | None | SpecialTypes = "qq"
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast
"""根据群ID获取聊天流
Args:
@@ -144,8 +144,8 @@ class ChatManager:
@staticmethod
def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
user_id: str, platform: str | None | SpecialTypes = "qq"
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast
"""根据用户ID获取私聊流
Args:
@@ -203,7 +203,7 @@ class ChatManager:
return "unknown"
@staticmethod
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]:
"""获取聊天流详细信息
Args:
@@ -222,7 +222,7 @@ class ChatManager:
raise TypeError("chat_stream 必须是 ChatStream 类型")
try:
info: Dict[str, Any] = {
info: dict[str, Any] = {
"stream_id": chat_stream.stream_id,
"platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream),
@@ -250,7 +250,7 @@ class ChatManager:
return {}
@staticmethod
def get_streams_summary() -> Dict[str, int]:
def get_streams_summary() -> dict[str, int]:
"""获取聊天流统计摘要
Returns:
@@ -285,27 +285,27 @@ class ChatManager:
# =============================================================================
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
"""获取所有聊天流的便捷函数"""
return ChatManager.get_all_streams(platform)
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
"""获取群聊聊天流的便捷函数"""
return ChatManager.get_group_streams(platform)
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
"""获取私聊聊天流的便捷函数"""
return ChatManager.get_private_streams(platform)
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None:
"""根据群ID获取聊天流的便捷函数"""
return ChatManager.get_group_stream_by_group_id(group_id, platform)
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None:
"""根据用户ID获取私聊流的便捷函数"""
return ChatManager.get_private_stream_by_user_id(user_id, platform)
@@ -315,11 +315,11 @@ def get_stream_type(chat_stream: ChatStream) -> str:
return ChatManager.get_stream_type(chat_stream)
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]:
"""获取聊天流信息的便捷函数"""
return ChatManager.get_stream_info(chat_stream)
def get_streams_summary() -> Dict[str, int]:
def get_streams_summary() -> dict[str, int]:
"""获取聊天流统计摘要的便捷函数"""
return ChatManager.get_streams_summary()

View File

@@ -1,16 +1,15 @@
from typing import Optional, Union, Dict
from src.plugin_system.base.component_types import (
CommandInfo,
ActionInfo,
CommandInfo,
ComponentType,
EventHandlerInfo,
PluginInfo,
ComponentType,
ToolInfo,
)
# === 插件信息查询 ===
def get_all_plugin_info() -> Dict[str, PluginInfo]:
def get_all_plugin_info() -> dict[str, PluginInfo]:
"""
获取所有插件的信息。
@@ -22,7 +21,7 @@ def get_all_plugin_info() -> Dict[str, PluginInfo]:
return component_registry.get_all_plugins()
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
def get_plugin_info(plugin_name: str) -> PluginInfo | None:
"""
获取指定插件的信息。
@@ -40,7 +39,7 @@ def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
# === 组件查询方法 ===
def get_component_info(
component_name: str, component_type: ComponentType
) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
) -> CommandInfo | ActionInfo | EventHandlerInfo | None:
"""
获取指定组件的信息。
@@ -57,7 +56,7 @@ def get_component_info(
def get_components_info_by_type(
component_type: ComponentType,
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
) -> dict[str, CommandInfo | ActionInfo | EventHandlerInfo]:
"""
获取指定类型的所有组件信息。
@@ -74,7 +73,7 @@ def get_components_info_by_type(
def get_enabled_components_info_by_type(
component_type: ComponentType,
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
) -> dict[str, CommandInfo | ActionInfo | EventHandlerInfo]:
"""
获取指定类型的所有启用的组件信息。
@@ -90,7 +89,7 @@ def get_enabled_components_info_by_type(
# === Action 查询方法 ===
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
def get_registered_action_info(action_name: str) -> ActionInfo | None:
"""
获取指定 Action 的注册信息。
@@ -105,7 +104,7 @@ def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
return component_registry.get_registered_action_info(action_name)
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
def get_registered_command_info(command_name: str) -> CommandInfo | None:
"""
获取指定 Command 的注册信息。
@@ -120,7 +119,7 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
return component_registry.get_registered_command_info(command_name)
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
def get_registered_tool_info(tool_name: str) -> ToolInfo | None:
"""
获取指定 Tool 的注册信息。
@@ -138,7 +137,7 @@ def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
# === EventHandler 特定查询方法 ===
def get_registered_event_handler_info(
event_handler_name: str,
) -> Optional[EventHandlerInfo]:
) -> EventHandlerInfo | None:
"""
获取指定 EventHandler 的注册信息。

View File

@@ -8,6 +8,7 @@
"""
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config

View File

@@ -3,20 +3,20 @@
"""
import time
from typing import Dict, Any, Optional, List
from typing import Any
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
)
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
build_readable_messages_with_id,
)
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
logger = get_logger("cross_context_api")
def get_context_groups(chat_id: str) -> Optional[List[List[str]]]:
def get_context_groups(chat_id: str) -> list[list[str]] | None:
"""
获取当前聊天所在的共享组的其他聊天ID
"""
@@ -41,7 +41,7 @@ def get_context_groups(chat_id: str) -> Optional[List[List[str]]]:
return None
async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: List[List[str]]) -> str:
async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: list[list[str]]) -> str:
"""
构建跨群聊/私聊上下文 (Normal模式)
"""
@@ -74,8 +74,8 @@ async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos:
async def build_cross_context_s4u(
chat_stream: ChatStream,
other_chat_infos: List[List[str]],
target_user_info: Optional[Dict[str, Any]],
other_chat_infos: list[list[str]],
target_user_info: dict[str, Any] | None,
) -> str:
"""
构建跨群聊/私聊上下文 (S4U模式)

View File

@@ -9,7 +9,7 @@
注意此模块现在使用SQLAlchemy实现提供更好的连接管理和错误处理
"""
from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get, store_action_info, MODEL_MAPPING
from src.common.database.sqlalchemy_database_api import MODEL_MAPPING, db_get, db_query, db_save, store_action_info
# 保持向后兼容性
__all__ = ["db_query", "db_save", "db_get", "store_action_info", "MODEL_MAPPING"]
__all__ = ["MODEL_MAPPING", "db_get", "db_query", "db_save", "store_action_info"]

View File

@@ -10,10 +10,9 @@
import random
from typing import Optional, Tuple, List
from src.common.logger import get_logger
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.utils.utils_image import image_path_to_base64
from src.common.logger import get_logger
logger = get_logger("emoji_api")
@@ -23,7 +22,7 @@ logger = get_logger("emoji_api")
# =============================================================================
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
async def get_by_description(description: str) -> tuple[str, str, str] | None:
"""根据描述选择表情包
Args:
@@ -65,7 +64,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
return None
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
async def get_random(count: int | None = 1) -> list[tuple[str, str, str]]:
"""随机获取指定数量的表情包
Args:
@@ -137,7 +136,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
return []
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
async def get_by_emotion(emotion: str) -> tuple[str, str, str] | None:
"""根据情感标签获取表情包
Args:
@@ -227,7 +226,7 @@ def get_info():
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
def get_emotions() -> List[str]:
def get_emotions() -> list[str]:
"""获取所有可用的情感标签
Returns:
@@ -247,7 +246,7 @@ def get_emotions() -> List[str]:
return []
def get_descriptions() -> List[str]:
def get_descriptions() -> list[str]:
"""获取所有表情包描述
Returns:

View File

@@ -9,13 +9,15 @@
"""
import traceback
from typing import Tuple, Any, Dict, List, Optional
from typing import Any
from rich.traceback import install
from src.common.logger import get_logger
from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.utils import process_llm_response
from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.replyer.replyer_manager import replyer_manager
from src.chat.utils.utils import process_llm_response
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ActionInfo
install(extra_lines=3)
@@ -30,10 +32,10 @@ logger = get_logger("generator_api")
def get_replyer(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
chat_stream: ChatStream | None = None,
chat_id: str | None = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer]:
) -> DefaultReplyer | None:
"""获取回复器对象
优先使用chat_stream如果没有则使用chat_id直接查找。
@@ -71,13 +73,13 @@ def get_replyer(
async def generate_reply(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None,
chat_stream: ChatStream | None = None,
chat_id: str | None = None,
action_data: dict[str, Any] | None = None,
reply_to: str = "",
reply_message: Optional[Dict[str, Any]] = None,
reply_message: dict[str, Any] | None = None,
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
available_actions: dict[str, ActionInfo] | None = None,
enable_tool: bool = False,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
@@ -85,7 +87,7 @@ async def generate_reply(
request_type: str = "generator_api",
from_plugin: bool = True,
read_mark: float = 0.0,
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
) -> tuple[bool, list[tuple[str, Any]], str | None]:
"""生成回复
Args:
@@ -168,9 +170,9 @@ async def generate_reply(
async def rewrite_reply(
chat_stream: Optional[ChatStream] = None,
reply_data: Optional[Dict[str, Any]] = None,
chat_id: Optional[str] = None,
chat_stream: ChatStream | None = None,
reply_data: dict[str, Any] | None = None,
chat_id: str | None = None,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
raw_reply: str = "",
@@ -178,7 +180,7 @@ async def rewrite_reply(
reply_to: str = "",
return_prompt: bool = False,
request_type: str = "generator_api",
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
) -> tuple[bool, list[tuple[str, Any]], str | None]:
"""重写回复
Args:
@@ -237,7 +239,7 @@ async def rewrite_reply(
return False, [], None
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> list[tuple[str, Any]]:
"""将文本处理为更拟人化的文本
Args:
@@ -266,11 +268,11 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
async def generate_response_custom(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
chat_stream: ChatStream | None = None,
chat_id: str | None = None,
request_type: str = "generator_api",
prompt: str = "",
) -> Optional[str]:
) -> str | None:
"""
使用自定义提示生成回复

View File

@@ -7,12 +7,13 @@
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
"""
from typing import Tuple, Dict, List, Any, Optional
from typing import Any
from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.config.config import model_config
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.config.api_ada_configs import TaskConfig
logger = get_logger("llm_api")
@@ -21,7 +22,7 @@ logger = get_logger("llm_api")
# =============================================================================
def get_available_models() -> Dict[str, TaskConfig]:
def get_available_models() -> dict[str, TaskConfig]:
"""获取所有可用的模型配置
Returns:
@@ -31,7 +32,7 @@ def get_available_models() -> Dict[str, TaskConfig]:
# 自动获取所有属性并转换为字典形式
models = model_config.model_task_config
attrs = dir(models)
rets: Dict[str, TaskConfig] = {}
rets: dict[str, TaskConfig] = {}
for attr in attrs:
if not attr.startswith("__"):
try:
@@ -52,9 +53,9 @@ async def generate_with_model(
prompt: str,
model_config: TaskConfig,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str]:
temperature: float | None = None,
max_tokens: int | None = None,
) -> tuple[bool, str, str, str]:
"""使用指定模型生成内容
Args:
@@ -78,7 +79,7 @@ async def generate_with_model(
return True, response, reasoning_content, model_name
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
error_msg = f"生成内容时出错: {e!s}"
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", ""
@@ -86,11 +87,11 @@ async def generate_with_model(
async def generate_with_model_with_tools(
prompt: str,
model_config: TaskConfig,
tool_options: List[Dict[str, Any]] | None = None,
tool_options: list[dict[str, Any]] | None = None,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
temperature: float | None = None,
max_tokens: int | None = None,
) -> tuple[bool, str, str, str, list[ToolCall] | None]:
"""使用指定模型和工具生成内容
Args:
@@ -117,6 +118,6 @@ async def generate_with_model_with_tools(
return True, response, reasoning_content, model_name, tool_call
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
error_msg = f"生成内容时出错: {e!s}"
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", "", None

View File

@@ -8,26 +8,26 @@
readable_text = message_api.build_readable_messages(messages)
"""
from typing import List, Dict, Any, Tuple, Optional
from src.config.config import global_config
import time
from typing import Any
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_users,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
build_readable_messages,
build_readable_messages_with_list,
get_person_id_list,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
)
from src.config.config import global_config
# =============================================================================
# 消息查询API函数
@@ -36,7 +36,7 @@ from src.chat.utils.chat_message_builder import (
async def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
获取指定时间范围内的消息
@@ -70,7 +70,7 @@ async def get_messages_by_time_in_chat(
limit_mode: str = "latest",
filter_mai: bool = False,
filter_command: bool = False,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
获取指定聊天中指定时间范围内的消息
@@ -111,7 +111,7 @@ async def get_messages_by_time_in_chat_inclusive(
limit_mode: str = "latest",
filter_mai: bool = False,
filter_command: bool = False,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
获取指定聊天中指定时间范围内的消息(包含边界)
@@ -152,10 +152,10 @@ async def get_messages_by_time_in_chat_for_users(
chat_id: str,
start_time: float,
end_time: float,
person_ids: List[str],
person_ids: list[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
获取指定聊天中指定用户在指定时间范围内的消息
@@ -186,7 +186,7 @@ async def get_messages_by_time_in_chat_for_users(
async def get_random_chat_messages(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
随机选择一个聊天,返回该聊天在指定时间范围内的消息
@@ -213,8 +213,8 @@ async def get_random_chat_messages(
async def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
start_time: float, end_time: float, person_ids: list[str], limit: int = 0, limit_mode: str = "latest"
) -> list[dict[str, Any]]:
"""
获取指定用户在所有聊天中指定时间范围内的消息
@@ -238,7 +238,7 @@ async def get_messages_by_time_for_users(
return await get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]:
async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> list[dict[str, Any]]:
"""
获取指定时间戳之前的消息
@@ -294,8 +294,8 @@ async def get_messages_before_time_in_chat(
async def get_messages_before_time_for_users(
timestamp: float, person_ids: List[str], limit: int = 0
) -> List[Dict[str, Any]]:
timestamp: float, person_ids: list[str], limit: int = 0
) -> list[dict[str, Any]]:
"""
获取指定用户在指定时间戳之前的消息
@@ -319,7 +319,7 @@ async def get_messages_before_time_for_users(
async def get_recent_messages(
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
获取指定聊天中最近一段时间的消息
@@ -358,7 +358,7 @@ async def get_recent_messages(
# =============================================================================
async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: float | None = None) -> int:
"""
计算指定聊天中从开始时间到结束时间的新消息数量
@@ -382,7 +382,7 @@ async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Op
return await num_new_messages_since(chat_id, start_time, end_time)
async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: list[str]) -> int:
"""
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
@@ -413,7 +413,7 @@ async def count_new_messages_for_users(chat_id: str, start_time: float, end_time
async def build_readable_messages_to_str(
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
@@ -442,12 +442,12 @@ async def build_readable_messages_to_str(
async def build_readable_messages_with_details(
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
) -> tuple[str, list[tuple[float, str, str]]]:
"""
将消息列表构建成可读的字符串,并返回详细信息
@@ -464,7 +464,7 @@ async def build_readable_messages_with_details(
return await build_readable_messages_with_list(messages, replace_bot_name, merge_messages, timestamp_mode, truncate)
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
async def get_person_ids_from_messages(messages: list[dict[str, Any]]) -> list[str]:
"""
从消息列表中提取不重复的用户ID列表
@@ -482,7 +482,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
# =============================================================================
async def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
async def filter_mai_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
从消息列表中移除麦麦的消息
Args:

View File

@@ -1,9 +1,9 @@
"""纯异步权限API定义。所有外部调用方必须使用 await。"""
from typing import Optional, List, Dict, Any
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from abc import ABC, abstractmethod
from typing import Any
from src.common.logger import get_logger
@@ -48,18 +48,18 @@ class IPermissionManager(ABC):
async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: ...
@abstractmethod
async def get_user_permissions(self, user: UserInfo) -> List[str]: ...
async def get_user_permissions(self, user: UserInfo) -> list[str]: ...
@abstractmethod
async def get_all_permission_nodes(self) -> List[PermissionNode]: ...
async def get_all_permission_nodes(self) -> list[PermissionNode]: ...
@abstractmethod
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: ...
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[PermissionNode]: ...
class PermissionAPI:
def __init__(self):
self._permission_manager: Optional[IPermissionManager] = None
self._permission_manager: IPermissionManager | None = None
# 需要保留的前缀(视为绝对节点名,不再自动加 plugins.<plugin>. 前缀)
self.RESERVED_PREFIXES: tuple[str, ...] = "system."
# 系统节点列表 (name, description, default_granted)
@@ -147,11 +147,11 @@ class PermissionAPI:
self._ensure_manager()
return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node)
async def get_user_permissions(self, platform: str, user_id: str) -> List[str]:
async def get_user_permissions(self, platform: str, user_id: str) -> list[str]:
self._ensure_manager()
return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id))
async def get_all_permission_nodes(self) -> List[Dict[str, Any]]:
async def get_all_permission_nodes(self) -> list[dict[str, Any]]:
self._ensure_manager()
nodes = await self._permission_manager.get_all_permission_nodes()
return [
@@ -164,7 +164,7 @@ class PermissionAPI:
for n in nodes
]
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]:
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[dict[str, Any]]:
self._ensure_manager()
nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name)
return [

View File

@@ -7,9 +7,10 @@
value = await person_api.get_person_value(person_id, "nickname")
"""
from typing import Any, Optional
from typing import Any
from src.common.logger import get_logger
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
logger = get_logger("person_api")
@@ -63,7 +64,7 @@ async def get_person_value(person_id: str, field_name: str, default: Any = None)
return default
async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict:
async def get_person_values(person_id: str, field_names: list, default_dict: dict | None = None) -> dict:
"""批量获取用户信息字段值
Args:

View File

@@ -1,7 +1,4 @@
from typing import Tuple, List
def list_loaded_plugins() -> List[str]:
def list_loaded_plugins() -> list[str]:
"""
列出所有当前加载的插件。
@@ -13,7 +10,7 @@ def list_loaded_plugins() -> List[str]:
return plugin_manager.list_loaded_plugins()
def list_registered_plugins() -> List[str]:
def list_registered_plugins() -> list[str]:
"""
列出所有已注册的插件。
@@ -80,7 +77,7 @@ async def reload_plugin(plugin_name: str) -> bool:
return await plugin_manager.reload_registered_plugin(plugin_name)
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
def load_plugin(plugin_name: str) -> tuple[bool, int]:
"""
加载指定的插件。
@@ -109,7 +106,7 @@ def add_plugin_directory(plugin_directory: str) -> bool:
return plugin_manager.add_plugin_directory(plugin_directory)
def rescan_plugin_directory() -> Tuple[int, int]:
def rescan_plugin_directory() -> tuple[int, int]:
"""
重新扫描插件目录,加载新插件。
Returns:

View File

@@ -6,8 +6,8 @@ logger = get_logger("plugin_manager") # 复用plugin_manager名称
def register_plugin(cls):
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.core.plugin_manager import plugin_manager
"""插件注册装饰器

View File

@@ -30,7 +30,7 @@
"""
from datetime import datetime
from typing import List, Dict, Any, Optional
from typing import Any
from src.common.database.sqlalchemy_models import MonthlyPlan
from src.common.logger import get_logger
@@ -44,7 +44,7 @@ class ScheduleAPI:
"""日程表与月度计划API - 负责日程和计划信息的查询与管理"""
@staticmethod
async def get_today_schedule() -> Optional[List[Dict[str, Any]]]:
async def get_today_schedule() -> list[dict[str, Any]] | None:
"""(异步) 获取今天的日程安排
Returns:
@@ -58,7 +58,7 @@ class ScheduleAPI:
return None
@staticmethod
async def get_current_activity() -> Optional[str]:
async def get_current_activity() -> str | None:
"""(异步) 获取当前正在进行的活动
Returns:
@@ -87,7 +87,7 @@ class ScheduleAPI:
return False
@staticmethod
async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]:
async def get_monthly_plans(target_month: str | None = None) -> list[MonthlyPlan]:
"""(异步) 获取指定月份的有效月度计划
Args:
@@ -106,7 +106,7 @@ class ScheduleAPI:
return []
@staticmethod
async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool:
async def ensure_monthly_plans(target_month: str | None = None) -> bool:
"""(异步) 确保指定月份存在月度计划,如果不存在则触发生成
Args:
@@ -125,7 +125,7 @@ class ScheduleAPI:
return False
@staticmethod
async def archive_monthly_plans(target_month: Optional[str] = None) -> bool:
async def archive_monthly_plans(target_month: str | None = None) -> bool:
"""(异步) 归档指定月份的月度计划
Args:
@@ -150,12 +150,12 @@ class ScheduleAPI:
# =============================================================================
async def get_today_schedule() -> Optional[List[Dict[str, Any]]]:
async def get_today_schedule() -> list[dict[str, Any]] | None:
"""(异步) 获取今天的日程安排的便捷函数"""
return await ScheduleAPI.get_today_schedule()
async def get_current_activity() -> Optional[str]:
async def get_current_activity() -> str | None:
"""(异步) 获取当前正在进行的活动的便捷函数"""
return await ScheduleAPI.get_current_activity()
@@ -165,16 +165,16 @@ async def regenerate_schedule() -> bool:
return await ScheduleAPI.regenerate_schedule()
async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]:
async def get_monthly_plans(target_month: str | None = None) -> list[MonthlyPlan]:
"""(异步) 获取指定月份的有效月度计划的便捷函数"""
return await ScheduleAPI.get_monthly_plans(target_month)
async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool:
async def ensure_monthly_plans(target_month: str | None = None) -> bool:
"""(异步) 确保指定月份存在月度计划的便捷函数"""
return await ScheduleAPI.ensure_monthly_plans(target_month)
async def archive_monthly_plans(target_month: Optional[str] = None) -> bool:
async def archive_monthly_plans(target_month: str | None = None) -> bool:
"""(异步) 归档指定月份的月度计划的便捷函数"""
return await ScheduleAPI.archive_monthly_plans(target_month)

View File

@@ -28,29 +28,28 @@
"""
import traceback
import time
import asyncio
from typing import Optional, Union, Dict, Any
from src.common.logger import get_logger
import time
import traceback
from typing import Any
from maim_message import Seg, UserInfo
# 导入依赖
from src.chat.message_receive.chat_stream import get_chat_manager
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.message import MessageRecv, MessageSending
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.message_receive.message import MessageSending, MessageRecv
from maim_message import Seg
from src.common.logger import get_logger
from src.config.config import global_config
# 日志记录器
logger = get_logger("send_api")
# 适配器命令响应等待池
_adapter_response_pool: Dict[str, asyncio.Future] = {}
_adapter_response_pool: dict[str, asyncio.Future] = {}
def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[MessageRecv]:
def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None:
"""查找要回复的消息
Args:
@@ -134,13 +133,13 @@ async def wait_adapter_response(request_id: str, timeout: float = 30.0) -> dict:
async def _send_to_target(
message_type: str,
content: Union[str, dict],
content: str | dict,
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_to: str = "",
set_reply: bool = False,
reply_to_message: Optional[Dict[str, Any]] = None,
reply_to_message: dict[str, Any] | None = None,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
@@ -247,7 +246,7 @@ async def text_to_stream(
stream_id: str,
typing: bool = False,
reply_to: str = "",
reply_to_message: Optional[Dict[str, Any]] = None,
reply_to_message: dict[str, Any] | None = None,
set_reply: bool = True,
storage_message: bool = True,
) -> bool:
@@ -313,7 +312,7 @@ async def image_to_stream(
async def command_to_stream(
command: Union[str, dict],
command: str | dict,
stream_id: str,
storage_message: bool = True,
display_message: str = "",
@@ -341,7 +340,7 @@ async def custom_to_stream(
display_message: str = "",
typing: bool = False,
reply_to: str = "",
reply_to_message: Optional[Dict[str, Any]] = None,
reply_to_message: dict[str, Any] | None = None,
set_reply: bool = True,
storage_message: bool = True,
show_log: bool = True,
@@ -377,8 +376,8 @@ async def custom_to_stream(
async def adapter_command_to_stream(
action: str,
params: dict,
platform: Optional[str] = "qq",
stream_id: Optional[str] = None,
platform: str | None = "qq",
stream_id: str | None = None,
timeout: float = 30.0,
storage_message: bool = False,
) -> dict:
@@ -497,4 +496,4 @@ async def adapter_command_to_stream(
except Exception as e:
logger.error(f"[SendAPI] 发送适配器命令时出错: {e}")
traceback.print_exc()
return {"status": "error", "message": f"发送适配器命令时出错: {str(e)}"}
return {"status": "error", "message": f"发送适配器命令时出错: {e!s}"}

View File

@@ -1,13 +1,11 @@
from typing import Optional, Type
from src.common.logger import get_logger
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType
from src.common.logger import get_logger
logger = get_logger("tool_api")
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
def get_tool_instance(tool_name: str) -> BaseTool | None:
"""获取公开工具实例"""
from src.plugin_system.core import component_registry
@@ -18,7 +16,7 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
else:
plugin_config = None
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
tool_class: type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
return tool_class(plugin_config) if tool_class else None