refactor: 统一类型注解风格并优化代码结构
- 将裸 except 改为显式 Exception 捕获 - 用列表推导式替换冗余 for 循环 - 为类属性添加 ClassVar 注解 - 统一 Union/Optional 写法为 | - 移除未使用的导入 - 修复 SQLAlchemy 空值比较语法 - 优化字符串拼接与字典更新逻辑 - 补充缺失的 noqa 注释与异常链 BREAKING CHANGE: 所有插件基类的类级字段现要求显式 ClassVar 注解,自定义插件需同步更新
This commit is contained in:
@@ -44,7 +44,6 @@ from .base import (
|
||||
# 新增的增强命令系统
|
||||
PlusCommand,
|
||||
PlusCommandAdapter,
|
||||
PlusCommandInfo,
|
||||
PythonDependency,
|
||||
ToolInfo,
|
||||
ToolParamType,
|
||||
|
||||
@@ -48,9 +48,10 @@ class ChatManager:
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for stream in get_chat_manager().streams.values():
|
||||
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
|
||||
streams.append(stream)
|
||||
streams.extend(
|
||||
stream for stream in get_chat_manager().streams.values()
|
||||
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
|
||||
)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取聊天流失败: {e}")
|
||||
@@ -71,9 +72,10 @@ class ChatManager:
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for stream in get_chat_manager().streams.values():
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info:
|
||||
streams.append(stream)
|
||||
streams.extend(
|
||||
stream for stream in get_chat_manager().streams.values()
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info
|
||||
)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的群聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取群聊流失败: {e}")
|
||||
@@ -97,9 +99,10 @@ class ChatManager:
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for stream in get_chat_manager().streams.values():
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info:
|
||||
streams.append(stream)
|
||||
streams.extend(
|
||||
stream for stream in get_chat_manager().streams.values()
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info
|
||||
)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的私聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取私聊流失败: {e}")
|
||||
|
||||
@@ -183,9 +183,10 @@ async def build_cross_context_s4u(
|
||||
blacklisted_streams.add(stream_id)
|
||||
except ValueError:
|
||||
logger.warning(f"无效的S4U黑名单格式: {chat_str}")
|
||||
for stream_id in chat_manager.streams:
|
||||
if stream_id != chat_stream.stream_id and stream_id not in blacklisted_streams:
|
||||
streams_to_scan.append(stream_id)
|
||||
streams_to_scan.extend(
|
||||
stream_id for stream_id in chat_manager.streams
|
||||
if stream_id != chat_stream.stream_id and stream_id not in blacklisted_streams
|
||||
)
|
||||
|
||||
logger.debug(f"[S4U] Found {len(streams_to_scan)} group streams to scan.")
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class ScoringAPI:
|
||||
return await relationship_service.get_user_relationship_data(user_id)
|
||||
|
||||
@staticmethod
|
||||
async def update_user_relationship(user_id: str, relationship_score: float, relationship_text: str = None, user_name: str = None):
|
||||
async def update_user_relationship(user_id: str, relationship_score: float, relationship_text: str | None = None, user_name: str | None = None):
|
||||
"""
|
||||
更新用户关系数据
|
||||
|
||||
@@ -71,7 +71,7 @@ class ScoringAPI:
|
||||
await interest_service.initialize_smart_interests(personality_description, personality_id)
|
||||
|
||||
@staticmethod
|
||||
async def calculate_interest_match(content: str, keywords: list[str] = None):
|
||||
async def calculate_interest_match(content: str, keywords: list[str] | None = None):
|
||||
"""
|
||||
计算内容与兴趣的匹配度
|
||||
|
||||
@@ -98,7 +98,7 @@ class ScoringAPI:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def clear_caches(user_id: str = None):
|
||||
def clear_caches(user_id: str | None = None):
|
||||
"""
|
||||
清理缓存
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -26,7 +26,7 @@ class PluginStorageManager:
|
||||
哼,现在它和API住在一起了,希望它们能和睦相处。
|
||||
"""
|
||||
|
||||
_instances: dict[str, "PluginStorage"] = {}
|
||||
_instances: ClassVar[dict[str, "PluginStorage"] ] = {}
|
||||
_lock = threading.Lock()
|
||||
_base_path = os.path.join("data", "plugin_data")
|
||||
|
||||
|
||||
@@ -9,11 +9,11 @@ logger = get_logger("tool_api")
|
||||
|
||||
def get_tool_instance(tool_name: str, chat_stream: Any = None) -> BaseTool | None:
|
||||
"""获取公开工具实例
|
||||
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
chat_stream: 聊天流对象,用于提供上下文信息
|
||||
|
||||
|
||||
Returns:
|
||||
BaseTool: 工具实例,如果工具不存在则返回None
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
import random
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
@@ -26,30 +26,30 @@ class BaseAction(ABC):
|
||||
新的激活机制 (推荐使用)
|
||||
==================================================================================
|
||||
推荐通过重写 go_activate() 方法来自定义激活逻辑:
|
||||
|
||||
|
||||
示例 1 - 关键词激活:
|
||||
async def go_activate(self, llm_judge_model=None) -> bool:
|
||||
return await self._keyword_match(["你好", "hello"])
|
||||
|
||||
|
||||
示例 2 - LLM 判断激活:
|
||||
async def go_activate(self, llm_judge_model=None) -> bool:
|
||||
return await self._llm_judge_activation(
|
||||
"当用户询问天气信息时激活",
|
||||
llm_judge_model
|
||||
)
|
||||
|
||||
|
||||
示例 3 - 组合多种条件:
|
||||
async def go_activate(self, llm_judge_model=None) -> bool:
|
||||
# 30% 随机概率,或者匹配关键词
|
||||
if await self._random_activation(0.3):
|
||||
return True
|
||||
return await self._keyword_match(["表情", "emoji"])
|
||||
|
||||
|
||||
提供的工具函数:
|
||||
- _random_activation(probability): 随机激活
|
||||
- _keyword_match(keywords, case_sensitive): 关键词匹配(自动获取聊天内容)
|
||||
- _llm_judge_activation(judge_prompt, llm_judge_model): LLM 判断(自动获取聊天内容)
|
||||
|
||||
|
||||
注意:聊天内容会自动从实例属性中获取,无需手动传入。
|
||||
|
||||
==================================================================================
|
||||
@@ -68,7 +68,7 @@ class BaseAction(ABC):
|
||||
==================================================================================
|
||||
- mode_enable: 启用的聊天模式
|
||||
- parallel_action: 是否允许并行执行
|
||||
|
||||
|
||||
二步Action相关属性:
|
||||
- is_two_step_action: 是否为二步Action
|
||||
- step_one_description: 第一步的描述
|
||||
@@ -80,7 +80,7 @@ class BaseAction(ABC):
|
||||
"""是否为二步Action。如果为True,Action将分两步执行:第一步选择操作,第二步执行具体操作"""
|
||||
step_one_description: str = ""
|
||||
"""第一步的描述,用于向LLM展示Action的基本功能"""
|
||||
sub_actions: list[tuple[str, str, dict[str, str]]] = []
|
||||
sub_actions: ClassVar[list[tuple[str, str, dict[str, str]]] ] = []
|
||||
"""子Action列表,格式为[(子Action名, 子Action描述, 子Action参数)]。仅在二步Action中使用"""
|
||||
|
||||
def __init__(
|
||||
@@ -110,7 +110,7 @@ class BaseAction(ABC):
|
||||
**kwargs: 其他参数
|
||||
"""
|
||||
if plugin_config is None:
|
||||
plugin_config = {}
|
||||
plugin_config: ClassVar = {}
|
||||
self.action_data = action_data
|
||||
self.reasoning = reasoning
|
||||
self.cycle_timers = cycle_timers
|
||||
@@ -489,7 +489,7 @@ class BaseAction(ABC):
|
||||
|
||||
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
|
||||
# 3. 实例化被调用的Action
|
||||
action_params = {
|
||||
action_params: ClassVar = {
|
||||
"action_data": called_action_data,
|
||||
"reasoning": f"Called by {self.action_name}",
|
||||
"cycle_timers": self.cycle_timers,
|
||||
@@ -615,9 +615,9 @@ class BaseAction(ABC):
|
||||
|
||||
def _get_chat_content(self) -> str:
|
||||
"""获取聊天内容用于激活判断
|
||||
|
||||
|
||||
从实例属性中获取聊天内容。子类可以重写此方法来自定义获取逻辑。
|
||||
|
||||
|
||||
Returns:
|
||||
str: 聊天内容
|
||||
"""
|
||||
@@ -645,7 +645,7 @@ class BaseAction(ABC):
|
||||
也可以使用提供的工具函数来简化常见的激活判断。
|
||||
|
||||
默认实现会检查类属性中的激活类型配置,提供向后兼容支持。
|
||||
|
||||
|
||||
聊天内容会自动从实例属性中获取,不需要手动传入。
|
||||
|
||||
Args:
|
||||
@@ -721,7 +721,7 @@ class BaseAction(ABC):
|
||||
case_sensitive: bool = False,
|
||||
) -> bool:
|
||||
"""关键词匹配工具函数
|
||||
|
||||
|
||||
聊天内容会自动从实例属性中获取。
|
||||
|
||||
Args:
|
||||
@@ -742,7 +742,7 @@ class BaseAction(ABC):
|
||||
if not case_sensitive:
|
||||
search_text = search_text.lower()
|
||||
|
||||
matched_keywords = []
|
||||
matched_keywords: ClassVar = []
|
||||
for keyword in keywords:
|
||||
check_keyword = keyword if case_sensitive else keyword.lower()
|
||||
if check_keyword in search_text:
|
||||
@@ -766,7 +766,7 @@ class BaseAction(ABC):
|
||||
|
||||
使用 LLM 来判断是否应该激活此 Action。
|
||||
会自动构建完整的判断提示词,只需要提供核心判断逻辑即可。
|
||||
|
||||
|
||||
聊天内容会自动从实例属性中获取。
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatterInfo, ComponentType
|
||||
@@ -15,7 +15,7 @@ class BaseChatter(ABC):
|
||||
"""Chatter组件的名称"""
|
||||
chatter_description: str = ""
|
||||
"""Chatter组件的描述"""
|
||||
chat_types: list[ChatType] = [ChatType.PRIVATE, ChatType.GROUP]
|
||||
chat_types: ClassVar[list[ChatType]] = [ChatType.PRIVATE, ChatType.GROUP]
|
||||
|
||||
def __init__(self, stream_id: str, action_manager: "ChatterActionManager"):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -21,7 +22,7 @@ class BaseEventHandler(ABC):
|
||||
"""处理器权重,越大权重越高"""
|
||||
intercept_message: bool = False
|
||||
"""是否拦截消息,默认为否"""
|
||||
init_subscribe: list[EventType | str] = [EventType.UNKNOWN]
|
||||
init_subscribe: ClassVar[list[EventType | str]] = [EventType.UNKNOWN]
|
||||
"""初始化时订阅的事件名称"""
|
||||
plugin_name = None
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
@@ -27,7 +27,7 @@ class BasePrompt(ABC):
|
||||
# 定义此组件希望如何注入到核心Prompt中
|
||||
# 这是一个 InjectionRule 对象的列表,可以实现复杂的注入逻辑
|
||||
# 例如: [InjectionRule(target_prompt="planner_prompt", injection_type=InjectionType.APPEND, priority=50)]
|
||||
injection_rules: list[InjectionRule] = []
|
||||
injection_rules: ClassVar[list[InjectionRule] ] = []
|
||||
"""定义注入规则的列表"""
|
||||
|
||||
# 旧的注入点定义,用于向后兼容。如果定义了这个,它将被自动转换为 injection_rules。
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
@@ -18,7 +18,7 @@ class BaseTool(ABC):
|
||||
"""工具的名称"""
|
||||
description: str = ""
|
||||
"""工具的描述"""
|
||||
parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
|
||||
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]] ] = []
|
||||
"""工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式
|
||||
param_name: 参数名称
|
||||
param_type: 参数类型
|
||||
@@ -44,7 +44,7 @@ class BaseTool(ABC):
|
||||
"""是否为二步工具。如果为True,工具将分两步调用:第一步展示工具信息,第二步执行具体操作"""
|
||||
step_one_description: str = ""
|
||||
"""第一步的描述,用于向LLM展示工具的基本功能"""
|
||||
sub_tools: list[tuple[str, str, list[tuple[str, ToolParamType, str, bool, list[str] | None]]]] = []
|
||||
sub_tools: ClassVar[list[tuple[str, str, list[tuple[str, ToolParamType, str, bool, list[str] | None]]]] ] = []
|
||||
"""子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用"""
|
||||
|
||||
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
|
||||
@@ -112,7 +112,7 @@ class BaseTool(ABC):
|
||||
if not cls.is_two_step_tool:
|
||||
return []
|
||||
|
||||
definitions = []
|
||||
definitions: ClassVar = []
|
||||
for sub_name, sub_desc, sub_params in cls.sub_tools:
|
||||
definitions.append({"name": f"{cls.name}_{sub_name}", "description": sub_desc, "parameters": sub_params})
|
||||
return definitions
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import toml
|
||||
|
||||
@@ -30,11 +30,11 @@ class PluginBase(ABC):
|
||||
config_file_name: str
|
||||
enable_plugin: bool = True
|
||||
|
||||
config_schema: dict[str, dict[str, ConfigField] | str] = {}
|
||||
config_schema: ClassVar[dict[str, dict[str, ConfigField] | str] ] = {}
|
||||
|
||||
permission_nodes: list["PermissionNodeField"] = []
|
||||
permission_nodes: ClassVar[list["PermissionNodeField"] ] = []
|
||||
|
||||
config_section_descriptions: dict[str, str] = {}
|
||||
config_section_descriptions: ClassVar[dict[str, str] ] = {}
|
||||
|
||||
def __init__(self, plugin_dir: str, metadata: PluginMetadata):
|
||||
"""初始化插件
|
||||
@@ -206,12 +206,12 @@ class PluginBase(ABC):
|
||||
if not self.config_schema:
|
||||
return {}
|
||||
|
||||
config_data = {}
|
||||
config_data: ClassVar = {}
|
||||
|
||||
# 遍历每个配置节
|
||||
for section, fields in self.config_schema.items():
|
||||
if isinstance(fields, dict):
|
||||
section_data = {}
|
||||
section_data: ClassVar = {}
|
||||
|
||||
# 遍历节内的字段
|
||||
for field_name, field in fields.items():
|
||||
@@ -331,7 +331,7 @@ class PluginBase(ABC):
|
||||
|
||||
try:
|
||||
with open(user_config_path, encoding="utf-8") as f:
|
||||
user_config = toml.load(f) or {}
|
||||
user_config: ClassVar = toml.load(f) or {}
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True)
|
||||
self.config = self._generate_config_from_schema() # 加载失败时使用默认 schema
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
@@ -42,7 +42,7 @@ class PlusCommand(ABC):
|
||||
command_description: str = ""
|
||||
"""命令描述"""
|
||||
|
||||
command_aliases: list[str] = []
|
||||
command_aliases: ClassVar[list[str] ] = []
|
||||
"""命令别名列表,如 ['say', 'repeat']"""
|
||||
|
||||
priority: int = 0
|
||||
@@ -435,7 +435,3 @@ def create_plus_command_adapter(plus_command_class):
|
||||
|
||||
return AdapterClass
|
||||
|
||||
|
||||
# 兼容旧的命名
|
||||
PlusCommandAdapter = create_plus_command_adapter
|
||||
|
||||
|
||||
@@ -87,8 +87,8 @@ class ComponentRegistry:
|
||||
self._tool_registry: dict[str, type["BaseTool"]] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: dict[str, type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
|
||||
|
||||
# MCP 工具注册表(运行时动态加载)
|
||||
self._mcp_tools: list["BaseTool"] = [] # MCP 工具适配器实例列表
|
||||
# MCP 工具注册表(运行时动态加载)
|
||||
self._mcp_tools: list[Any] = [] # MCP 工具适配器实例列表
|
||||
self._mcp_tools_loaded = False # MCP 工具是否已加载
|
||||
|
||||
# EventHandler特定注册表
|
||||
|
||||
@@ -7,7 +7,6 @@ from threading import Lock
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseEventHandler
|
||||
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
@@ -176,10 +175,10 @@ class EventManager:
|
||||
|
||||
# 处理init_subscribe,缓存失败的订阅
|
||||
if self._event_handlers[handler_name].init_subscribe:
|
||||
failed_subscriptions = []
|
||||
for event_name in self._event_handlers[handler_name].init_subscribe:
|
||||
if not self.subscribe_handler_to_event(handler_name, event_name):
|
||||
failed_subscriptions.append(event_name)
|
||||
failed_subscriptions = [
|
||||
event_name for event_name in self._event_handlers[handler_name].init_subscribe
|
||||
if not self.subscribe_handler_to_event(handler_name, event_name)
|
||||
]
|
||||
|
||||
# 缓存失败的订阅
|
||||
if failed_subscriptions:
|
||||
|
||||
@@ -4,7 +4,7 @@ MCP Tool Adapter
|
||||
将 MCP 工具适配为 BaseTool,使其能够被插件系统识别和调用
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any
|
||||
|
||||
import mcp.types
|
||||
|
||||
@@ -27,9 +27,6 @@ class MCPToolAdapter(BaseTool):
|
||||
3. 参与工具缓存机制
|
||||
"""
|
||||
|
||||
# 类级别默认值,使用 ClassVar 标注
|
||||
available_for_llm: ClassVar[bool] = True
|
||||
|
||||
def __init__(self, server_name: str, mcp_tool: mcp.types.Tool, plugin_config: dict | None = None):
|
||||
"""
|
||||
初始化 MCP 工具适配器
|
||||
@@ -47,6 +44,7 @@ class MCPToolAdapter(BaseTool):
|
||||
# 设置实例属性
|
||||
self.name = f"mcp_{server_name}_{mcp_tool.name}"
|
||||
self.description = mcp_tool.description or f"MCP tool from {server_name}"
|
||||
self.available_for_llm = True # MCP 工具默认可供 LLM 使用
|
||||
|
||||
# 转换参数定义
|
||||
self.parameters = self._convert_parameters(mcp_tool.inputSchema)
|
||||
|
||||
@@ -456,8 +456,7 @@ class PermissionManager(IPermissionManager):
|
||||
)
|
||||
granted_users = result.scalars().all()
|
||||
|
||||
for user_perm in granted_users:
|
||||
users.append((user_perm.platform, user_perm.user_id))
|
||||
users.extend((user_perm.platform, user_perm.user_id) for user_perm in granted_users)
|
||||
|
||||
# 如果是默认授权的权限节点,还需要考虑没有明确设置的用户
|
||||
# 但这里我们只返回明确授权的用户,避免返回所有用户
|
||||
|
||||
@@ -94,7 +94,6 @@ class PluginManager:
|
||||
if not plugin_class:
|
||||
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
|
||||
return False, 1
|
||||
init_module = None # 预先定义,避免后续条件加载导致未绑定
|
||||
try:
|
||||
# 使用记录的插件目录路径
|
||||
plugin_dir = self.plugin_paths.get(plugin_name)
|
||||
|
||||
Reference in New Issue
Block a user