style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
e7aaafde2f
commit
00ba07e0e1
@@ -19,7 +19,7 @@ from src.plugin_system.apis import (
|
||||
send_api,
|
||||
tool_api,
|
||||
permission_api,
|
||||
schedule_api
|
||||
schedule_api,
|
||||
)
|
||||
from src.plugin_system.apis.chat_api import ChatManager as context_api
|
||||
from .logging_api import get_logger
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
readable_text = message_api.build_readable_messages(messages)
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional, Coroutine
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
@@ -181,9 +181,7 @@ async def get_messages_by_time_in_chat_for_users(
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
return await get_raw_msg_by_timestamp_with_chat_users(
|
||||
chat_id, start_time, end_time, person_ids, limit, limit_mode
|
||||
)
|
||||
return await get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode)
|
||||
|
||||
|
||||
async def get_random_chat_messages(
|
||||
@@ -384,9 +382,7 @@ async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Op
|
||||
return await num_new_messages_since(chat_id, start_time, end_time)
|
||||
|
||||
|
||||
async def count_new_messages_for_users(
|
||||
chat_id: str, start_time: float, end_time: float, person_ids: List[str]
|
||||
) -> int:
|
||||
async def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
|
||||
"""
|
||||
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
|
||||
|
||||
|
||||
@@ -61,8 +61,7 @@ class PermissionAPI:
|
||||
def __init__(self):
|
||||
self._permission_manager: Optional[IPermissionManager] = None
|
||||
# 需要保留的前缀(视为绝对节点名,不再自动加 plugins.<plugin>. 前缀)
|
||||
self.RESERVED_PREFIXES: tuple[str, ...] = (
|
||||
"system.")
|
||||
self.RESERVED_PREFIXES: tuple[str, ...] = "system."
|
||||
# 系统节点列表 (name, description, default_granted)
|
||||
self._SYSTEM_NODES: list[tuple[str, str, bool]] = [
|
||||
("system.superuser", "系统超级管理员:拥有所有权限", False),
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
@@ -176,4 +177,4 @@ async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool:
|
||||
|
||||
async def archive_monthly_plans(target_month: Optional[str] = None) -> bool:
|
||||
"""(异步) 归档指定月份的月度计划的便捷函数"""
|
||||
return await ScheduleAPI.archive_monthly_plans(target_month)
|
||||
return await ScheduleAPI.archive_monthly_plans(target_month)
|
||||
|
||||
@@ -80,7 +80,9 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
|
||||
|
||||
message_info = {
|
||||
"platform": message_dict.get("chat_info_platform", ""),
|
||||
"message_id": message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id"),
|
||||
"message_id": message_dict.get("message_id")
|
||||
or message_dict.get("chat_info_message_id")
|
||||
or message_dict.get("id"),
|
||||
"time": message_dict.get("time"),
|
||||
"group_info": group_info,
|
||||
"user_info": user_info,
|
||||
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
import asyncio
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
from typing import Tuple, Optional, List, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
@@ -25,7 +25,7 @@ class BaseAction(ABC):
|
||||
- parallel_action: 是否允许并行执行
|
||||
- random_activation_probability: 随机激活概率
|
||||
- llm_judge_prompt: LLM判断提示词
|
||||
|
||||
|
||||
二步Action相关属性:
|
||||
- is_two_step_action: 是否为二步Action
|
||||
- step_one_description: 第一步的描述
|
||||
@@ -435,7 +435,9 @@ class BaseAction(ABC):
|
||||
|
||||
# 确保获取的是Action组件
|
||||
if component_info.component_type != ComponentType.ACTION:
|
||||
logger.error(f"{log_prefix} 尝试调用的组件 '{action_name}' 不是一个Action,而是一个 '{component_info.component_type.value}'")
|
||||
logger.error(
|
||||
f"{log_prefix} 尝试调用的组件 '{action_name}' 不是一个Action,而是一个 '{component_info.component_type.value}'"
|
||||
)
|
||||
return False, f"组件 '{action_name}' 不是一个有效的Action"
|
||||
|
||||
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
|
||||
@@ -528,20 +530,20 @@ class BaseAction(ABC):
|
||||
# 第一步:展示可用的子Action
|
||||
available_actions = [sub_action[0] for sub_action in self.sub_actions]
|
||||
description = self.step_one_description or f"{self.action_name}支持以下操作"
|
||||
|
||||
|
||||
actions_list = "\n".join([f"- {action}: {desc}" for action, desc, _ in self.sub_actions])
|
||||
response = f"{description}\n\n可用操作:\n{actions_list}\n\n请选择要执行的操作。"
|
||||
|
||||
|
||||
return True, response
|
||||
else:
|
||||
# 验证选择的子Action是否有效
|
||||
valid_actions = [sub_action[0] for sub_action in self.sub_actions]
|
||||
if selected_action not in valid_actions:
|
||||
return False, f"无效的操作选择: {selected_action}。可用操作: {valid_actions}"
|
||||
|
||||
|
||||
# 保存选择的子Action
|
||||
self._selected_sub_action = selected_action
|
||||
|
||||
|
||||
# 调用第二步执行
|
||||
return await self.execute_step_two(selected_action)
|
||||
|
||||
@@ -572,7 +574,7 @@ class BaseAction(ABC):
|
||||
# 如果是二步Action,自动处理第一步
|
||||
if self.is_two_step_action:
|
||||
return await self.handle_step_one()
|
||||
|
||||
|
||||
# 普通Action由子类实现
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
from typing import List, TYPE_CHECKING
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from .component_types import ChatType
|
||||
from src.plugin_system.base.component_types import ChatterInfo, ComponentType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner as ActionPlanner
|
||||
|
||||
|
||||
class BaseChatter(ABC):
|
||||
chatter_name: str = ""
|
||||
@@ -15,7 +15,7 @@ class BaseChatter(ABC):
|
||||
"""Chatter组件的描述"""
|
||||
chat_types: List[ChatType] = [ChatType.PRIVATE, ChatType.GROUP]
|
||||
|
||||
def __init__(self, stream_id: str, action_manager: 'ChatterActionManager'):
|
||||
def __init__(self, stream_id: str, action_manager: "ChatterActionManager"):
|
||||
"""
|
||||
初始化聊天处理器
|
||||
|
||||
@@ -45,11 +45,10 @@ class BaseChatter(ABC):
|
||||
Returns:
|
||||
ChatterInfo对象
|
||||
"""
|
||||
|
||||
|
||||
return ChatterInfo(
|
||||
name=cls.chatter_name,
|
||||
description=cls.chatter_description or "No description provided.",
|
||||
chat_type_allow=cls.chat_types[0],
|
||||
component_type=ComponentType.CHATTER,
|
||||
)
|
||||
|
||||
|
||||
@@ -64,7 +64,15 @@ class BaseTool(ABC):
|
||||
return {
|
||||
"name": cls.name,
|
||||
"description": cls.step_one_description or cls.description,
|
||||
"parameters": [("action", ToolParamType.STRING, "选择要执行的操作", True, [sub_tool[0] for sub_tool in cls.sub_tools])]
|
||||
"parameters": [
|
||||
(
|
||||
"action",
|
||||
ToolParamType.STRING,
|
||||
"选择要执行的操作",
|
||||
True,
|
||||
[sub_tool[0] for sub_tool in cls.sub_tools],
|
||||
)
|
||||
],
|
||||
}
|
||||
else:
|
||||
# 普通工具需要parameters
|
||||
@@ -88,12 +96,8 @@ class BaseTool(ABC):
|
||||
# 查找对应的子工具
|
||||
for sub_name, sub_desc, sub_params in cls.sub_tools:
|
||||
if sub_name == sub_tool_name:
|
||||
return {
|
||||
"name": f"{cls.name}_{sub_tool_name}",
|
||||
"description": sub_desc,
|
||||
"parameters": sub_params
|
||||
}
|
||||
|
||||
return {"name": f"{cls.name}_{sub_tool_name}", "description": sub_desc, "parameters": sub_params}
|
||||
|
||||
raise ValueError(f"未找到子工具: {sub_tool_name}")
|
||||
|
||||
@classmethod
|
||||
@@ -105,14 +109,10 @@ class BaseTool(ABC):
|
||||
"""
|
||||
if not cls.is_two_step_tool:
|
||||
return []
|
||||
|
||||
|
||||
definitions = []
|
||||
for sub_name, sub_desc, sub_params in cls.sub_tools:
|
||||
definitions.append({
|
||||
"name": f"{cls.name}_{sub_name}",
|
||||
"description": sub_desc,
|
||||
"parameters": sub_params
|
||||
})
|
||||
definitions.append({"name": f"{cls.name}_{sub_name}", "description": sub_desc, "parameters": sub_params})
|
||||
return definitions
|
||||
|
||||
@classmethod
|
||||
@@ -144,7 +144,7 @@ class BaseTool(ABC):
|
||||
# 如果是二步工具,处理第一步调用
|
||||
if self.is_two_step_tool and "action" in function_args:
|
||||
return await self._handle_step_one(function_args)
|
||||
|
||||
|
||||
raise NotImplementedError("子类必须实现execute方法")
|
||||
|
||||
async def _handle_step_one(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -174,17 +174,13 @@ class BaseTool(ABC):
|
||||
sub_name, sub_desc, sub_params = sub_tool_found
|
||||
|
||||
# 返回第二步工具定义
|
||||
step_two_definition = {
|
||||
"name": f"{self.name}_{sub_name}",
|
||||
"description": sub_desc,
|
||||
"parameters": sub_params
|
||||
}
|
||||
step_two_definition = {"name": f"{self.name}_{sub_name}", "description": sub_desc, "parameters": sub_params}
|
||||
|
||||
return {
|
||||
"type": "two_step_tool_step_one",
|
||||
"content": f"已选择操作: {action}。请使用以下工具进行具体调用:",
|
||||
"next_tool_definition": step_two_definition,
|
||||
"selected_action": action
|
||||
"selected_action": action,
|
||||
}
|
||||
|
||||
async def execute_step_two(self, sub_tool_name: str, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
@@ -40,7 +40,7 @@ class ActionActivationType(Enum):
|
||||
# 聊天模式枚举
|
||||
class ChatMode(Enum):
|
||||
"""聊天模式枚举"""
|
||||
|
||||
|
||||
FOCUS = "focus" # 专注模式
|
||||
NORMAL = "normal" # Normal聊天模式
|
||||
PROACTIVE = "proactive" # 主动思考模式
|
||||
|
||||
@@ -294,9 +294,7 @@ class PluginBase(ABC):
|
||||
changed = False
|
||||
|
||||
# 内部递归函数
|
||||
def _sync_dicts(
|
||||
schema_dict: Dict[str, Any], user_dict: Dict[str, Any], parent_key: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
def _sync_dicts(schema_dict: Dict[str, Any], user_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, Any]:
|
||||
nonlocal changed
|
||||
synced_dict = schema_dict.copy()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
@@ -34,44 +34,46 @@ class ComponentRegistry:
|
||||
|
||||
def __init__(self):
|
||||
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
|
||||
self._components: Dict[str, 'ComponentInfo'] = {}
|
||||
self._components: Dict[str, "ComponentInfo"] = {}
|
||||
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
||||
self._components_by_type: Dict['ComponentType', Dict[str, 'ComponentInfo']] = {types: {} for types in ComponentType}
|
||||
self._components_by_type: Dict["ComponentType", Dict[str, "ComponentInfo"]] = {
|
||||
types: {} for types in ComponentType
|
||||
}
|
||||
"""类型 -> 组件原名称 -> 组件信息"""
|
||||
self._components_classes: Dict[
|
||||
str, Type[Union['BaseCommand', 'BaseAction', 'BaseTool', 'BaseEventHandler', 'PlusCommand', 'BaseChatter']]
|
||||
str, Type[Union["BaseCommand", "BaseAction", "BaseTool", "BaseEventHandler", "PlusCommand", "BaseChatter"]]
|
||||
] = {}
|
||||
"""命名空间式组件名 -> 组件类"""
|
||||
|
||||
# 插件注册表
|
||||
self._plugins: Dict[str, 'PluginInfo'] = {}
|
||||
self._plugins: Dict[str, "PluginInfo"] = {}
|
||||
"""插件名 -> 插件信息"""
|
||||
|
||||
# Action特定注册表
|
||||
self._action_registry: Dict[str, Type['BaseAction']] = {}
|
||||
self._action_registry: Dict[str, Type["BaseAction"]] = {}
|
||||
"""Action注册表 action名 -> action类"""
|
||||
self._default_actions: Dict[str, 'ActionInfo'] = {}
|
||||
self._default_actions: Dict[str, "ActionInfo"] = {}
|
||||
"""默认动作集,即启用的Action集,用于重置ActionManager状态"""
|
||||
|
||||
# Command特定注册表
|
||||
self._command_registry: Dict[str, Type['BaseCommand']] = {}
|
||||
self._command_registry: Dict[str, Type["BaseCommand"]] = {}
|
||||
"""Command类注册表 command名 -> command类"""
|
||||
self._command_patterns: Dict[Pattern, str] = {}
|
||||
"""编译后的正则 -> command名"""
|
||||
|
||||
# 工具特定注册表
|
||||
self._tool_registry: Dict[str, Type['BaseTool']] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: Dict[str, Type['BaseTool']] = {} # llm可用的工具名 -> 工具类
|
||||
self._tool_registry: Dict[str, Type["BaseTool"]] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: Dict[str, Type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
|
||||
|
||||
# EventHandler特定注册表
|
||||
self._event_handler_registry: Dict[str, Type['BaseEventHandler']] = {}
|
||||
self._event_handler_registry: Dict[str, Type["BaseEventHandler"]] = {}
|
||||
"""event_handler名 -> event_handler类"""
|
||||
self._enabled_event_handlers: Dict[str, Type['BaseEventHandler']] = {}
|
||||
self._enabled_event_handlers: Dict[str, Type["BaseEventHandler"]] = {}
|
||||
"""启用的事件处理器 event_handler名 -> event_handler类"""
|
||||
|
||||
self._chatter_registry: Dict[str, Type['BaseChatter']] = {}
|
||||
self._chatter_registry: Dict[str, Type["BaseChatter"]] = {}
|
||||
"""chatter名 -> chatter类"""
|
||||
self._enabled_chatter_registry: Dict[str, Type['BaseChatter']] = {}
|
||||
self._enabled_chatter_registry: Dict[str, Type["BaseChatter"]] = {}
|
||||
"""启用的chatter名 -> chatter类"""
|
||||
logger.info("组件注册中心初始化完成")
|
||||
|
||||
@@ -99,7 +101,7 @@ class ComponentRegistry:
|
||||
def register_component(
|
||||
self,
|
||||
component_info: ComponentInfo,
|
||||
component_class: Type[Union['BaseCommand', 'BaseAction', 'BaseEventHandler', 'BaseTool', 'BaseChatter']],
|
||||
component_class: Type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]],
|
||||
) -> bool:
|
||||
"""注册组件
|
||||
|
||||
@@ -172,7 +174,7 @@ class ComponentRegistry:
|
||||
)
|
||||
return True
|
||||
|
||||
def _register_action_component(self, action_info: 'ActionInfo', action_class: Type['BaseAction']) -> bool:
|
||||
def _register_action_component(self, action_info: "ActionInfo", action_class: Type["BaseAction"]) -> bool:
|
||||
"""注册Action组件到Action特定注册表"""
|
||||
if not (action_name := action_info.name):
|
||||
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
|
||||
@@ -192,7 +194,7 @@ class ComponentRegistry:
|
||||
|
||||
return True
|
||||
|
||||
def _register_command_component(self, command_info: 'CommandInfo', command_class: Type['BaseCommand']) -> bool:
|
||||
def _register_command_component(self, command_info: "CommandInfo", command_class: Type["BaseCommand"]) -> bool:
|
||||
"""注册Command组件到Command特定注册表"""
|
||||
if not (command_name := command_info.name):
|
||||
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
|
||||
@@ -219,7 +221,7 @@ class ComponentRegistry:
|
||||
return True
|
||||
|
||||
def _register_plus_command_component(
|
||||
self, plus_command_info: 'PlusCommandInfo', plus_command_class: Type['PlusCommand']
|
||||
self, plus_command_info: "PlusCommandInfo", plus_command_class: Type["PlusCommand"]
|
||||
) -> bool:
|
||||
"""注册PlusCommand组件到特定注册表"""
|
||||
plus_command_name = plus_command_info.name
|
||||
@@ -233,7 +235,7 @@ class ComponentRegistry:
|
||||
|
||||
# 创建专门的PlusCommand注册表(如果还没有)
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
self._plus_command_registry: Dict[str, Type['PlusCommand']] = {}
|
||||
self._plus_command_registry: Dict[str, Type["PlusCommand"]] = {}
|
||||
|
||||
plus_command_class.plugin_name = plus_command_info.plugin_name
|
||||
# 设置插件配置
|
||||
@@ -243,7 +245,7 @@ class ComponentRegistry:
|
||||
logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
|
||||
return True
|
||||
|
||||
def _register_tool_component(self, tool_info: 'ToolInfo', tool_class: Type['BaseTool']) -> bool:
|
||||
def _register_tool_component(self, tool_info: "ToolInfo", tool_class: Type["BaseTool"]) -> bool:
|
||||
"""注册Tool组件到Tool特定注册表"""
|
||||
tool_name = tool_info.name
|
||||
|
||||
@@ -259,7 +261,7 @@ class ComponentRegistry:
|
||||
return True
|
||||
|
||||
def _register_event_handler_component(
|
||||
self, handler_info: 'EventHandlerInfo', handler_class: Type['BaseEventHandler']
|
||||
self, handler_info: "EventHandlerInfo", handler_class: Type["BaseEventHandler"]
|
||||
) -> bool:
|
||||
if not (handler_name := handler_info.name):
|
||||
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
|
||||
@@ -285,7 +287,7 @@ class ComponentRegistry:
|
||||
handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
|
||||
)
|
||||
|
||||
def _register_chatter_component(self, chatter_info: 'ChatterInfo', chatter_class: Type['BaseChatter']) -> bool:
|
||||
def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: Type["BaseChatter"]) -> bool:
|
||||
"""注册Chatter组件到Chatter特定注册表"""
|
||||
chatter_name = chatter_info.name
|
||||
|
||||
@@ -312,7 +314,7 @@ class ComponentRegistry:
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def remove_component(self, component_name: str, component_type: 'ComponentType', plugin_name: str) -> bool:
|
||||
async def remove_component(self, component_name: str, component_type: "ComponentType", plugin_name: str) -> bool:
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
if not target_component_class:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法移除")
|
||||
@@ -362,7 +364,7 @@ class ComponentRegistry:
|
||||
|
||||
case ComponentType.CHATTER:
|
||||
# 移除Chatter注册
|
||||
if hasattr(self, '_chatter_registry'):
|
||||
if hasattr(self, "_chatter_registry"):
|
||||
self._chatter_registry.pop(component_name, None)
|
||||
logger.debug(f"已移除Chatter组件: {component_name}")
|
||||
|
||||
@@ -484,8 +486,8 @@ class ComponentRegistry:
|
||||
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
self, component_name: str, component_type: Optional['ComponentType'] = None
|
||||
) -> Optional['ComponentInfo']:
|
||||
self, component_name: str, component_type: Optional["ComponentType"] = None
|
||||
) -> Optional["ComponentInfo"]:
|
||||
# sourcery skip: class-extract-method
|
||||
"""获取组件信息,支持自动命名空间解析
|
||||
|
||||
@@ -529,8 +531,8 @@ class ComponentRegistry:
|
||||
def get_component_class(
|
||||
self,
|
||||
component_name: str,
|
||||
component_type: Optional['ComponentType'] = None,
|
||||
) -> Optional[Union[Type['BaseCommand'], Type['BaseAction'], Type['BaseEventHandler'], Type['BaseTool']]]:
|
||||
component_type: Optional["ComponentType"] = None,
|
||||
) -> Optional[Union[Type["BaseCommand"], Type["BaseAction"], Type["BaseEventHandler"], Type["BaseTool"]]]:
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
@@ -572,22 +574,22 @@ class ComponentRegistry:
|
||||
# 4. 都没找到
|
||||
return None
|
||||
|
||||
def get_components_by_type(self, component_type: 'ComponentType') -> Dict[str, 'ComponentInfo']:
|
||||
def get_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]:
|
||||
"""获取指定类型的所有组件"""
|
||||
return self._components_by_type.get(component_type, {}).copy()
|
||||
|
||||
def get_enabled_components_by_type(self, component_type: 'ComponentType') -> Dict[str, 'ComponentInfo']:
|
||||
def get_enabled_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]:
|
||||
"""获取指定类型的所有启用组件"""
|
||||
components = self.get_components_by_type(component_type)
|
||||
return {name: info for name, info in components.items() if info.enabled}
|
||||
|
||||
# === Action特定查询方法 ===
|
||||
|
||||
def get_action_registry(self) -> Dict[str, Type['BaseAction']]:
|
||||
def get_action_registry(self) -> Dict[str, Type["BaseAction"]]:
|
||||
"""获取Action注册表"""
|
||||
return self._action_registry.copy()
|
||||
|
||||
def get_registered_action_info(self, action_name: str) -> Optional['ActionInfo']:
|
||||
def get_registered_action_info(self, action_name: str) -> Optional["ActionInfo"]:
|
||||
"""获取Action信息"""
|
||||
info = self.get_component_info(action_name, ComponentType.ACTION)
|
||||
return info if isinstance(info, ActionInfo) else None
|
||||
@@ -598,11 +600,11 @@ class ComponentRegistry:
|
||||
|
||||
# === Command特定查询方法 ===
|
||||
|
||||
def get_command_registry(self) -> Dict[str, Type['BaseCommand']]:
|
||||
def get_command_registry(self) -> Dict[str, Type["BaseCommand"]]:
|
||||
"""获取Command注册表"""
|
||||
return self._command_registry.copy()
|
||||
|
||||
def get_registered_command_info(self, command_name: str) -> Optional['CommandInfo']:
|
||||
def get_registered_command_info(self, command_name: str) -> Optional["CommandInfo"]:
|
||||
"""获取Command信息"""
|
||||
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
||||
return info if isinstance(info, CommandInfo) else None
|
||||
@@ -611,7 +613,7 @@ class ComponentRegistry:
|
||||
"""获取Command模式注册表"""
|
||||
return self._command_patterns.copy()
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type['BaseCommand'], dict, 'CommandInfo']]:
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type["BaseCommand"], dict, "CommandInfo"]]:
|
||||
# sourcery skip: use-named-expression, use-next
|
||||
"""根据文本查找匹配的命令
|
||||
|
||||
@@ -638,15 +640,15 @@ class ComponentRegistry:
|
||||
return None
|
||||
|
||||
# === Tool 特定查询方法 ===
|
||||
def get_tool_registry(self) -> Dict[str, Type['BaseTool']]:
|
||||
def get_tool_registry(self) -> Dict[str, Type["BaseTool"]]:
|
||||
"""获取Tool注册表"""
|
||||
return self._tool_registry.copy()
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, Type['BaseTool']]:
|
||||
def get_llm_available_tools(self) -> Dict[str, Type["BaseTool"]]:
|
||||
"""获取LLM可用的Tool列表"""
|
||||
return self._llm_available_tools.copy()
|
||||
|
||||
def get_registered_tool_info(self, tool_name: str) -> Optional['ToolInfo']:
|
||||
def get_registered_tool_info(self, tool_name: str) -> Optional["ToolInfo"]:
|
||||
"""获取Tool信息
|
||||
|
||||
Args:
|
||||
@@ -659,13 +661,13 @@ class ComponentRegistry:
|
||||
return info if isinstance(info, ToolInfo) else None
|
||||
|
||||
# === PlusCommand 特定查询方法 ===
|
||||
def get_plus_command_registry(self) -> Dict[str, Type['PlusCommand']]:
|
||||
def get_plus_command_registry(self) -> Dict[str, Type["PlusCommand"]]:
|
||||
"""获取PlusCommand注册表"""
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
pass
|
||||
return self._plus_command_registry.copy()
|
||||
|
||||
def get_registered_plus_command_info(self, command_name: str) -> Optional['PlusCommandInfo']:
|
||||
def get_registered_plus_command_info(self, command_name: str) -> Optional["PlusCommandInfo"]:
|
||||
"""获取PlusCommand信息
|
||||
|
||||
Args:
|
||||
@@ -679,44 +681,44 @@ class ComponentRegistry:
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
|
||||
def get_event_handler_registry(self) -> Dict[str, Type['BaseEventHandler']]:
|
||||
def get_event_handler_registry(self) -> Dict[str, Type["BaseEventHandler"]]:
|
||||
"""获取事件处理器注册表"""
|
||||
return self._event_handler_registry.copy()
|
||||
|
||||
def get_registered_event_handler_info(self, handler_name: str) -> Optional['EventHandlerInfo']:
|
||||
def get_registered_event_handler_info(self, handler_name: str) -> Optional["EventHandlerInfo"]:
|
||||
"""获取事件处理器信息"""
|
||||
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
|
||||
return info if isinstance(info, EventHandlerInfo) else None
|
||||
|
||||
def get_enabled_event_handlers(self) -> Dict[str, Type['BaseEventHandler']]:
|
||||
def get_enabled_event_handlers(self) -> Dict[str, Type["BaseEventHandler"]]:
|
||||
"""获取启用的事件处理器"""
|
||||
return self._enabled_event_handlers.copy()
|
||||
|
||||
# === Chatter 特定查询方法 ===
|
||||
def get_chatter_registry(self) -> Dict[str, Type['BaseChatter']]:
|
||||
def get_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]:
|
||||
"""获取Chatter注册表"""
|
||||
if not hasattr(self, '_chatter_registry'):
|
||||
if not hasattr(self, "_chatter_registry"):
|
||||
self._chatter_registry: Dict[str, Type[BaseChatter]] = {}
|
||||
return self._chatter_registry.copy()
|
||||
|
||||
def get_enabled_chatter_registry(self) -> Dict[str, Type['BaseChatter']]:
|
||||
|
||||
def get_enabled_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]:
|
||||
"""获取启用的Chatter注册表"""
|
||||
if not hasattr(self, '_enabled_chatter_registry'):
|
||||
if not hasattr(self, "_enabled_chatter_registry"):
|
||||
self._enabled_chatter_registry: Dict[str, Type[BaseChatter]] = {}
|
||||
return self._enabled_chatter_registry.copy()
|
||||
|
||||
def get_registered_chatter_info(self, chatter_name: str) -> Optional['ChatterInfo']:
|
||||
|
||||
def get_registered_chatter_info(self, chatter_name: str) -> Optional["ChatterInfo"]:
|
||||
"""获取Chatter信息"""
|
||||
info = self.get_component_info(chatter_name, ComponentType.CHATTER)
|
||||
return info if isinstance(info, ChatterInfo) else None
|
||||
|
||||
|
||||
# === 插件查询方法 ===
|
||||
|
||||
def get_plugin_info(self, plugin_name: str) -> Optional['PluginInfo']:
|
||||
def get_plugin_info(self, plugin_name: str) -> Optional["PluginInfo"]:
|
||||
"""获取插件信息"""
|
||||
return self._plugins.get(plugin_name)
|
||||
|
||||
def get_all_plugins(self) -> Dict[str, 'PluginInfo']:
|
||||
def get_all_plugins(self) -> Dict[str, "PluginInfo"]:
|
||||
"""获取所有插件"""
|
||||
return self._plugins.copy()
|
||||
|
||||
@@ -724,7 +726,7 @@ class ComponentRegistry:
|
||||
# """获取所有启用的插件"""
|
||||
# return {name: info for name, info in self._plugins.items() if info.enabled}
|
||||
|
||||
def get_plugin_components(self, plugin_name: str) -> List['ComponentInfo']:
|
||||
def get_plugin_components(self, plugin_name: str) -> List["ComponentInfo"]:
|
||||
"""获取插件的所有组件"""
|
||||
plugin_info = self.get_plugin_info(plugin_name)
|
||||
return plugin_info.components if plugin_info else []
|
||||
|
||||
@@ -95,17 +95,16 @@ class PermissionManager(IPermissionManager):
|
||||
|
||||
# 检查用户是否有明确的权限设置
|
||||
result = await session.execute(
|
||||
select(UserPermissions)
|
||||
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
|
||||
select(UserPermissions).filter_by(
|
||||
platform=user.platform, user_id=user.user_id, permission_node=permission_node
|
||||
)
|
||||
)
|
||||
user_perm = result.scalar_one_or_none()
|
||||
|
||||
if user_perm:
|
||||
# 有明确设置,返回设置的值
|
||||
res = user_perm.granted
|
||||
logger.debug(
|
||||
f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {res}"
|
||||
)
|
||||
logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {res}")
|
||||
return res
|
||||
else:
|
||||
# 没有明确设置,使用默认值
|
||||
@@ -191,8 +190,9 @@ class PermissionManager(IPermissionManager):
|
||||
|
||||
# 检查是否已有权限记录
|
||||
result = await session.execute(
|
||||
select(UserPermissions)
|
||||
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
|
||||
select(UserPermissions).filter_by(
|
||||
platform=user.platform, user_id=user.user_id, permission_node=permission_node
|
||||
)
|
||||
)
|
||||
existing_perm = result.scalar_one_or_none()
|
||||
|
||||
@@ -244,8 +244,9 @@ class PermissionManager(IPermissionManager):
|
||||
|
||||
# 检查是否已有权限记录
|
||||
result = await session.execute(
|
||||
select(UserPermissions)
|
||||
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
|
||||
select(UserPermissions).filter_by(
|
||||
platform=user.platform, user_id=user.user_id, permission_node=permission_node
|
||||
)
|
||||
)
|
||||
existing_perm = result.scalar_one_or_none()
|
||||
|
||||
@@ -303,8 +304,9 @@ class PermissionManager(IPermissionManager):
|
||||
for node in all_nodes:
|
||||
# 检查用户是否有明确的权限设置
|
||||
result = await session.execute(
|
||||
select(UserPermissions)
|
||||
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=node.node_name)
|
||||
select(UserPermissions).filter_by(
|
||||
platform=user.platform, user_id=user.user_id, permission_node=node.node_name
|
||||
)
|
||||
)
|
||||
user_perm = result.scalar_one_or_none()
|
||||
|
||||
@@ -408,8 +410,7 @@ class PermissionManager(IPermissionManager):
|
||||
|
||||
# 删除用户权限记录
|
||||
result = await session.execute(
|
||||
delete(UserPermissions)
|
||||
.where(UserPermissions.permission_node.in_(node_names))
|
||||
delete(UserPermissions).where(UserPermissions.permission_node.in_(node_names))
|
||||
)
|
||||
deleted_user_perms = result.rowcount
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import hashlib
|
||||
import traceback
|
||||
import importlib
|
||||
|
||||
@@ -106,7 +104,6 @@ class PluginManager:
|
||||
if not plugin_dir:
|
||||
return False, 1
|
||||
|
||||
|
||||
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败)
|
||||
if not plugin_instance:
|
||||
logger.error(f"插件 {plugin_name} 实例化失败")
|
||||
@@ -545,9 +542,7 @@ class PluginManager:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
component_registry.unregister_plugin(plugin_name), loop
|
||||
)
|
||||
fut = asyncio.run_coroutine_threadsafe(component_registry.unregister_plugin(plugin_name), loop)
|
||||
fut.result(timeout=5)
|
||||
else:
|
||||
asyncio.run(component_registry.unregister_plugin(plugin_name))
|
||||
|
||||
@@ -116,17 +116,17 @@ class ToolExecutor:
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
all_tools = get_llm_available_tool_definitions()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
|
||||
|
||||
# 获取基础工具定义(包括二步工具的第一步)
|
||||
tool_definitions = [definition for name, definition in all_tools if name not in user_disabled_tools]
|
||||
|
||||
|
||||
# 检查是否有待处理的二步工具第二步调用
|
||||
pending_step_two = getattr(self, '_pending_step_two_tools', {})
|
||||
pending_step_two = getattr(self, "_pending_step_two_tools", {})
|
||||
if pending_step_two:
|
||||
# 添加第二步工具定义
|
||||
for tool_name, step_two_def in pending_step_two.items():
|
||||
tool_definitions.append(step_two_def)
|
||||
|
||||
|
||||
return tool_definitions
|
||||
|
||||
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
@@ -266,7 +266,7 @@ class ToolExecutor:
|
||||
f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}"
|
||||
)
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
|
||||
|
||||
# 检查是否是二步工具的第二步调用
|
||||
if "_" in function_name and function_name.count("_") >= 1:
|
||||
# 可能是二步工具的第二步调用,格式为 "tool_name_sub_tool_name"
|
||||
@@ -274,14 +274,14 @@ class ToolExecutor:
|
||||
if len(parts) == 2:
|
||||
base_tool_name, sub_tool_name = parts
|
||||
base_tool_instance = get_tool_instance(base_tool_name)
|
||||
|
||||
|
||||
if base_tool_instance and base_tool_instance.is_two_step_tool:
|
||||
logger.info(f"{self.log_prefix}执行二步工具第二步: {base_tool_name}.{sub_tool_name}")
|
||||
result = await base_tool_instance.execute_step_two(sub_tool_name, function_args)
|
||||
|
||||
|
||||
# 清理待处理的第二步工具
|
||||
self._pending_step_two_tools.pop(base_tool_name, None)
|
||||
|
||||
|
||||
if result:
|
||||
logger.debug(f"{self.log_prefix}二步工具第二步 {function_name} 执行成功")
|
||||
return {
|
||||
@@ -291,7 +291,7 @@ class ToolExecutor:
|
||||
"type": "function",
|
||||
"content": result.get("content", ""),
|
||||
}
|
||||
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = tool_instance or get_tool_instance(function_name)
|
||||
if not tool_instance:
|
||||
@@ -301,7 +301,7 @@ class ToolExecutor:
|
||||
# 执行工具并记录日志
|
||||
logger.debug(f"{self.log_prefix}执行工具 {function_name},参数: {function_args}")
|
||||
result = await tool_instance.execute(function_args)
|
||||
|
||||
|
||||
# 检查是否是二步工具的第一步结果
|
||||
if result and result.get("type") == "two_step_tool_step_one":
|
||||
logger.info(f"{self.log_prefix}二步工具第一步完成: {function_name}")
|
||||
@@ -310,7 +310,7 @@ class ToolExecutor:
|
||||
if next_tool_def:
|
||||
self._pending_step_two_tools[function_name] = next_tool_def
|
||||
logger.debug(f"{self.log_prefix}已保存第二步工具定义: {next_tool_def['name']}")
|
||||
|
||||
|
||||
if result:
|
||||
logger.debug(f"{self.log_prefix}工具 {function_name} 执行成功,结果: {result}")
|
||||
return {
|
||||
|
||||
@@ -79,9 +79,11 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None)
|
||||
|
||||
if not iscoroutinefunction(func):
|
||||
logger.warning(f"函数 {func.__name__} 使用 require_permission 但非异步,已强制阻止执行")
|
||||
|
||||
async def blocked(*_a, **_k):
|
||||
logger.error("同步函数不再支持权限装饰器,请改为 async def")
|
||||
return None
|
||||
|
||||
return blocked
|
||||
return async_wrapper
|
||||
|
||||
@@ -146,9 +148,11 @@ def require_master(deny_message: Optional[str] = None):
|
||||
|
||||
if not iscoroutinefunction(func):
|
||||
logger.warning(f"函数 {func.__name__} 使用 require_master 但非异步,已强制阻止执行")
|
||||
|
||||
async def blocked(*_a, **_k):
|
||||
logger.error("同步函数不再支持 require_master,请改为 async def")
|
||||
return None
|
||||
|
||||
return blocked
|
||||
return async_wrapper
|
||||
|
||||
@@ -164,7 +168,9 @@ class PermissionChecker:
|
||||
|
||||
@staticmethod
|
||||
def check_permission(chat_stream: ChatStream, permission_node: str) -> bool:
|
||||
raise RuntimeError("PermissionChecker.check_permission 已移除同步支持,请直接 await permission_api.check_permission")
|
||||
raise RuntimeError(
|
||||
"PermissionChecker.check_permission 已移除同步支持,请直接 await permission_api.check_permission"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_master(chat_stream: ChatStream) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user