feat: 添加插件配置支持,优化适配器和组件初始化
This commit is contained in:
@@ -109,14 +109,15 @@ class BaseAction(ABC):
|
||||
action_message: 消息数据
|
||||
**kwargs: 其他参数
|
||||
"""
|
||||
if plugin_config is None:
|
||||
plugin_config: ClassVar = {}
|
||||
self.action_data = action_data
|
||||
self.reasoning = reasoning
|
||||
self.cycle_timers = cycle_timers
|
||||
self.thinking_id = thinking_id
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
if plugin_config is None:
|
||||
plugin_config = getattr(self.__class__, "plugin_config", {})
|
||||
|
||||
self.plugin_config = plugin_config or {}
|
||||
"""对应的插件配置"""
|
||||
|
||||
|
||||
@@ -75,6 +75,17 @@ class BaseAdapter(MoFoxAdapterBase, ABC):
|
||||
"""设置适配器配置"""
|
||||
self._config = value
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""获取适配器配置,优先使用插件配置,其次使用内部配置。"""
|
||||
current = self.config or {}
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动适配器"""
|
||||
logger.info(f"启动适配器: {self.adapter_name} v{self.adapter_version}")
|
||||
|
||||
@@ -12,29 +12,34 @@ if TYPE_CHECKING:
|
||||
|
||||
class BaseChatter(ABC):
|
||||
chatter_name: str = ""
|
||||
"""Chatter组件的名称"""
|
||||
"""Chatter组件名称"""
|
||||
chatter_description: str = ""
|
||||
"""Chatter组件的描述"""
|
||||
"""Chatter组件描述"""
|
||||
chat_types: ClassVar[list[ChatType]] = [ChatType.PRIVATE, ChatType.GROUP]
|
||||
|
||||
def __init__(self, stream_id: str, action_manager: "ChatterActionManager"):
|
||||
def __init__(self, stream_id: str, action_manager: "ChatterActionManager", plugin_config: dict | None = None):
|
||||
"""
|
||||
初始化聊天处理器
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
action_manager: 动作管理器
|
||||
plugin_config: 插件配置字典
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.action_manager = action_manager
|
||||
if plugin_config is None:
|
||||
plugin_config = getattr(self.__class__, "plugin_config", {})
|
||||
|
||||
self.plugin_config = plugin_config or {}
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, context: StreamContext) -> dict:
|
||||
"""
|
||||
执行聊天处理流程
|
||||
执行聊天处理逻辑
|
||||
|
||||
Args:
|
||||
context: StreamContext对象,包含聊天流的所有消息信息
|
||||
context: StreamContext对象,包含聊天上下文信息
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
@@ -43,9 +48,9 @@ class BaseChatter(ABC):
|
||||
|
||||
@classmethod
|
||||
def get_chatter_info(cls) -> "ChatterInfo":
|
||||
"""从类属性生成ChatterInfo
|
||||
"""构造并返回ChatterInfo
|
||||
Returns:
|
||||
ChatterInfo对象
|
||||
ChatterInfo实例
|
||||
"""
|
||||
|
||||
return ChatterInfo(
|
||||
@@ -54,3 +59,16 @@ class BaseChatter(ABC):
|
||||
chat_type_allow=cls.chat_types[0],
|
||||
component_type=ComponentType.CHATTER,
|
||||
)
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置,支持嵌套键"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
current = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@@ -7,34 +7,52 @@ from .component_types import ComponentType, RouterInfo
|
||||
|
||||
class BaseRouterComponent(ABC):
|
||||
"""
|
||||
用于暴露HTTP端点的组件基类。
|
||||
插件开发者应继承此类,并实现 register_endpoints 方法来定义API路由。
|
||||
对外暴露HTTP接口的基类。
|
||||
插件路由类应继承本类,并实现 register_endpoints 方法注册API路由。
|
||||
"""
|
||||
# 组件元数据,由插件管理器读取
|
||||
|
||||
# 基本元数据,可由插件类读取
|
||||
component_name: str
|
||||
component_description: str
|
||||
component_version: str = "1.0.0"
|
||||
|
||||
# 每个组件实例都会管理自己的APIRouter
|
||||
# 每个路由实例都拥有自己的 APIRouter
|
||||
router: APIRouter
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, plugin_config: dict | None = None):
|
||||
if plugin_config is None:
|
||||
plugin_config = getattr(self.__class__, "plugin_config", {})
|
||||
self.plugin_config = plugin_config or {}
|
||||
|
||||
self.router = APIRouter()
|
||||
self.register_endpoints()
|
||||
|
||||
@abstractmethod
|
||||
def register_endpoints(self) -> None:
|
||||
"""
|
||||
【开发者必须实现】
|
||||
在此方法中定义所有HTTP端点。
|
||||
子类需要实现的方法。
|
||||
在此方法中定义插件的HTTP接口。
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_router_info(cls) -> "RouterInfo":
|
||||
"""从类属性生成RouterInfo"""
|
||||
"""构造 RouterInfo"""
|
||||
return RouterInfo(
|
||||
name=cls.component_name,
|
||||
description=getattr(cls, "component_description", "路由组件"),
|
||||
component_type=ComponentType.ROUTER,
|
||||
)
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,支持嵌套键"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
current = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
@@ -79,12 +79,16 @@ class BaseInterestCalculator(ABC):
|
||||
component_description: str = ""
|
||||
enabled_by_default: bool = True # 是否默认启用
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, plugin_config: dict | None = None):
|
||||
self._enabled = False
|
||||
self._last_calculation_time = 0.0
|
||||
self._total_calculations = 0
|
||||
self._failed_calculations = 0
|
||||
self._average_calculation_time = 0.0
|
||||
if plugin_config is None:
|
||||
plugin_config = getattr(self.__class__, "plugin_config", {})
|
||||
|
||||
self.plugin_config = plugin_config or {}
|
||||
|
||||
# 验证必须定义的属性
|
||||
if not self.component_name:
|
||||
@@ -193,6 +197,19 @@ class BaseInterestCalculator(ABC):
|
||||
self._update_statistics(result)
|
||||
return result
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""获取插件配置,支持嵌套键访问"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
current = self.plugin_config
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@classmethod
|
||||
def get_interest_calculator_info(cls) -> "InterestCalculatorInfo":
|
||||
"""从类属性生成InterestCalculatorInfo
|
||||
|
||||
@@ -34,7 +34,9 @@ class BasePrompt(ABC):
|
||||
injection_point: str | list[str] | None = None
|
||||
"""[已废弃] 要注入的目标Prompt名称或列表,请使用 injection_rules"""
|
||||
|
||||
def __init__(self, params: PromptParameters, plugin_config: dict | None = None, target_prompt_name: str | None = None):
|
||||
def __init__(
|
||||
self, params: PromptParameters, plugin_config: dict | None = None, target_prompt_name: str | None = None
|
||||
):
|
||||
"""初始化Prompt组件
|
||||
|
||||
Args:
|
||||
@@ -43,6 +45,9 @@ class BasePrompt(ABC):
|
||||
target_prompt_name: 在应用注入时,当前注入的目标提示词名称。
|
||||
"""
|
||||
self.params = params
|
||||
if plugin_config is None:
|
||||
plugin_config = getattr(self.__class__, "plugin_config", {})
|
||||
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.target_prompt_name = target_prompt_name
|
||||
self.log_prefix = "[PromptComponent]"
|
||||
|
||||
@@ -48,6 +48,9 @@ class BaseTool(ABC):
|
||||
"""子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用"""
|
||||
|
||||
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
|
||||
if plugin_config is None:
|
||||
plugin_config = getattr(self.__class__, "plugin_config", {})
|
||||
|
||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||
self.chat_stream = chat_stream # 存储聊天流信息,可用于获取上下文
|
||||
|
||||
@@ -205,7 +208,7 @@ class BaseTool(ABC):
|
||||
"""直接执行工具函数(供插件调用)
|
||||
通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数
|
||||
插件可以直接调用此方法,用更加明了的方式传入参数
|
||||
示例: result = await tool.direct_execute(arg1="参数",arg2="参数2")
|
||||
示例: result = await tool.direct_execute(arg1=\"参数\",arg2=\"参数2\")
|
||||
|
||||
工具开发者可以重写此方法以实现与llm调用差异化的执行逻辑
|
||||
|
||||
@@ -226,7 +229,7 @@ class BaseTool(ABC):
|
||||
"""获取插件配置值,使用嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
||||
key: 配置键名,使用嵌套访问如 \"section.subsection.key\"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -60,6 +60,9 @@ class PlusCommand(ABC):
|
||||
message: 接收到的消息对象(DatabaseMessages)
|
||||
plugin_config: 插件配置字典
|
||||
"""
|
||||
if plugin_config is None:
|
||||
plugin_config = getattr(self.__class__, "plugin_config", {})
|
||||
|
||||
self.message = message
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.log_prefix = "[PlusCommand]"
|
||||
|
||||
@@ -341,11 +341,9 @@ class ComponentRegistry:
|
||||
if not hasattr(self, "_enabled_interest_calculator_registry"):
|
||||
self._enabled_interest_calculator_registry: dict[str, type["BaseInterestCalculator"]] = {}
|
||||
|
||||
setattr(interest_calculator_class, "plugin_name", interest_calculator_info.plugin_name)
|
||||
# 设置插件配置
|
||||
setattr(
|
||||
_assign_plugin_attrs(
|
||||
interest_calculator_class,
|
||||
"plugin_config",
|
||||
interest_calculator_info.plugin_name,
|
||||
self.get_plugin_config(interest_calculator_info.plugin_name) or {},
|
||||
)
|
||||
self._interest_calculator_registry[calculator_name] = interest_calculator_class
|
||||
@@ -394,6 +392,8 @@ class ComponentRegistry:
|
||||
|
||||
router_name = router_info.name
|
||||
plugin_name = router_info.plugin_name
|
||||
plugin_config = self.get_plugin_config(plugin_name) or {}
|
||||
_assign_plugin_attrs(router_class, plugin_name, plugin_config)
|
||||
|
||||
# 2. 实例化组件以触发其 __init__ 和 register_endpoints
|
||||
component_instance = router_class()
|
||||
|
||||
Reference in New Issue
Block a user