refactor: 统一类型注解风格并优化代码结构

- 将裸 except 改为显式 Exception 捕获
- 用列表推导式替换冗余 for 循环
- 为类属性添加 ClassVar 注解
- 统一 Union/Optional 写法为 |
- 移除未使用的导入
- 修复 SQLAlchemy 空值比较语法
- 优化字符串拼接与字典更新逻辑
- 补充缺失的 noqa 注释与异常链

BREAKING CHANGE: 所有插件基类的类级字段现要求显式 ClassVar 注解,自定义插件需同步更新
This commit is contained in:
明天好像没什么
2025-10-31 22:42:39 +08:00
parent 5080cfccfc
commit 0e129d385e
105 changed files with 592 additions and 561 deletions

View File

@@ -44,7 +44,6 @@ from .base import (
# 新增的增强命令系统
PlusCommand,
PlusCommandAdapter,
PlusCommandInfo,
PythonDependency,
ToolInfo,
ToolParamType,

View File

@@ -48,9 +48,10 @@ class ChatManager:
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for stream in get_chat_manager().streams.values():
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
streams.append(stream)
streams.extend(
stream for stream in get_chat_manager().streams.values()
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的聊天流")
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流失败: {e}")
@@ -71,9 +72,10 @@ class ChatManager:
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for stream in get_chat_manager().streams.values():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info:
streams.append(stream)
streams.extend(
stream for stream in get_chat_manager().streams.values()
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info
)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e:
logger.error(f"[ChatAPI] 获取群聊流失败: {e}")
@@ -97,9 +99,10 @@ class ChatManager:
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for stream in get_chat_manager().streams.values():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info:
streams.append(stream)
streams.extend(
stream for stream in get_chat_manager().streams.values()
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info
)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e:
logger.error(f"[ChatAPI] 获取私聊流失败: {e}")

View File

@@ -183,9 +183,10 @@ async def build_cross_context_s4u(
blacklisted_streams.add(stream_id)
except ValueError:
logger.warning(f"无效的S4U黑名单格式: {chat_str}")
for stream_id in chat_manager.streams:
if stream_id != chat_stream.stream_id and stream_id not in blacklisted_streams:
streams_to_scan.append(stream_id)
streams_to_scan.extend(
stream_id for stream_id in chat_manager.streams
if stream_id != chat_stream.stream_id and stream_id not in blacklisted_streams
)
logger.debug(f"[S4U] Found {len(streams_to_scan)} group streams to scan.")

View File

@@ -47,7 +47,7 @@ class ScoringAPI:
return await relationship_service.get_user_relationship_data(user_id)
@staticmethod
async def update_user_relationship(user_id: str, relationship_score: float, relationship_text: str = None, user_name: str = None):
async def update_user_relationship(user_id: str, relationship_score: float, relationship_text: str | None = None, user_name: str | None = None):
"""
更新用户关系数据
@@ -71,7 +71,7 @@ class ScoringAPI:
await interest_service.initialize_smart_interests(personality_description, personality_id)
@staticmethod
async def calculate_interest_match(content: str, keywords: list[str] = None):
async def calculate_interest_match(content: str, keywords: list[str] | None = None):
"""
计算内容与兴趣的匹配度
@@ -98,7 +98,7 @@ class ScoringAPI:
}
@staticmethod
def clear_caches(user_id: str = None):
def clear_caches(user_id: str | None = None):
"""
清理缓存

View File

@@ -9,7 +9,7 @@
import json
import os
import threading
from typing import Any
from typing import Any, ClassVar
from src.common.logger import get_logger
@@ -26,7 +26,7 @@ class PluginStorageManager:
现在它和API住在一起了希望它们能和睦相处。
"""
_instances: dict[str, "PluginStorage"] = {}
_instances: ClassVar[dict[str, "PluginStorage"] ] = {}
_lock = threading.Lock()
_base_path = os.path.join("data", "plugin_data")

View File

@@ -9,11 +9,11 @@ logger = get_logger("tool_api")
def get_tool_instance(tool_name: str, chat_stream: Any = None) -> BaseTool | None:
"""获取公开工具实例
Args:
tool_name: 工具名称
chat_stream: 聊天流对象,用于提供上下文信息
Returns:
BaseTool: 工具实例如果工具不存在则返回None
"""

View File

@@ -3,7 +3,7 @@ import asyncio
import random
import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar
from src.chat.message_receive.chat_stream import ChatStream
from src.common.data_models.database_data_model import DatabaseMessages
@@ -26,30 +26,30 @@ class BaseAction(ABC):
新的激活机制 (推荐使用)
==================================================================================
推荐通过重写 go_activate() 方法来自定义激活逻辑:
示例 1 - 关键词激活:
async def go_activate(self, llm_judge_model=None) -> bool:
return await self._keyword_match(["你好", "hello"])
示例 2 - LLM 判断激活:
async def go_activate(self, llm_judge_model=None) -> bool:
return await self._llm_judge_activation(
"当用户询问天气信息时激活",
llm_judge_model
)
示例 3 - 组合多种条件:
async def go_activate(self, llm_judge_model=None) -> bool:
# 30% 随机概率,或者匹配关键词
if await self._random_activation(0.3):
return True
return await self._keyword_match(["表情", "emoji"])
提供的工具函数:
- _random_activation(probability): 随机激活
- _keyword_match(keywords, case_sensitive): 关键词匹配(自动获取聊天内容)
- _llm_judge_activation(judge_prompt, llm_judge_model): LLM 判断(自动获取聊天内容)
注意:聊天内容会自动从实例属性中获取,无需手动传入。
==================================================================================
@@ -68,7 +68,7 @@ class BaseAction(ABC):
==================================================================================
- mode_enable: 启用的聊天模式
- parallel_action: 是否允许并行执行
二步Action相关属性
- is_two_step_action: 是否为二步Action
- step_one_description: 第一步的描述
@@ -80,7 +80,7 @@ class BaseAction(ABC):
"""是否为二步Action。如果为TrueAction将分两步执行第一步选择操作第二步执行具体操作"""
step_one_description: str = ""
"""第一步的描述用于向LLM展示Action的基本功能"""
sub_actions: list[tuple[str, str, dict[str, str]]] = []
sub_actions: ClassVar[list[tuple[str, str, dict[str, str]]] ] = []
"""子Action列表格式为[(子Action名, 子Action描述, 子Action参数)]。仅在二步Action中使用"""
def __init__(
@@ -110,7 +110,7 @@ class BaseAction(ABC):
**kwargs: 其他参数
"""
if plugin_config is None:
plugin_config = {}
plugin_config: ClassVar = {}
self.action_data = action_data
self.reasoning = reasoning
self.cycle_timers = cycle_timers
@@ -489,7 +489,7 @@ class BaseAction(ABC):
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
# 3. 实例化被调用的Action
action_params = {
action_params: ClassVar = {
"action_data": called_action_data,
"reasoning": f"Called by {self.action_name}",
"cycle_timers": self.cycle_timers,
@@ -615,9 +615,9 @@ class BaseAction(ABC):
def _get_chat_content(self) -> str:
"""获取聊天内容用于激活判断
从实例属性中获取聊天内容。子类可以重写此方法来自定义获取逻辑。
Returns:
str: 聊天内容
"""
@@ -645,7 +645,7 @@ class BaseAction(ABC):
也可以使用提供的工具函数来简化常见的激活判断。
默认实现会检查类属性中的激活类型配置,提供向后兼容支持。
聊天内容会自动从实例属性中获取,不需要手动传入。
Args:
@@ -721,7 +721,7 @@ class BaseAction(ABC):
case_sensitive: bool = False,
) -> bool:
"""关键词匹配工具函数
聊天内容会自动从实例属性中获取。
Args:
@@ -742,7 +742,7 @@ class BaseAction(ABC):
if not case_sensitive:
search_text = search_text.lower()
matched_keywords = []
matched_keywords: ClassVar = []
for keyword in keywords:
check_keyword = keyword if case_sensitive else keyword.lower()
if check_keyword in search_text:
@@ -766,7 +766,7 @@ class BaseAction(ABC):
使用 LLM 来判断是否应该激活此 Action。
会自动构建完整的判断提示词,只需要提供核心判断逻辑即可。
聊天内容会自动从实例属性中获取。
Args:

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatterInfo, ComponentType
@@ -15,7 +15,7 @@ class BaseChatter(ABC):
"""Chatter组件的名称"""
chatter_description: str = ""
"""Chatter组件的描述"""
chat_types: list[ChatType] = [ChatType.PRIVATE, ChatType.GROUP]
chat_types: ClassVar[list[ChatType]] = [ChatType.PRIVATE, ChatType.GROUP]
def __init__(self, stream_id: str, action_manager: "ChatterActionManager"):
"""

View File

@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import ClassVar
from src.common.logger import get_logger
@@ -21,7 +22,7 @@ class BaseEventHandler(ABC):
"""处理器权重,越大权重越高"""
intercept_message: bool = False
"""是否拦截消息,默认为否"""
init_subscribe: list[EventType | str] = [EventType.UNKNOWN]
init_subscribe: ClassVar[list[EventType | str]] = [EventType.UNKNOWN]
"""初始化时订阅的事件名称"""
plugin_name = None

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, ClassVar
from src.chat.utils.prompt_params import PromptParameters
from src.common.logger import get_logger
@@ -27,7 +27,7 @@ class BasePrompt(ABC):
# 定义此组件希望如何注入到核心Prompt中
# 这是一个 InjectionRule 对象的列表,可以实现复杂的注入逻辑
# 例如: [InjectionRule(target_prompt="planner_prompt", injection_type=InjectionType.APPEND, priority=50)]
injection_rules: list[InjectionRule] = []
injection_rules: ClassVar[list[InjectionRule] ] = []
"""定义注入规则的列表"""
# 旧的注入点定义,用于向后兼容。如果定义了这个,它将被自动转换为 injection_rules。

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, ClassVar
from rich.traceback import install
@@ -18,7 +18,7 @@ class BaseTool(ABC):
"""工具的名称"""
description: str = ""
"""工具的描述"""
parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]] ] = []
"""工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式
param_name: 参数名称
param_type: 参数类型
@@ -44,7 +44,7 @@ class BaseTool(ABC):
"""是否为二步工具。如果为True工具将分两步调用第一步展示工具信息第二步执行具体操作"""
step_one_description: str = ""
"""第一步的描述用于向LLM展示工具的基本功能"""
sub_tools: list[tuple[str, str, list[tuple[str, ToolParamType, str, bool, list[str] | None]]]] = []
sub_tools: ClassVar[list[tuple[str, str, list[tuple[str, ToolParamType, str, bool, list[str] | None]]]] ] = []
"""子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用"""
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
@@ -112,7 +112,7 @@ class BaseTool(ABC):
if not cls.is_two_step_tool:
return []
definitions = []
definitions: ClassVar = []
for sub_name, sub_desc, sub_params in cls.sub_tools:
definitions.append({"name": f"{cls.name}_{sub_name}", "description": sub_desc, "parameters": sub_params})
return definitions

View File

@@ -3,7 +3,7 @@ import os
import shutil
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
from typing import Any, ClassVar
import toml
@@ -30,11 +30,11 @@ class PluginBase(ABC):
config_file_name: str
enable_plugin: bool = True
config_schema: dict[str, dict[str, ConfigField] | str] = {}
config_schema: ClassVar[dict[str, dict[str, ConfigField] | str] ] = {}
permission_nodes: list["PermissionNodeField"] = []
permission_nodes: ClassVar[list["PermissionNodeField"] ] = []
config_section_descriptions: dict[str, str] = {}
config_section_descriptions: ClassVar[dict[str, str] ] = {}
def __init__(self, plugin_dir: str, metadata: PluginMetadata):
"""初始化插件
@@ -206,12 +206,12 @@ class PluginBase(ABC):
if not self.config_schema:
return {}
config_data = {}
config_data: ClassVar = {}
# 遍历每个配置节
for section, fields in self.config_schema.items():
if isinstance(fields, dict):
section_data = {}
section_data: ClassVar = {}
# 遍历节内的字段
for field_name, field in fields.items():
@@ -331,7 +331,7 @@ class PluginBase(ABC):
try:
with open(user_config_path, encoding="utf-8") as f:
user_config = toml.load(f) or {}
user_config: ClassVar = toml.load(f) or {}
except Exception as e:
logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True)
self.config = self._generate_config_from_schema() # 加载失败时使用默认 schema

View File

@@ -5,7 +5,7 @@
import re
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
@@ -42,7 +42,7 @@ class PlusCommand(ABC):
command_description: str = ""
"""命令描述"""
command_aliases: list[str] = []
command_aliases: ClassVar[list[str] ] = []
"""命令别名列表,如 ['say', 'repeat']"""
priority: int = 0
@@ -435,7 +435,3 @@ def create_plus_command_adapter(plus_command_class):
return AdapterClass
# 兼容旧的命名
PlusCommandAdapter = create_plus_command_adapter

View File

@@ -87,8 +87,8 @@ class ComponentRegistry:
self._tool_registry: dict[str, type["BaseTool"]] = {} # 工具名 -> 工具类
self._llm_available_tools: dict[str, type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
# MCP 工具注册表运行时动态加载
self._mcp_tools: list["BaseTool"] = [] # MCP 工具适配器实例列表
# MCP 工具注册表(运行时动态加载)
self._mcp_tools: list[Any] = [] # MCP 工具适配器实例列表
self._mcp_tools_loaded = False # MCP 工具是否已加载
# EventHandler特定注册表

View File

@@ -7,7 +7,6 @@ from threading import Lock
from typing import Any, Optional
from src.common.logger import get_logger
from src.plugin_system import BaseEventHandler
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.component_types import EventType
@@ -176,10 +175,10 @@ class EventManager:
# 处理init_subscribe缓存失败的订阅
if self._event_handlers[handler_name].init_subscribe:
failed_subscriptions = []
for event_name in self._event_handlers[handler_name].init_subscribe:
if not self.subscribe_handler_to_event(handler_name, event_name):
failed_subscriptions.append(event_name)
failed_subscriptions = [
event_name for event_name in self._event_handlers[handler_name].init_subscribe
if not self.subscribe_handler_to_event(handler_name, event_name)
]
# 缓存失败的订阅
if failed_subscriptions:

View File

@@ -4,7 +4,7 @@ MCP Tool Adapter
将 MCP 工具适配为 BaseTool使其能够被插件系统识别和调用
"""
from typing import Any, ClassVar
from typing import Any
import mcp.types
@@ -27,9 +27,6 @@ class MCPToolAdapter(BaseTool):
3. 参与工具缓存机制
"""
# 类级别默认值,使用 ClassVar 标注
available_for_llm: ClassVar[bool] = True
def __init__(self, server_name: str, mcp_tool: mcp.types.Tool, plugin_config: dict | None = None):
"""
初始化 MCP 工具适配器
@@ -47,6 +44,7 @@ class MCPToolAdapter(BaseTool):
# 设置实例属性
self.name = f"mcp_{server_name}_{mcp_tool.name}"
self.description = mcp_tool.description or f"MCP tool from {server_name}"
self.available_for_llm = True # MCP 工具默认可供 LLM 使用
# 转换参数定义
self.parameters = self._convert_parameters(mcp_tool.inputSchema)

View File

@@ -456,8 +456,7 @@ class PermissionManager(IPermissionManager):
)
granted_users = result.scalars().all()
for user_perm in granted_users:
users.append((user_perm.platform, user_perm.user_id))
users.extend((user_perm.platform, user_perm.user_id) for user_perm in granted_users)
# 如果是默认授权的权限节点,还需要考虑没有明确设置的用户
# 但这里我们只返回明确授权的用户,避免返回所有用户

View File

@@ -94,7 +94,6 @@ class PluginManager:
if not plugin_class:
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
return False, 1
init_module = None # 预先定义,避免后续条件加载导致未绑定
try:
# 使用记录的插件目录路径
plugin_dir = self.plugin_paths.get(plugin_name)