🤖 自动格式化代码 [skip ci]
This commit is contained in:
@@ -16,15 +16,15 @@ class DatabaseAPI:
|
||||
"""
|
||||
|
||||
async def store_action_info(
|
||||
self,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
self,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: dict = None
|
||||
action_data: dict = None,
|
||||
) -> None:
|
||||
"""存储action信息到数据库
|
||||
|
||||
|
||||
Args:
|
||||
action_build_into_prompt: 是否构建到提示中
|
||||
action_prompt_display: 显示的action提示信息
|
||||
|
||||
@@ -54,10 +54,10 @@ class PluginAPI(MessageAPI, LLMAPI, DatabaseAPI, ConfigAPI, UtilsAPI, StreamAPI,
|
||||
}
|
||||
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
|
||||
# 存储action上下文信息
|
||||
self._action_context = {}
|
||||
|
||||
|
||||
# 调用所有父类的初始化
|
||||
super().__init__()
|
||||
|
||||
@@ -97,7 +97,7 @@ class PluginAPI(MessageAPI, LLMAPI, DatabaseAPI, ConfigAPI, UtilsAPI, StreamAPI,
|
||||
self._action_context["thinking_id"] = thinking_id
|
||||
self._action_context["shutting_down"] = shutting_down
|
||||
self._action_context.update(kwargs)
|
||||
|
||||
|
||||
def get_action_context(self, key: str, default=None):
|
||||
"""获取action上下文信息"""
|
||||
return self._action_context.get(key, default)
|
||||
|
||||
@@ -162,10 +162,10 @@ class StreamAPI:
|
||||
|
||||
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
|
||||
"""等待新消息或超时
|
||||
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒),默认1200秒
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否收到新消息, 空字符串)
|
||||
"""
|
||||
@@ -175,21 +175,21 @@ class StreamAPI:
|
||||
if not observations:
|
||||
logger.warning(f"{self.log_prefix} 无法获取observations服务,无法等待新消息")
|
||||
return False, ""
|
||||
|
||||
|
||||
# 获取第一个观察对象(通常是ChattingObservation)
|
||||
observation = observations[0] if observations else None
|
||||
if not observation:
|
||||
logger.warning(f"{self.log_prefix} 无观察对象,无法等待新消息")
|
||||
return False, ""
|
||||
|
||||
|
||||
# 从action上下文获取thinking_id
|
||||
thinking_id = self.get_action_context("thinking_id")
|
||||
if not thinking_id:
|
||||
logger.warning(f"{self.log_prefix} 无thinking_id,无法等待新消息")
|
||||
return False, ""
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 开始等待新消息... (超时: {timeout}秒)")
|
||||
|
||||
|
||||
wait_start_time = asyncio.get_event_loop().time()
|
||||
while True:
|
||||
# 检查关闭标志
|
||||
@@ -197,21 +197,21 @@ class StreamAPI:
|
||||
if shutting_down:
|
||||
logger.info(f"{self.log_prefix} 等待新消息时检测到关闭信号,中断等待")
|
||||
return False, ""
|
||||
|
||||
|
||||
# 检查新消息
|
||||
thinking_id_timestamp = parse_thinking_id_to_timestamp(thinking_id)
|
||||
if await observation.has_new_messages_since(thinking_id_timestamp):
|
||||
logger.info(f"{self.log_prefix} 检测到新消息")
|
||||
return True, ""
|
||||
|
||||
|
||||
# 检查超时
|
||||
if asyncio.get_event_loop().time() - wait_start_time > timeout:
|
||||
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒)")
|
||||
return False, ""
|
||||
|
||||
|
||||
# 短暂休眠
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
|
||||
return False, ""
|
||||
|
||||
@@ -11,7 +11,7 @@ class BaseAction(ABC):
|
||||
"""Action组件基类
|
||||
|
||||
Action是插件的一种组件类型,用于处理聊天中的动作逻辑
|
||||
|
||||
|
||||
子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性:
|
||||
- focus_activation_type: 专注模式激活类型
|
||||
- normal_activation_type: 普通模式激活类型
|
||||
@@ -22,19 +22,21 @@ class BaseAction(ABC):
|
||||
- random_activation_probability: 随机激活概率
|
||||
- llm_judge_prompt: LLM判断提示词
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
observations: list = None,
|
||||
expressor = None,
|
||||
replyer = None,
|
||||
chat_stream = None,
|
||||
log_prefix: str = "",
|
||||
shutting_down: bool = False,
|
||||
**kwargs):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
observations: list = None,
|
||||
expressor=None,
|
||||
replyer=None,
|
||||
chat_stream=None,
|
||||
log_prefix: str = "",
|
||||
shutting_down: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化Action组件
|
||||
|
||||
Args:
|
||||
@@ -56,60 +58,57 @@ class BaseAction(ABC):
|
||||
self.thinking_id = thinking_id
|
||||
self.log_prefix = log_prefix
|
||||
self.shutting_down = shutting_down
|
||||
|
||||
|
||||
# 设置动作基本信息实例属性(兼容旧系统)
|
||||
self.action_name: str = getattr(self, 'action_name', self.__class__.__name__.lower().replace('action', ''))
|
||||
self.action_description: str = getattr(self, 'action_description', self.__doc__ or "Action组件")
|
||||
self.action_parameters: dict = getattr(self.__class__, 'action_parameters', {}).copy()
|
||||
self.action_require: list[str] = getattr(self.__class__, 'action_require', []).copy()
|
||||
|
||||
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
||||
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
||||
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
||||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
||||
|
||||
# 设置激活类型实例属性(从类属性复制,提供默认值)
|
||||
self.focus_activation_type: str = self._get_activation_type_value('focus_activation_type', 'never')
|
||||
self.normal_activation_type: str = self._get_activation_type_value('normal_activation_type', 'never')
|
||||
self.random_activation_probability: float = getattr(self.__class__, 'random_activation_probability', 0.0)
|
||||
self.llm_judge_prompt: str = getattr(self.__class__, 'llm_judge_prompt', "")
|
||||
self.activation_keywords: list[str] = getattr(self.__class__, 'activation_keywords', []).copy()
|
||||
self.keyword_case_sensitive: bool = getattr(self.__class__, 'keyword_case_sensitive', False)
|
||||
self.mode_enable: str = self._get_mode_value('mode_enable', 'all')
|
||||
self.parallel_action: bool = getattr(self.__class__, 'parallel_action', True)
|
||||
self.associated_types: list[str] = getattr(self.__class__, 'associated_types', []).copy()
|
||||
self.focus_activation_type: str = self._get_activation_type_value("focus_activation_type", "never")
|
||||
self.normal_activation_type: str = self._get_activation_type_value("normal_activation_type", "never")
|
||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||||
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
|
||||
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
||||
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
||||
self.mode_enable: str = self._get_mode_value("mode_enable", "all")
|
||||
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
||||
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
||||
self.enable_plugin: bool = True # 默认启用
|
||||
|
||||
|
||||
# 创建API实例,传递所有服务对象
|
||||
self.api = PluginAPI(
|
||||
chat_stream=chat_stream or kwargs.get("chat_stream"),
|
||||
expressor=expressor or kwargs.get("expressor"),
|
||||
expressor=expressor or kwargs.get("expressor"),
|
||||
replyer=replyer or kwargs.get("replyer"),
|
||||
observations=observations or kwargs.get("observations", []),
|
||||
log_prefix=log_prefix
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
|
||||
|
||||
# 设置API的action上下文
|
||||
self.api.set_action_context(
|
||||
thinking_id=thinking_id,
|
||||
shutting_down=shutting_down
|
||||
)
|
||||
|
||||
self.api.set_action_context(thinking_id=thinking_id, shutting_down=shutting_down)
|
||||
|
||||
logger.debug(f"{self.log_prefix} Action组件初始化完成")
|
||||
|
||||
|
||||
def _get_activation_type_value(self, attr_name: str, default: str) -> str:
|
||||
"""获取激活类型的字符串值"""
|
||||
attr = getattr(self.__class__, attr_name, None)
|
||||
if attr is None:
|
||||
return default
|
||||
if hasattr(attr, 'value'):
|
||||
if hasattr(attr, "value"):
|
||||
return attr.value
|
||||
return str(attr)
|
||||
|
||||
|
||||
def _get_mode_value(self, attr_name: str, default: str) -> str:
|
||||
"""获取模式的字符串值"""
|
||||
attr = getattr(self.__class__, attr_name, None)
|
||||
if attr is None:
|
||||
return default
|
||||
if hasattr(attr, 'value'):
|
||||
if hasattr(attr, "value"):
|
||||
return attr.value
|
||||
return str(attr)
|
||||
|
||||
|
||||
async def send_reply(self, content: str) -> bool:
|
||||
"""发送回复消息
|
||||
|
||||
@@ -138,8 +137,8 @@ class BaseAction(ABC):
|
||||
name = cls.__name__.lower().replace("action", "")
|
||||
if description is None:
|
||||
description = cls.__doc__ or f"{cls.__name__} Action组件"
|
||||
description = description.strip().split('\n')[0] # 取第一行作为描述
|
||||
|
||||
description = description.strip().split("\n")[0] # 取第一行作为描述
|
||||
|
||||
# 安全获取激活类型值
|
||||
def get_enum_value(attr_name, default):
|
||||
attr = getattr(cls, attr_name, None)
|
||||
@@ -147,29 +146,29 @@ class BaseAction(ABC):
|
||||
# 如果没有定义,返回默认的枚举值
|
||||
return getattr(ActionActivationType, default.upper(), ActionActivationType.NEVER)
|
||||
return attr
|
||||
|
||||
|
||||
def get_mode_value(attr_name, default):
|
||||
attr = getattr(cls, attr_name, None)
|
||||
if attr is None:
|
||||
return getattr(ChatMode, default.upper(), ChatMode.ALL)
|
||||
return attr
|
||||
|
||||
|
||||
return ActionInfo(
|
||||
name=name,
|
||||
component_type=ComponentType.ACTION,
|
||||
description=description,
|
||||
focus_activation_type=get_enum_value('focus_activation_type', 'never'),
|
||||
normal_activation_type=get_enum_value('normal_activation_type', 'never'),
|
||||
activation_keywords=getattr(cls, 'activation_keywords', []).copy(),
|
||||
keyword_case_sensitive=getattr(cls, 'keyword_case_sensitive', False),
|
||||
mode_enable=get_mode_value('mode_enable', 'all'),
|
||||
parallel_action=getattr(cls, 'parallel_action', True),
|
||||
random_activation_probability=getattr(cls, 'random_activation_probability', 0.0),
|
||||
llm_judge_prompt=getattr(cls, 'llm_judge_prompt', ""),
|
||||
focus_activation_type=get_enum_value("focus_activation_type", "never"),
|
||||
normal_activation_type=get_enum_value("normal_activation_type", "never"),
|
||||
activation_keywords=getattr(cls, "activation_keywords", []).copy(),
|
||||
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
|
||||
mode_enable=get_mode_value("mode_enable", "all"),
|
||||
parallel_action=getattr(cls, "parallel_action", True),
|
||||
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
|
||||
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
|
||||
# 使用正确的字段名
|
||||
action_parameters=getattr(cls, 'action_parameters', {}).copy(),
|
||||
action_require=getattr(cls, 'action_require', []).copy(),
|
||||
associated_types=getattr(cls, 'associated_types', []).copy()
|
||||
action_parameters=getattr(cls, "action_parameters", {}).copy(),
|
||||
action_require=getattr(cls, "action_require", []).copy(),
|
||||
associated_types=getattr(cls, "associated_types", []).copy(),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -180,14 +179,14 @@ class BaseAction(ABC):
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
"""兼容旧系统的handle_action接口,委托给execute方法
|
||||
|
||||
|
||||
为了保持向后兼容性,旧系统的代码可能会调用handle_action方法。
|
||||
此方法将调用委托给新的execute方法。
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
return await self.execute()
|
||||
return await self.execute()
|
||||
|
||||
@@ -69,7 +69,7 @@ class ComponentRegistry:
|
||||
self._register_action_component(component_info, component_class)
|
||||
elif component_type == ComponentType.COMMAND:
|
||||
self._register_command_component(component_info, component_class)
|
||||
|
||||
|
||||
logger.debug(f"已注册{component_type.value}组件: {component_name} ({component_class.__name__})")
|
||||
return True
|
||||
|
||||
|
||||
@@ -22,10 +22,10 @@ class PluginManager:
|
||||
|
||||
def __init__(self):
|
||||
self.plugin_directories: List[str] = []
|
||||
self.loaded_plugins: Dict[str, 'BasePlugin'] = {}
|
||||
self.loaded_plugins: Dict[str, "BasePlugin"] = {}
|
||||
self.failed_plugins: Dict[str, str] = {}
|
||||
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射
|
||||
|
||||
|
||||
logger.info("插件管理器初始化完成")
|
||||
|
||||
def add_plugin_directory(self, directory: str):
|
||||
@@ -43,7 +43,7 @@ class PluginManager:
|
||||
tuple[int, int]: (插件数量, 组件数量)
|
||||
"""
|
||||
logger.debug("开始加载所有插件...")
|
||||
|
||||
|
||||
# 第一阶段:加载所有插件模块(注册插件类)
|
||||
total_loaded_modules = 0
|
||||
total_failed_modules = 0
|
||||
@@ -52,9 +52,9 @@ class PluginManager:
|
||||
loaded, failed = self._load_plugin_modules_from_directory(directory)
|
||||
total_loaded_modules += loaded
|
||||
total_failed_modules += failed
|
||||
|
||||
|
||||
logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
|
||||
|
||||
|
||||
# 第二阶段:实例化所有已注册的插件类
|
||||
from src.plugin_system.base.base_plugin import get_registered_plugin_classes, instantiate_and_register_plugin
|
||||
|
||||
@@ -65,17 +65,17 @@ class PluginManager:
|
||||
for plugin_name, plugin_class in plugin_classes.items():
|
||||
# 使用记录的插件目录路径
|
||||
plugin_dir = self.plugin_paths.get(plugin_name)
|
||||
|
||||
|
||||
# 如果没有记录,则尝试查找(fallback)
|
||||
if not plugin_dir:
|
||||
plugin_dir = self._find_plugin_directory(plugin_class)
|
||||
if plugin_dir:
|
||||
self.plugin_paths[plugin_name] = plugin_dir
|
||||
|
||||
|
||||
if instantiate_and_register_plugin(plugin_class, plugin_dir):
|
||||
total_registered += 1
|
||||
self.loaded_plugins[plugin_name] = plugin_class
|
||||
|
||||
|
||||
# 📊 显示插件详细信息
|
||||
plugin_info = component_registry.get_plugin_info(plugin_name)
|
||||
if plugin_info:
|
||||
@@ -83,28 +83,32 @@ class PluginManager:
|
||||
for comp in plugin_info.components:
|
||||
comp_type = comp.component_type.name
|
||||
component_types[comp_type] = component_types.get(comp_type, 0) + 1
|
||||
|
||||
|
||||
components_str = ", ".join([f"{count}个{ctype}" for ctype, count in component_types.items()])
|
||||
logger.info(f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}) - {plugin_info.description}")
|
||||
logger.info(
|
||||
f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}) - {plugin_info.description}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"✅ 插件加载成功: {plugin_name}")
|
||||
else:
|
||||
total_failed_registration += 1
|
||||
self.failed_plugins[plugin_name] = "插件注册失败"
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name}")
|
||||
|
||||
|
||||
# 获取组件统计信息
|
||||
stats = component_registry.get_registry_stats()
|
||||
|
||||
|
||||
# 📋 显示插件加载总览
|
||||
if total_registered > 0:
|
||||
action_count = stats.get('action_components', 0)
|
||||
command_count = stats.get('command_components', 0)
|
||||
total_components = stats.get('total_components', 0)
|
||||
|
||||
action_count = stats.get("action_components", 0)
|
||||
command_count = stats.get("command_components", 0)
|
||||
total_components = stats.get("total_components", 0)
|
||||
|
||||
logger.info("🎉 插件系统加载完成!")
|
||||
logger.info(f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count})")
|
||||
|
||||
logger.info(
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count})"
|
||||
)
|
||||
|
||||
# 显示详细的插件列表
|
||||
logger.info("📋 已加载插件详情:")
|
||||
for plugin_name, _plugin_class in self.loaded_plugins.items():
|
||||
@@ -115,31 +119,31 @@ class PluginManager:
|
||||
author_info = f"by {plugin_info.author}" if plugin_info.author else ""
|
||||
info_parts = [part for part in [version_info, author_info] if part]
|
||||
extra_info = f" ({', '.join(info_parts)})" if info_parts else ""
|
||||
|
||||
|
||||
logger.info(f" 📦 {plugin_name}{extra_info}")
|
||||
|
||||
|
||||
# 组件列表
|
||||
if plugin_info.components:
|
||||
action_components = [c for c in plugin_info.components if c.component_type.name == 'ACTION']
|
||||
command_components = [c for c in plugin_info.components if c.component_type.name == 'COMMAND']
|
||||
|
||||
action_components = [c for c in plugin_info.components if c.component_type.name == "ACTION"]
|
||||
command_components = [c for c in plugin_info.components if c.component_type.name == "COMMAND"]
|
||||
|
||||
if action_components:
|
||||
action_names = [c.name for c in action_components]
|
||||
logger.info(f" 🎯 Action组件: {', '.join(action_names)}")
|
||||
|
||||
|
||||
if command_components:
|
||||
command_names = [c.name for c in command_components]
|
||||
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
||||
|
||||
|
||||
# 依赖信息
|
||||
if plugin_info.dependencies:
|
||||
logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}")
|
||||
|
||||
|
||||
# 配置文件信息
|
||||
if plugin_info.config_file:
|
||||
config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌"
|
||||
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
|
||||
|
||||
|
||||
# 显示目录统计
|
||||
logger.info("📂 加载目录统计:")
|
||||
for directory in self.plugin_directories:
|
||||
@@ -149,12 +153,12 @@ class PluginManager:
|
||||
plugin_path = self.plugin_paths.get(plugin_name, "")
|
||||
if plugin_path.startswith(directory):
|
||||
plugins_in_dir.append(plugin_name)
|
||||
|
||||
|
||||
if plugins_in_dir:
|
||||
logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})")
|
||||
else:
|
||||
logger.info(f" 📁 {directory}: 0个插件")
|
||||
|
||||
|
||||
# 失败信息
|
||||
if total_failed_registration > 0:
|
||||
logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败")
|
||||
@@ -162,10 +166,10 @@ class PluginManager:
|
||||
logger.info(f" ❌ {failed_plugin}: {error}")
|
||||
else:
|
||||
logger.warning("😕 没有成功加载任何插件")
|
||||
|
||||
|
||||
# 返回插件数量和组件数量
|
||||
return total_registered, total_components
|
||||
|
||||
|
||||
def _find_plugin_directory(self, plugin_class) -> Optional[str]:
|
||||
"""查找插件类对应的目录路径"""
|
||||
try:
|
||||
@@ -186,9 +190,9 @@ class PluginManager:
|
||||
if not os.path.exists(directory):
|
||||
logger.warning(f"插件目录不存在: {directory}")
|
||||
return loaded_count, failed_count
|
||||
|
||||
|
||||
logger.debug(f"正在扫描插件目录: {directory}")
|
||||
|
||||
|
||||
# 遍历目录中的所有Python文件和包
|
||||
for item in os.listdir(directory):
|
||||
item_path = os.path.join(directory, item)
|
||||
@@ -212,10 +216,10 @@ class PluginManager:
|
||||
failed_count += 1
|
||||
|
||||
return loaded_count, failed_count
|
||||
|
||||
|
||||
def _load_plugin_module_file(self, plugin_file: str, plugin_name: str, plugin_dir: str) -> bool:
|
||||
"""加载单个插件模块文件
|
||||
|
||||
|
||||
Args:
|
||||
plugin_file: 插件文件路径
|
||||
plugin_name: 插件名称
|
||||
@@ -239,10 +243,10 @@ class PluginManager:
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
|
||||
# 记录插件名和目录路径的映射
|
||||
self.plugin_paths[plugin_name] = plugin_dir
|
||||
|
||||
|
||||
logger.debug(f"插件模块加载成功: {plugin_file}")
|
||||
return True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user