feat:重构插件api

This commit is contained in:
SengokuCola
2025-06-10 15:28:36 +08:00
parent 2edece11ea
commit 4d32b3052f
30 changed files with 2429 additions and 471 deletions

169
src/plugin_system/README.md Normal file
View File

@@ -0,0 +1,169 @@
# MaiBot 插件系统 - 重构版
## 目录结构说明
经过重构,插件系统现在采用清晰的**系统核心**与**插件内容**分离的架构:
```
src/
├── plugin_system/ # 🔧 系统核心 - 插件框架本身
│ ├── __init__.py # 统一导出接口
│ ├── core/ # 核心管理
│ │ ├── plugin_manager.py
│ │ ├── component_registry.py
│ │ └── __init__.py
│ ├── apis/ # API接口
│ │ ├── plugin_api.py # 统一API聚合
│ │ ├── message_api.py
│ │ ├── llm_api.py
│ │ ├── database_api.py
│ │ ├── config_api.py
│ │ ├── utils_api.py
│ │ ├── stream_api.py
│ │ ├── hearflow_api.py
│ │ └── __init__.py
│ ├── base/ # 基础类
│ │ ├── base_plugin.py
│ │ ├── base_action.py
│ │ ├── base_command.py
│ │ ├── component_types.py
│ │ └── __init__.py
│ └── registry/ # 注册相关(预留)
└── plugins/ # 🔌 插件内容 - 具体的插件实现
├── built_in/ # 内置插件
│ ├── system_actions/ # 系统内置Action
│ └── system_commands/# 系统内置Command
└── examples/ # 示例插件
└── simple_plugin/
├── plugin.py
└── config.toml
```
## 架构优势
### 1. 职责清晰
- **`src/plugin_system/`** - 系统提供的框架、API和基础设施
- **`src/plugins/`** - 用户开发或使用的具体插件
### 2. 导入简化
```python
# 统一导入接口
from src.plugin_system import (
BasePlugin, register_plugin, BaseAction, BaseCommand,
ActionInfo, CommandInfo, PluginAPI
)
```
### 3. 模块化设计
- 各个子模块都有清晰的职责和接口
- 支持按需导入特定功能
- 便于维护和扩展
## 快速开始
### 创建简单插件
```python
from src.plugin_system import BasePlugin, register_plugin, BaseAction, ActionInfo
class MyAction(BaseAction):
async def execute(self):
return True, "Hello from my plugin!"
@register_plugin
class MyPlugin(BasePlugin):
plugin_name = "my_plugin"
plugin_description = "我的第一个插件"
def get_plugin_components(self):
return [(
ActionInfo(name="my_action", description="我的动作"),
MyAction
)]
```
### 使用系统API
```python
class MyAction(BaseAction):
async def execute(self):
# 发送消息
await self.api.send_text_to_group(
self.api.get_service("chat_stream"),
"Hello World!"
)
# 数据库操作
data = await self.api.db_get("table", "key")
# LLM调用
response = await self.api.llm_text_request("你好")
return True, response
```
## 兼容性迁移
### 现有Action迁移
```python
# 旧方式
from src.chat.actions.base_action import BaseAction, register_action
# 新方式
from src.plugin_system import BaseAction, register_plugin
from src.plugin_system.base.component_types import ActionInfo
# 将Action封装到Plugin中
@register_plugin
class MyActionPlugin(BasePlugin):
plugin_name = "my_action_plugin"
def get_plugin_components(self):
return [(ActionInfo(...), MyAction)]
```
### 现有Command迁移
```python
# 旧方式
from src.chat.command.command_handler import BaseCommand, register_command
# 新方式
from src.plugin_system import BaseCommand, register_plugin
from src.plugin_system.base.component_types import CommandInfo
# 将Command封装到Plugin中
@register_plugin
class MyCommandPlugin(BasePlugin):
plugin_name = "my_command_plugin"
def get_plugin_components(self):
return [(CommandInfo(...), MyCommand)]
```
## 扩展指南
### 添加新的组件类型
1.`component_types.py` 中定义新的组件类型
2.`component_registry.py` 中添加对应的注册逻辑
3. 创建对应的基类
### 添加新的API
1.`apis/` 目录下创建新的API模块
2.`plugin_api.py` 中集成新API
3. 更新 `__init__.py` 导出接口
## 最佳实践
1. **单一插件包含相关组件** - 一个插件可以包含多个相关的Action和Command
2. **使用配置文件** - 通过TOML配置文件管理插件行为
3. **合理的组件命名** - 使用描述性的组件名称
4. **充分的错误处理** - 在组件中妥善处理异常
5. **详细的文档** - 为插件和组件编写清晰的文档
## 内置插件规划
- **系统核心插件** - 将现有的内置Action/Command迁移为系统插件
- **工具插件** - 常用的工具和实用功能
- **示例插件** - 帮助开发者学习的示例代码
这个重构保持了向后兼容性,同时提供了更清晰、更易维护的架构。

View File

@@ -0,0 +1,47 @@
"""
MaiBot 插件系统
提供统一的插件开发和管理框架
"""
# 导出主要的公共接口
from src.plugin_system.base.base_plugin import BasePlugin, register_plugin
from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.component_types import (
ComponentType, ActionActivationType, ChatMode,
ComponentInfo, ActionInfo, CommandInfo, PluginInfo
)
from src.plugin_system.apis.plugin_api import PluginAPI, create_plugin_api, create_command_api
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
__version__ = "1.0.0"
__all__ = [
# 基础类
'BasePlugin',
'BaseAction',
'BaseCommand',
# 类型定义
'ComponentType',
'ActionActivationType',
'ChatMode',
'ComponentInfo',
'ActionInfo',
'CommandInfo',
'PluginInfo',
# API接口
'PluginAPI',
'create_plugin_api',
'create_command_api',
# 管理器
'plugin_manager',
'component_registry',
# 装饰器
'register_plugin',
]

View File

@@ -0,0 +1,172 @@
# API使用指南
插件系统提供了多种API访问方式根据使用场景选择合适的API类。
## 📊 API分类
### 🔗 ActionAPI - 需要Action依赖
**适用场景**在Action组件中使用需要访问聊天上下文
```python
from src.plugin_system.apis import ActionAPI
class MyAction(BaseAction):
async def execute(self):
# Action已内置ActionAPI可以直接使用
await self.api.send_message("text", "Hello")
await self.api.store_action_info(action_prompt_display="执行了动作")
```
**包含功能**
- ✅ 发送消息需要chat_stream、expressor等
- ✅ 数据库操作需要thinking_id、action_data等
### 🔧 IndependentAPI - 独立功能
**适用场景**在Command组件中使用或需要独立工具功能
```python
from src.plugin_system.apis import IndependentAPI
class MyCommand(BaseCommand):
async def execute(self):
# 创建独立API实例
api = IndependentAPI(log_prefix="[MyCommand]")
# 使用独立功能
models = api.get_available_models()
config = api.get_global_config("some_key")
timestamp = api.get_timestamp()
```
**包含功能**
- ✅ LLM模型调用
- ✅ 配置读取
- ✅ 工具函数时间、文件、ID生成等
- ✅ 聊天流查询
- ✅ 心流状态控制
### ⚡ StaticAPI - 静态访问
**适用场景**:简单工具调用,不需要实例化
```python
from src.plugin_system.apis import StaticAPI
# 直接调用静态方法
models = StaticAPI.get_available_models()
config = StaticAPI.get_global_config("bot.nickname")
timestamp = StaticAPI.get_timestamp()
unique_id = StaticAPI.generate_unique_id()
# 异步方法
result = await StaticAPI.generate_with_model(prompt, model_config)
chat_stream = StaticAPI.get_chat_stream_by_group_id("123456")
```
## 🎯 使用建议
### Action组件开发
```python
class MyAction(BaseAction):
# 激活条件直接在类中定义
focus_activation_type = ActionActivationType.KEYWORD
activation_keywords = ["测试"]
async def execute(self):
# 使用内置的ActionAPI
success = await self.api.send_message("text", "处理中...")
# 存储执行记录
await self.api.store_action_info(
action_prompt_display="执行了测试动作"
)
return True, "完成"
```
### Command组件开发
```python
class MyCommand(BaseCommand):
# 命令模式直接在类中定义
command_pattern = r"^/test\s+(?P<param>\w+)$"
command_help = "测试命令"
async def execute(self):
# 使用独立API
api = IndependentAPI(log_prefix="[TestCommand]")
# 获取配置
max_length = api.get_global_config("test.max_length", 100)
# 生成内容(如果需要)
if api.get_available_models():
models = api.get_available_models()
first_model = list(models.values())[0]
success, response, _, _ = await api.generate_with_model(
"生成测试回复", first_model
)
if success:
await self.send_reply(response)
```
### 独立工具使用
```python
# 不在插件环境中的独立使用
from src.plugin_system.apis import StaticAPI
def some_utility_function():
# 获取配置
bot_name = StaticAPI.get_global_config("bot.nickname", "Bot")
# 生成ID
request_id = StaticAPI.generate_unique_id()
# 格式化时间
current_time = StaticAPI.format_time()
return f"{bot_name}_{request_id}_{current_time}"
```
## 🔄 迁移指南
### 从原PluginAPI迁移
**原来的用法**
```python
# 原来需要导入完整PluginAPI
from src.plugin_system.apis import PluginAPI
api = PluginAPI(chat_stream=..., expressor=...)
await api.send_message("text", "Hello")
config = api.get_global_config("key")
```
**新的用法**
```python
# 方式1继续使用原PluginAPI不变
from src.plugin_system.apis import PluginAPI
# 方式2使用分类API推荐
from src.plugin_system.apis import ActionAPI, IndependentAPI
# Action相关功能
action_api = ActionAPI(chat_stream=..., expressor=...)
await action_api.send_message("text", "Hello")
# 独立功能
config = IndependentAPI().get_global_config("key")
# 或者
config = StaticAPI.get_global_config("key")
```
## 📋 API对照表
| 功能类别 | 原PluginAPI | ActionAPI | IndependentAPI | StaticAPI |
|---------|-------------|-----------|----------------|-----------|
| 发送消息 | ✅ | ✅ | ❌ | ❌ |
| 数据库操作 | ✅ | ✅ | ❌ | ❌ |
| LLM调用 | ✅ | ❌ | ✅ | ✅ |
| 配置读取 | ✅ | ❌ | ✅ | ✅ |
| 工具函数 | ✅ | ❌ | ✅ | ✅ |
| 聊天流查询 | ✅ | ❌ | ✅ | ✅ |
| 心流控制 | ✅ | ❌ | ✅ | ✅ |
这样的分类让插件开发者可以更明确地知道需要什么样的API避免不必要的依赖注入。

View File

@@ -0,0 +1,37 @@
"""
插件API模块
提供插件可以使用的各种API接口
"""
from src.plugin_system.apis.plugin_api import PluginAPI, create_plugin_api, create_command_api
from src.plugin_system.apis.message_api import MessageAPI
from src.plugin_system.apis.llm_api import LLMAPI
from src.plugin_system.apis.database_api import DatabaseAPI
from src.plugin_system.apis.config_api import ConfigAPI
from src.plugin_system.apis.utils_api import UtilsAPI
from src.plugin_system.apis.stream_api import StreamAPI
from src.plugin_system.apis.hearflow_api import HearflowAPI
# 新增分类的API聚合
from src.plugin_system.apis.action_apis import ActionAPI
from src.plugin_system.apis.independent_apis import IndependentAPI, StaticAPI
__all__ = [
# 原有统一API
'PluginAPI',
'create_plugin_api',
'create_command_api',
# 原有单独API
'MessageAPI',
'LLMAPI',
'DatabaseAPI',
'ConfigAPI',
'UtilsAPI',
'StreamAPI',
'HearflowAPI',
# 新增分类API
'ActionAPI', # 需要Action依赖的API
'IndependentAPI', # 独立API
'StaticAPI', # 静态API
]

View File

@@ -0,0 +1,85 @@
"""
Action相关API聚合模块
聚合了需要Action组件依赖的API这些API需要通过Action初始化时注入的服务对象才能正常工作。
包括MessageAPI、DatabaseAPI等需要chat_stream、expressor等服务的API。
"""
from src.plugin_system.apis.message_api import MessageAPI
from src.plugin_system.apis.database_api import DatabaseAPI
from src.common.logger_manager import get_logger
logger = get_logger("action_apis")
class ActionAPI(MessageAPI, DatabaseAPI):
"""
Action相关API聚合类
聚合了需要Action组件依赖的API功能。这些API需要以下依赖
- _services: 包含chat_stream、expressor、replyer、observations等服务对象
- log_prefix: 日志前缀
- thinking_id: 思考ID
- cycle_timers: 计时器
- action_data: Action数据
使用场景:
- 在Action组件中使用需要发送消息、存储数据等功能
- 需要访问聊天上下文和执行环境的操作
"""
def __init__(self,
chat_stream=None,
expressor=None,
replyer=None,
observations=None,
log_prefix: str = "[ActionAPI]",
thinking_id: str = "",
cycle_timers: dict = None,
action_data: dict = None):
"""
初始化Action相关API
Args:
chat_stream: 聊天流对象
expressor: 表达器对象
replyer: 回复器对象
observations: 观察列表
log_prefix: 日志前缀
thinking_id: 思考ID
cycle_timers: 计时器字典
action_data: Action数据
"""
# 存储依赖对象
self._services = {
"chat_stream": chat_stream,
"expressor": expressor,
"replyer": replyer,
"observations": observations or []
}
self.log_prefix = log_prefix
self.thinking_id = thinking_id
self.cycle_timers = cycle_timers or {}
self.action_data = action_data or {}
logger.debug(f"{self.log_prefix} ActionAPI 初始化完成")
def set_chat_stream(self, chat_stream):
"""设置聊天流对象"""
self._services["chat_stream"] = chat_stream
logger.debug(f"{self.log_prefix} 设置聊天流")
def set_expressor(self, expressor):
"""设置表达器对象"""
self._services["expressor"] = expressor
logger.debug(f"{self.log_prefix} 设置表达器")
def set_replyer(self, replyer):
"""设置回复器对象"""
self._services["replyer"] = replyer
logger.debug(f"{self.log_prefix} 设置回复器")
def set_observations(self, observations):
"""设置观察列表"""
self._services["observations"] = observations or []
logger.debug(f"{self.log_prefix} 设置观察列表")

View File

@@ -0,0 +1,54 @@
from typing import Any
from src.common.logger_manager import get_logger
from src.config.config import global_config
from src.person_info.person_info import person_info_manager
logger = get_logger("config_api")
class ConfigAPI:
"""配置API模块
提供了配置读取和用户信息获取等功能
"""
def get_global_config(self, key: str, default: Any = None) -> Any:
"""
安全地从全局配置中获取一个值。
插件应使用此方法读取全局配置,以保证只读和隔离性。
Args:
key: 配置键名
default: 如果配置不存在时返回的默认值
Returns:
Any: 配置值或默认值
"""
return global_config.get(key, default)
async def get_user_id_by_person_name(self, person_name: str) -> tuple[str, str]:
"""根据用户名获取用户ID
Args:
person_name: 用户名
Returns:
tuple[str, str]: (平台, 用户ID)
"""
person_id = person_info_manager.get_person_id_by_person_name(person_name)
user_id = await person_info_manager.get_value(person_id, "user_id")
platform = await person_info_manager.get_value(person_id, "platform")
return platform, user_id
async def get_person_info(self, person_id: str, key: str, default: Any = None) -> Any:
"""获取用户信息
Args:
person_id: 用户ID
key: 信息键名
default: 默认值
Returns:
Any: 用户信息值或默认值
"""
return await person_info_manager.get_value(person_id, key, default)

View File

@@ -0,0 +1,372 @@
import traceback
import time
from typing import Dict, List, Any, Union, Type
from src.common.logger_manager import get_logger
from src.common.database.database_model import ActionRecords
from src.common.database.database import db
from peewee import Model, DoesNotExist
logger = get_logger("database_api")
class DatabaseAPI:
"""数据库API模块
提供了数据库操作相关的功能
"""
async def store_action_info(
self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True
) -> None:
"""存储action执行信息到数据库
Args:
action_build_into_prompt: 是否构建到提示中
action_prompt_display: 动作显示内容
action_done: 动作是否已完成
"""
try:
chat_stream = self._services.get("chat_stream")
if not chat_stream:
logger.error(f"{self.log_prefix} 无法存储action信息缺少chat_stream服务")
return
action_time = time.time()
action_id = f"{action_time}_{self.thinking_id}"
ActionRecords.create(
action_id=action_id,
time=action_time,
action_name=self.__class__.__name__,
action_data=str(self.action_data),
action_done=action_done,
action_build_into_prompt=action_build_into_prompt,
action_prompt_display=action_prompt_display,
chat_id=chat_stream.stream_id,
chat_info_stream_id=chat_stream.stream_id,
chat_info_platform=chat_stream.platform,
user_id=chat_stream.user_info.user_id if chat_stream.user_info else "",
user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "",
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "",
)
logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}")
except Exception as e:
logger.error(f"{self.log_prefix} 存储action信息时出错: {e}")
traceback.print_exc()
async def db_query(
self,
model_class: Type[Model],
query_type: str = "get",
filters: Dict[str, Any] = None,
data: Dict[str, Any] = None,
limit: int = None,
order_by: List[str] = None,
single_result: bool = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
Args:
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
filters: 过滤条件字典,键为字段名,值为要匹配的值
data: 用于创建或更新的数据字典
limit: 限制结果数量
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回不同的结果:
- "get": 返回查询结果列表或单个结果(如果 single_result=True
- "create": 返回创建的记录
- "update": 返回受影响的行数
- "delete": 返回受影响的行数
- "count": 返回记录数量
示例:
# 查询最近10条消息
messages = await self.db_query(
Messages,
query_type="get",
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by=["-time"]
)
# 创建一条记录
new_record = await self.db_query(
ActionRecords,
query_type="create",
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
)
# 更新记录
updated_count = await self.db_query(
ActionRecords,
query_type="update",
filters={"action_id": "123"},
data={"action_done": True}
)
# 删除记录
deleted_count = await self.db_query(
ActionRecords,
query_type="delete",
filters={"action_id": "123"}
)
# 计数
count = await self.db_query(
Messages,
query_type="count",
filters={"chat_id": chat_stream.stream_id}
)
"""
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
# 构建基本查询
if query_type in ["get", "update", "delete", "count"]:
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 执行查询
if query_type == "get":
# 应用排序
if order_by:
for field in order_by:
if field.startswith("-"):
query = query.order_by(getattr(model_class, field[1:]).desc())
else:
query = query.order_by(getattr(model_class, field))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if single_result:
return results[0] if results else None
return results
elif query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
# 创建记录
record = model_class.create(**data)
# 返回创建的记录
return model_class.select().where(model_class.id == record.id).dicts().get()
elif query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
# 更新记录
return query.update(**data).execute()
elif query_type == "delete":
# 删除记录
return query.delete().execute()
elif query_type == "count":
# 计数
return query.count()
else:
raise ValueError(f"不支持的查询类型: {query_type}")
except DoesNotExist:
# 记录不存在
if query_type == "get" and single_result:
return None
return []
except Exception as e:
logger.error(f"{self.log_prefix} 数据库操作出错: {e}")
traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get":
return None if single_result else []
elif query_type in ["create", "update", "delete", "count"]:
return None
return None
async def db_raw_query(
self, sql: str, params: List[Any] = None, fetch_results: bool = True
) -> Union[List[Dict[str, Any]], int, None]:
"""执行原始SQL查询
警告: 使用此方法需要小心确保SQL语句已正确构造以避免SQL注入风险。
Args:
sql: 原始SQL查询字符串
params: 查询参数列表用于替换SQL中的占位符
fetch_results: 是否获取查询结果对于SELECT查询设为True对于
UPDATE/INSERT/DELETE等操作设为False
Returns:
如果fetch_results为True返回查询结果列表
如果fetch_results为False返回受影响的行数
如果出错返回None
"""
try:
cursor = db.execute_sql(sql, params or [])
if fetch_results:
# 获取列名
columns = [col[0] for col in cursor.description]
# 构建结果字典列表
results = []
for row in cursor.fetchall():
results.append(dict(zip(columns, row)))
return results
else:
# 返回受影响的行数
return cursor.rowcount
except Exception as e:
logger.error(f"{self.log_prefix} 执行原始SQL查询出错: {e}")
traceback.print_exc()
return None
async def db_save(
self, model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None
) -> Union[Dict[str, Any], None]:
"""保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
如果没有找到匹配记录或未提供key_field和key_value则创建新记录。
Args:
model_class: Peewee模型类如ActionRecords, Messages等
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名,例如"action_id"
key_value: 用于查找现有记录的字段值
Returns:
Dict[str, Any]: 保存后的记录数据
None: 如果操作失败
示例:
# 创建或更新一条记录
record = await self.db_save(
ActionRecords,
{
"action_id": "123",
"time": time.time(),
"action_name": "TestAction",
"action_done": True
},
key_field="action_id",
key_value="123"
)
"""
try:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
# 查找现有记录
existing_records = list(
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
)
if existing_records:
# 更新现有记录
existing_record = existing_records[0]
for field, value in data.items():
setattr(existing_record, field, value)
existing_record.save()
# 返回更新后的记录
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get()
return updated_record
# 如果没有找到现有记录或未提供key_field和key_value创建新记录
new_record = model_class.create(**data)
# 返回创建的记录
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get()
return created_record
except Exception as e:
logger.error(f"{self.log_prefix} 保存数据库记录出错: {e}")
traceback.print_exc()
return None
async def db_get(
self, model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
这是db_query方法的简化版本专注于数据检索操作。
Args:
model_class: Peewee模型类
filters: 过滤条件,字段名和值的字典
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
limit: 结果数量限制如果为1则返回单个记录而不是列表
Returns:
如果limit=1返回单个记录字典或None
否则返回记录字典列表或空列表。
示例:
# 获取单个记录
record = await self.db_get(
ActionRecords,
filters={"action_id": "123"},
limit=1
)
# 获取最近10条记录
records = await self.db_get(
Messages,
filters={"chat_id": chat_stream.stream_id},
order_by="-time",
limit=10
)
"""
try:
# 构建查询
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 应用排序
if order_by:
if order_by.startswith("-"):
query = query.order_by(getattr(model_class, order_by[1:]).desc())
else:
query = query.order_by(getattr(model_class, order_by))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if limit == 1:
return results[0] if results else None
return results
except Exception as e:
logger.error(f"{self.log_prefix} 获取数据库记录出错: {e}")
traceback.print_exc()
return None if limit == 1 else []

View File

@@ -0,0 +1,133 @@
from typing import Optional, List, Any
from src.common.logger_manager import get_logger
from src.chat.heart_flow.heartflow import heartflow
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
logger = get_logger("hearflow_api")
class HearflowAPI:
"""心流API模块
提供与心流和子心流相关的操作接口
"""
async def get_sub_hearflow_by_chat_id(self, chat_id: str) -> Optional[SubHeartflow]:
"""根据chat_id获取指定的sub_hearflow实例
Args:
chat_id: 聊天ID与sub_hearflow的subheartflow_id相同
Returns:
Optional[SubHeartflow]: sub_hearflow实例如果不存在则返回None
"""
try:
# 直接从subheartflow_manager获取已存在的子心流
# 使用锁来确保线程安全
async with heartflow.subheartflow_manager._lock:
subflow = heartflow.subheartflow_manager.subheartflows.get(chat_id)
if subflow and not subflow.should_stop:
logger.debug(f"{self.log_prefix} 成功获取子心流实例: {chat_id}")
return subflow
else:
logger.debug(f"{self.log_prefix} 子心流不存在或已停止: {chat_id}")
return None
except Exception as e:
logger.error(f"{self.log_prefix} 获取子心流实例时出错: {e}")
return None
def get_all_sub_hearflow_ids(self) -> List[str]:
"""获取所有子心流的ID列表
Returns:
List[str]: 所有子心流的ID列表
"""
try:
all_subflows = heartflow.subheartflow_manager.get_all_subheartflows()
chat_ids = [subflow.chat_id for subflow in all_subflows if not subflow.should_stop]
logger.debug(f"{self.log_prefix} 获取到 {len(chat_ids)} 个活跃的子心流ID")
return chat_ids
except Exception as e:
logger.error(f"{self.log_prefix} 获取子心流ID列表时出错: {e}")
return []
def get_all_sub_hearflows(self) -> List[SubHeartflow]:
"""获取所有子心流实例
Returns:
List[SubHeartflow]: 所有活跃的子心流实例列表
"""
try:
all_subflows = heartflow.subheartflow_manager.get_all_subheartflows()
active_subflows = [subflow for subflow in all_subflows if not subflow.should_stop]
logger.debug(f"{self.log_prefix} 获取到 {len(active_subflows)} 个活跃的子心流实例")
return active_subflows
except Exception as e:
logger.error(f"{self.log_prefix} 获取子心流实例列表时出错: {e}")
return []
async def get_sub_hearflow_chat_state(self, chat_id: str) -> Optional[ChatState]:
"""获取指定子心流的聊天状态
Args:
chat_id: 聊天ID
Returns:
Optional[ChatState]: 聊天状态如果子心流不存在则返回None
"""
try:
subflow = await self.get_sub_hearflow_by_chat_id(chat_id)
if subflow:
return subflow.chat_state.chat_status
return None
except Exception as e:
logger.error(f"{self.log_prefix} 获取子心流聊天状态时出错: {e}")
return None
async def set_sub_hearflow_chat_state(self, chat_id: str, target_state: ChatState) -> bool:
"""设置指定子心流的聊天状态
Args:
chat_id: 聊天ID
target_state: 目标状态
Returns:
bool: 是否设置成功
"""
try:
return await heartflow.subheartflow_manager.force_change_state(chat_id, target_state)
except Exception as e:
logger.error(f"{self.log_prefix} 设置子心流聊天状态时出错: {e}")
return False
async def get_sub_hearflow_replyer(self, chat_id: str) -> Optional[Any]:
"""根据chat_id获取指定子心流的replyer实例
Args:
chat_id: 聊天ID
Returns:
Optional[Any]: replyer实例如果不存在则返回None
"""
try:
replyer, _ = await self.get_sub_hearflow_replyer_and_expressor(chat_id)
return replyer
except Exception as e:
logger.error(f"{self.log_prefix} 获取子心流replyer时出错: {e}")
return None
async def get_sub_hearflow_expressor(self, chat_id: str) -> Optional[Any]:
"""根据chat_id获取指定子心流的expressor实例
Args:
chat_id: 聊天ID
Returns:
Optional[Any]: expressor实例如果不存在则返回None
"""
try:
_, expressor = await self.get_sub_hearflow_replyer_and_expressor(chat_id)
return expressor
except Exception as e:
logger.error(f"{self.log_prefix} 获取子心流expressor时出错: {e}")
return None

View File

@@ -0,0 +1,132 @@
"""
独立API聚合模块
聚合了不需要Action组件依赖的API这些API可以独立使用不需要注入服务对象。
包括LLMAPI、ConfigAPI、UtilsAPI、StreamAPI、HearflowAPI等独立功能的API。
"""
from src.plugin_system.apis.llm_api import LLMAPI
from src.plugin_system.apis.config_api import ConfigAPI
from src.plugin_system.apis.utils_api import UtilsAPI
from src.plugin_system.apis.stream_api import StreamAPI
from src.plugin_system.apis.hearflow_api import HearflowAPI
from src.common.logger_manager import get_logger
logger = get_logger("independent_apis")
class IndependentAPI(LLMAPI, ConfigAPI, UtilsAPI, StreamAPI, HearflowAPI):
"""
独立API聚合类
聚合了不需要Action组件依赖的API功能。这些API的特点
- 不需要chat_stream、expressor等服务对象
- 可以独立调用不依赖Action执行上下文
- 主要是工具类方法和配置查询方法
包含的API
- LLMAPI: LLM模型调用仅需要全局配置
- ConfigAPI: 配置读取(使用全局配置)
- UtilsAPI: 工具方法(文件操作、时间处理等)
- StreamAPI: 聊天流查询使用ChatManager
- HearflowAPI: 心流状态控制使用heartflow
使用场景:
- 在Command组件中使用
- 独立的工具函数调用
- 配置查询和系统状态检查
"""
def __init__(self, log_prefix: str = "[IndependentAPI]"):
"""
初始化独立API
Args:
log_prefix: 日志前缀,用于区分不同的调用来源
"""
self.log_prefix = log_prefix
logger.debug(f"{self.log_prefix} IndependentAPI 初始化完成")
# 提供便捷的静态访问方式
class StaticAPI:
"""
静态API类
提供完全静态的API访问方式不需要实例化适合简单的工具调用。
"""
# LLM相关
@staticmethod
def get_available_models():
"""获取可用的LLM模型"""
api = LLMAPI()
return api.get_available_models()
@staticmethod
async def generate_with_model(prompt: str, model_config: dict, **kwargs):
"""使用LLM生成内容"""
api = LLMAPI()
api.log_prefix = "[StaticAPI]"
return await api.generate_with_model(prompt, model_config, **kwargs)
# 配置相关
@staticmethod
def get_global_config(key: str, default=None):
"""获取全局配置"""
api = ConfigAPI()
return api.get_global_config(key, default)
@staticmethod
async def get_user_id_by_name(person_name: str):
"""根据用户名获取用户ID"""
api = ConfigAPI()
return await api.get_user_id_by_person_name(person_name)
# 工具相关
@staticmethod
def get_timestamp():
"""获取当前时间戳"""
api = UtilsAPI()
return api.get_timestamp()
@staticmethod
def format_time(timestamp=None, format_str="%Y-%m-%d %H:%M:%S"):
"""格式化时间"""
api = UtilsAPI()
return api.format_time(timestamp, format_str)
@staticmethod
def generate_unique_id():
"""生成唯一ID"""
api = UtilsAPI()
return api.generate_unique_id()
# 聊天流相关
@staticmethod
def get_chat_stream_by_group_id(group_id: str, platform: str = "qq"):
"""通过群ID获取聊天流"""
api = StreamAPI()
api.log_prefix = "[StaticAPI]"
return api.get_chat_stream_by_group_id(group_id, platform)
@staticmethod
def get_all_group_chat_streams(platform: str = "qq"):
"""获取所有群聊聊天流"""
api = StreamAPI()
api.log_prefix = "[StaticAPI]"
return api.get_all_group_chat_streams(platform)
# 心流相关
@staticmethod
async def get_sub_hearflow_by_chat_id(chat_id: str):
"""获取子心流"""
api = HearflowAPI()
api.log_prefix = "[StaticAPI]"
return await api.get_sub_hearflow_by_chat_id(chat_id)
@staticmethod
async def set_sub_hearflow_chat_state(chat_id: str, target_state):
"""设置子心流状态"""
api = HearflowAPI()
api.log_prefix = "[StaticAPI]"
return await api.set_sub_hearflow_chat_state(chat_id, target_state)

View File

@@ -0,0 +1,54 @@
from typing import Tuple, Dict, Any
from src.common.logger_manager import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
logger = get_logger("llm_api")
class LLMAPI:
"""LLM API模块
提供了与LLM模型交互的功能
"""
def get_available_models(self) -> Dict[str, Any]:
"""获取所有可用的模型配置
Returns:
Dict[str, Any]: 模型配置字典key为模型名称value为模型配置
"""
if not hasattr(global_config, "model"):
logger.error(f"{self.log_prefix} 无法获取模型列表:全局配置中未找到 model 配置")
return {}
models = global_config.model
return models
async def generate_with_model(
self, prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
) -> Tuple[bool, str, str, str]:
"""使用指定模型生成内容
Args:
prompt: 提示词
model_config: 模型配置(从 get_available_models 获取的模型配置)
request_type: 请求类型标识
**kwargs: 其他模型特定参数如temperature、max_tokens等
Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
"""
try:
logger.info(f"{self.log_prefix} 使用模型生成内容,提示词: {prompt[:100]}...")
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
response, (reasoning, model_name) = await llm_request.generate_response_async(prompt)
return True, response, reasoning, model_name
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"{self.log_prefix} {error_msg}")
return False, error_msg, "", ""

View File

@@ -0,0 +1,398 @@
import traceback
import time
from typing import Optional, List, Dict, Any
from src.common.logger_manager import get_logger
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
# 以下为类型注解需要
from src.chat.message_receive.chat_stream import ChatStream, chat_manager
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
from src.chat.focus_chat.info.obs_info import ObsInfo
# 新增导入
from src.chat.focus_chat.heartFC_sender import HeartFCSender
from src.chat.message_receive.message import MessageSending
from maim_message import Seg, UserInfo
from src.config.config import global_config
logger = get_logger("message_api")
class MessageAPI:
"""消息API模块
提供了发送消息、获取消息历史等功能
"""
async def send_message_to_target(
self,
message_type: str,
content: str,
platform: str,
target_id: str,
is_group: bool = True,
display_message: str = "",
) -> bool:
"""直接向指定目标发送消息
Args:
message_type: 消息类型,如"text""image""emoji"
content: 消息内容
platform: 目标平台,如"qq"
target_id: 目标ID群ID或用户ID
is_group: 是否为群聊True为群聊False为私聊
display_message: 显示消息(可选)
Returns:
bool: 是否发送成功
"""
try:
# 构建目标聊天流ID
if is_group:
# 群聊:从数据库查找对应的聊天流
target_stream = None
for _, stream in chat_manager.streams.items():
if (
stream.group_info
and str(stream.group_info.group_id) == str(target_id)
and stream.platform == platform
):
target_stream = stream
break
if not target_stream:
logger.error(f"{getattr(self, 'log_prefix', '')} 未找到群ID为 {target_id} 的聊天流")
return False
else:
# 私聊:从数据库查找对应的聊天流
target_stream = None
for _, stream in chat_manager.streams.items():
if (
not stream.group_info
and str(stream.user_info.user_id) == str(target_id)
and stream.platform == platform
):
target_stream = stream
break
if not target_stream:
logger.error(f"{getattr(self, 'log_prefix', '')} 未找到用户ID为 {target_id} 的私聊流")
return False
# 创建HeartFCSender实例
heart_fc_sender = HeartFCSender()
# 生成消息ID和thinking_id
current_time = time.time()
message_id = f"plugin_msg_{int(current_time * 1000)}"
# 构建机器人用户信息
bot_user_info = UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=platform,
)
# 创建消息段
message_segment = Seg(type=message_type, data=content)
# 创建空锚点消息(用于回复)
anchor_message = await create_empty_anchor_message(platform, target_stream.group_info, target_stream)
# 构建发送消息对象
bot_message = MessageSending(
message_id=message_id,
chat_stream=target_stream,
bot_user_info=bot_user_info,
sender_info=target_stream.user_info, # 目标用户信息
message_segment=message_segment,
display_message=display_message,
reply=anchor_message,
is_head=True,
is_emoji=(message_type == "emoji"),
thinking_start_time=current_time,
)
# 发送消息
sent_msg = await heart_fc_sender.send_message(bot_message, has_thinking=True, typing=False, set_reply=False)
if sent_msg:
logger.info(f"{getattr(self, 'log_prefix', '')} 成功发送消息到 {platform}:{target_id}")
return True
else:
logger.error(f"{getattr(self, 'log_prefix', '')} 发送消息失败")
return False
except Exception as e:
logger.error(f"{getattr(self, 'log_prefix', '')} 向目标发送消息时出错: {e}")
traceback.print_exc()
return False
async def send_text_to_group(self, text: str, group_id: str, platform: str = "qq") -> bool:
"""便捷方法:向指定群聊发送文本消息
Args:
text: 要发送的文本内容
group_id: 群聊ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
return await self.send_message_to_target(
message_type="text", content=text, platform=platform, target_id=group_id, is_group=True
)
async def send_text_to_user(self, text: str, user_id: str, platform: str = "qq") -> bool:
"""便捷方法:向指定用户发送私聊文本消息
Args:
text: 要发送的文本内容
user_id: 用户ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
return await self.send_message_to_target(
message_type="text", content=text, platform=platform, target_id=user_id, is_group=False
)
async def send_message(self, type: str, data: str, target: Optional[str] = "", display_message: str = "") -> bool:
"""发送消息的简化方法
Args:
type: 消息类型,如"text""image"
data: 消息内容
target: 目标消息(可选)
display_message: 显示的消息内容(可选)
Returns:
bool: 是否发送成功
"""
try:
# 安全获取服务和日志前缀
services = getattr(self, '_services', {})
log_prefix = getattr(self, 'log_prefix', '[MessageAPI]')
expressor: DefaultExpressor = services.get("expressor")
chat_stream: ChatStream = services.get("chat_stream")
if not expressor or not chat_stream:
logger.error(f"{log_prefix} 无法发送消息:缺少必要的内部服务")
return False
# 获取锚定消息(如果有)
observations = services.get("observations", [])
if len(observations) > 0:
chatting_observation: ChattingObservation = next(
(obs for obs in observations if isinstance(obs, ChattingObservation)), None
)
if chatting_observation:
anchor_message = chatting_observation.search_message_by_text(target)
else:
anchor_message = None
else:
anchor_message = None
# 如果没有找到锚点消息,创建一个占位符
if not anchor_message:
logger.info(f"{log_prefix} 未找到锚点消息,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message.update_chat_stream(chat_stream)
response_set = [
(type, data),
]
# 调用内部方法发送消息
success = await expressor.send_response_messages(
anchor_message=anchor_message,
response_set=response_set,
display_message=display_message,
)
return success
except Exception as e:
log_prefix = getattr(self, 'log_prefix', '[MessageAPI]')
logger.error(f"{log_prefix} 发送消息时出错: {e}")
traceback.print_exc()
return False
async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool:
"""通过expressor发送文本消息的简化方法
Args:
text: 要发送的消息文本
target: 目标消息(可选)
Returns:
bool: 是否发送成功
"""
# 安全获取服务和日志前缀
services = getattr(self, '_services', {})
log_prefix = getattr(self, 'log_prefix', '[MessageAPI]')
expressor: DefaultExpressor = services.get("expressor")
chat_stream: ChatStream = services.get("chat_stream")
if not expressor or not chat_stream:
logger.error(f"{log_prefix} 无法发送消息:缺少必要的内部服务")
return False
# 构造简化的动作数据
reply_data = {"text": text, "target": target or "", "emojis": []}
# 获取锚定消息(如果有)
observations = services.get("observations", [])
# 查找 ChattingObservation 实例
chatting_observation = None
for obs in observations:
if isinstance(obs, ChattingObservation):
chatting_observation = obs
break
if not chatting_observation:
logger.warning(f"{log_prefix} 未找到 ChattingObservation 实例,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
if not anchor_message:
logger.info(f"{log_prefix} 未找到锚点消息,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message.update_chat_stream(chat_stream)
# 调用内部方法发送消息
cycle_timers = getattr(self, 'cycle_timers', {})
reasoning = getattr(self, 'reasoning', '插件生成')
thinking_id = getattr(self, 'thinking_id', 'plugin_thinking')
success, _ = await expressor.deal_reply(
cycle_timers=cycle_timers,
action_data=reply_data,
anchor_message=anchor_message,
reasoning=reasoning,
thinking_id=thinking_id,
)
return success
async def send_message_by_replyer(
self, target: Optional[str] = None, extra_info_block: Optional[str] = None
) -> bool:
"""通过replyer发送消息的简化方法
Args:
target: 目标消息(可选)
extra_info_block: 额外信息块(可选)
Returns:
bool: 是否发送成功
"""
# 安全获取服务和日志前缀
services = getattr(self, '_services', {})
log_prefix = getattr(self, 'log_prefix', '[MessageAPI]')
replyer: DefaultReplyer = services.get("replyer")
chat_stream: ChatStream = services.get("chat_stream")
if not replyer or not chat_stream:
logger.error(f"{log_prefix} 无法发送消息:缺少必要的内部服务")
return False
# 构造简化的动作数据
reply_data = {"target": target or "", "extra_info_block": extra_info_block}
# 获取锚定消息(如果有)
observations = services.get("observations", [])
# 查找 ChattingObservation 实例
chatting_observation = None
for obs in observations:
if isinstance(obs, ChattingObservation):
chatting_observation = obs
break
if not chatting_observation:
logger.warning(f"{log_prefix} 未找到 ChattingObservation 实例,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
if not anchor_message:
logger.info(f"{log_prefix} 未找到锚点消息,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message.update_chat_stream(chat_stream)
# 调用内部方法发送消息
cycle_timers = getattr(self, 'cycle_timers', {})
reasoning = getattr(self, 'reasoning', '插件生成')
thinking_id = getattr(self, 'thinking_id', 'plugin_thinking')
success, _ = await replyer.deal_reply(
cycle_timers=cycle_timers,
action_data=reply_data,
anchor_message=anchor_message,
reasoning=reasoning,
thinking_id=thinking_id,
)
return success
def get_chat_type(self) -> str:
"""获取当前聊天类型
Returns:
str: 聊天类型 ("group""private")
"""
services = getattr(self, '_services', {})
chat_stream: ChatStream = services.get("chat_stream")
if chat_stream and hasattr(chat_stream, "group_info"):
return "group" if chat_stream.group_info else "private"
return "unknown"
def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]:
"""获取最近的消息
Args:
count: 要获取的消息数量
Returns:
List[Dict]: 消息列表,每个消息包含发送者、内容等信息
"""
messages = []
services = getattr(self, '_services', {})
observations = services.get("observations", [])
if observations and len(observations) > 0:
obs = observations[0]
if hasattr(obs, "get_talking_message"):
obs: ObsInfo
raw_messages = obs.get_talking_message()
# 转换为简化格式
for msg in raw_messages[-count:]:
simple_msg = {
"sender": msg.get("sender", "未知"),
"content": msg.get("content", ""),
"timestamp": msg.get("timestamp", 0),
}
messages.append(simple_msg)
return messages

View File

@@ -0,0 +1,158 @@
# -*- coding: utf-8 -*-
"""
统一的插件API聚合模块
提供所有插件API功能的统一访问入口
"""
from typing import Dict, Any, Optional
from src.common.logger_manager import get_logger
# 导入所有API模块
from src.plugin_system.apis.message_api import MessageAPI
from src.plugin_system.apis.llm_api import LLMAPI
from src.plugin_system.apis.database_api import DatabaseAPI
from src.plugin_system.apis.config_api import ConfigAPI
from src.plugin_system.apis.utils_api import UtilsAPI
from src.plugin_system.apis.stream_api import StreamAPI
from src.plugin_system.apis.hearflow_api import HearflowAPI
logger = get_logger("plugin_api")
class PluginAPI(MessageAPI, LLMAPI, DatabaseAPI, ConfigAPI, UtilsAPI, StreamAPI, HearflowAPI):
"""
插件API聚合类
集成了所有可供插件使用的API功能提供统一的访问接口。
插件组件可以直接使用此API实例来访问各种功能。
特性:
- 聚合所有API模块的功能
- 支持依赖注入和配置
- 提供统一的错误处理和日志记录
"""
def __init__(self,
chat_stream=None,
expressor=None,
replyer=None,
observations=None,
log_prefix: str = "[PluginAPI]"):
"""
初始化插件API
Args:
chat_stream: 聊天流对象
expressor: 表达器对象
replyer: 回复器对象
observations: 观察列表
log_prefix: 日志前缀
"""
# 存储依赖对象
self._services = {
"chat_stream": chat_stream,
"expressor": expressor,
"replyer": replyer,
"observations": observations or []
}
self.log_prefix = log_prefix
# 调用所有父类的初始化
super().__init__()
logger.debug(f"{self.log_prefix} PluginAPI 初始化完成")
def set_chat_stream(self, chat_stream):
"""设置聊天流对象"""
self._services["chat_stream"] = chat_stream
logger.debug(f"{self.log_prefix} 设置聊天流: {getattr(chat_stream, 'stream_id', 'Unknown')}")
def set_expressor(self, expressor):
"""设置表达器对象"""
self._services["expressor"] = expressor
logger.debug(f"{self.log_prefix} 设置表达器")
def set_replyer(self, replyer):
"""设置回复器对象"""
self._services["replyer"] = replyer
logger.debug(f"{self.log_prefix} 设置回复器")
def set_observations(self, observations):
"""设置观察列表"""
self._services["observations"] = observations or []
logger.debug(f"{self.log_prefix} 设置观察列表,数量: {len(observations or [])}")
def get_service(self, service_name: str):
"""获取指定的服务对象"""
return self._services.get(service_name)
def has_service(self, service_name: str) -> bool:
"""检查是否有指定的服务对象"""
return service_name in self._services and self._services[service_name] is not None
# 便捷的工厂函数
def create_plugin_api(chat_stream=None,
expressor=None,
replyer=None,
observations=None,
log_prefix: str = "[Plugin]") -> PluginAPI:
"""
创建插件API实例的便捷函数
Args:
chat_stream: 聊天流对象
expressor: 表达器对象
replyer: 回复器对象
observations: 观察列表
log_prefix: 日志前缀
Returns:
PluginAPI: 配置好的插件API实例
"""
return PluginAPI(
chat_stream=chat_stream,
expressor=expressor,
replyer=replyer,
observations=observations,
log_prefix=log_prefix
)
def create_command_api(message, log_prefix: str = "[Command]") -> PluginAPI:
"""
为命令创建插件API实例的便捷函数
Args:
message: 消息对象,应该包含 chat_stream 等信息
log_prefix: 日志前缀
Returns:
PluginAPI: 配置好的插件API实例
"""
chat_stream = getattr(message, 'chat_stream', None)
api = PluginAPI(
chat_stream=chat_stream,
log_prefix=log_prefix
)
return api
# 导出主要接口
__all__ = [
'PluginAPI',
'create_plugin_api',
'create_command_api',
# 也可以导出各个API类供单独使用
'MessageAPI',
'LLMAPI',
'DatabaseAPI',
'ConfigAPI',
'UtilsAPI',
'StreamAPI',
'HearflowAPI'
]

View File

@@ -0,0 +1,159 @@
from typing import Optional, List, Dict, Any
from src.common.logger_manager import get_logger
from src.chat.message_receive.chat_stream import ChatManager, ChatStream
logger = get_logger("stream_api")
class StreamAPI:
"""聊天流API模块
提供了获取聊天流、通过群ID查找聊天流等功能
"""
def get_chat_stream_by_group_id(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]:
"""通过QQ群ID获取聊天流
Args:
group_id: QQ群ID
platform: 平台标识,默认为"qq"
Returns:
Optional[ChatStream]: 找到的聊天流对象如果未找到则返回None
"""
try:
chat_manager = ChatManager()
# 遍历所有已加载的聊天流查找匹配的群ID
for stream_id, stream in chat_manager.streams.items():
if (
stream.group_info
and str(stream.group_info.group_id) == str(group_id)
and stream.platform == platform
):
logger.info(f"{self.log_prefix} 通过群ID {group_id} 找到聊天流: {stream_id}")
return stream
logger.warning(f"{self.log_prefix} 未找到群ID为 {group_id} 的聊天流")
return None
except Exception as e:
logger.error(f"{self.log_prefix} 通过群ID获取聊天流时出错: {e}")
return None
def get_all_group_chat_streams(self, platform: str = "qq") -> List[ChatStream]:
"""获取所有群聊的聊天流
Args:
platform: 平台标识,默认为"qq"
Returns:
List[ChatStream]: 所有群聊的聊天流列表
"""
try:
chat_manager = ChatManager()
group_streams = []
for stream in chat_manager.streams.values():
if stream.group_info and stream.platform == platform:
group_streams.append(stream)
logger.info(f"{self.log_prefix} 找到 {len(group_streams)} 个群聊聊天流")
return group_streams
except Exception as e:
logger.error(f"{self.log_prefix} 获取所有群聊聊天流时出错: {e}")
return []
def get_chat_stream_by_user_id(self, user_id: str, platform: str = "qq") -> Optional[ChatStream]:
"""通过用户ID获取私聊聊天流
Args:
user_id: 用户ID
platform: 平台标识,默认为"qq"
Returns:
Optional[ChatStream]: 找到的私聊聊天流对象如果未找到则返回None
"""
try:
chat_manager = ChatManager()
# 遍历所有已加载的聊天流查找匹配的用户ID私聊
for stream_id, stream in chat_manager.streams.items():
if (
not stream.group_info # 私聊没有群信息
and stream.user_info
and str(stream.user_info.user_id) == str(user_id)
and stream.platform == platform
):
logger.info(f"{self.log_prefix} 通过用户ID {user_id} 找到私聊聊天流: {stream_id}")
return stream
logger.warning(f"{self.log_prefix} 未找到用户ID为 {user_id} 的私聊聊天流")
return None
except Exception as e:
logger.error(f"{self.log_prefix} 通过用户ID获取私聊聊天流时出错: {e}")
return None
def get_chat_streams_info(self) -> List[Dict[str, Any]]:
"""获取所有聊天流的基本信息
Returns:
List[Dict[str, Any]]: 包含聊天流基本信息的字典列表
"""
try:
chat_manager = ChatManager()
streams_info = []
for stream_id, stream in chat_manager.streams.items():
info = {
"stream_id": stream_id,
"platform": stream.platform,
"chat_type": "group" if stream.group_info else "private",
"create_time": stream.create_time,
"last_active_time": stream.last_active_time,
}
if stream.group_info:
info.update({"group_id": stream.group_info.group_id, "group_name": stream.group_info.group_name})
if stream.user_info:
info.update({"user_id": stream.user_info.user_id, "user_nickname": stream.user_info.user_nickname})
streams_info.append(info)
logger.info(f"{self.log_prefix} 获取到 {len(streams_info)} 个聊天流信息")
return streams_info
except Exception as e:
logger.error(f"{self.log_prefix} 获取聊天流信息时出错: {e}")
return []
async def get_chat_stream_by_group_id_async(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]:
"""异步通过QQ群ID获取聊天流包括从数据库搜索
Args:
group_id: QQ群ID
platform: 平台标识,默认为"qq"
Returns:
Optional[ChatStream]: 找到的聊天流对象如果未找到则返回None
"""
try:
# 首先尝试从内存中查找
stream = self.get_chat_stream_by_group_id(group_id, platform)
if stream:
return stream
# 如果内存中没有,尝试从数据库加载所有聊天流后再查找
chat_manager = ChatManager()
await chat_manager.load_all_streams()
# 再次尝试从内存中查找
stream = self.get_chat_stream_by_group_id(group_id, platform)
return stream
except Exception as e:
logger.error(f"{self.log_prefix} 异步通过群ID获取聊天流时出错: {e}")
return None

View File

@@ -0,0 +1,126 @@
import os
import json
import time
from typing import Any, Optional
from src.common.logger_manager import get_logger
logger = get_logger("utils_api")
class UtilsAPI:
"""工具类API模块
提供了各种辅助功能
"""
def get_plugin_path(self) -> str:
"""获取当前插件的路径
Returns:
str: 插件目录的绝对路径
"""
import inspect
plugin_module_path = inspect.getfile(self.__class__)
plugin_dir = os.path.dirname(plugin_module_path)
return plugin_dir
def read_json_file(self, file_path: str, default: Any = None) -> Any:
"""读取JSON文件
Args:
file_path: 文件路径,可以是相对于插件目录的路径
default: 如果文件不存在或读取失败时返回的默认值
Returns:
Any: JSON数据或默认值
"""
try:
# 如果是相对路径,则相对于插件目录
if not os.path.isabs(file_path):
file_path = os.path.join(self.get_plugin_path(), file_path)
if not os.path.exists(file_path):
logger.warning(f"{self.log_prefix} 文件不存在: {file_path}")
return default
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"{self.log_prefix} 读取JSON文件出错: {e}")
return default
def write_json_file(self, file_path: str, data: Any, indent: int = 2) -> bool:
"""写入JSON文件
Args:
file_path: 文件路径,可以是相对于插件目录的路径
data: 要写入的数据
indent: JSON缩进
Returns:
bool: 是否写入成功
"""
try:
# 如果是相对路径,则相对于插件目录
if not os.path.isabs(file_path):
file_path = os.path.join(self.get_plugin_path(), file_path)
# 确保目录存在
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=indent)
return True
except Exception as e:
logger.error(f"{self.log_prefix} 写入JSON文件出错: {e}")
return False
def get_timestamp(self) -> int:
"""获取当前时间戳
Returns:
int: 当前时间戳(秒)
"""
return int(time.time())
def format_time(self, timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
"""格式化时间
Args:
timestamp: 时间戳如果为None则使用当前时间
format_str: 时间格式字符串
Returns:
str: 格式化后的时间字符串
"""
import datetime
if timestamp is None:
timestamp = time.time()
return datetime.datetime.fromtimestamp(timestamp).strftime(format_str)
def parse_time(self, time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int:
"""解析时间字符串为时间戳
Args:
time_str: 时间字符串
format_str: 时间格式字符串
Returns:
int: 时间戳(秒)
"""
import datetime
dt = datetime.datetime.strptime(time_str, format_str)
return int(dt.timestamp())
def generate_unique_id(self) -> str:
"""生成唯一ID
Returns:
str: 唯一ID
"""
import uuid
return str(uuid.uuid4())

View File

@@ -0,0 +1,27 @@
"""
插件基础类模块
提供插件开发的基础类和类型定义
"""
from src.plugin_system.base.base_plugin import BasePlugin, register_plugin
from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.component_types import (
ComponentType, ActionActivationType, ChatMode,
ComponentInfo, ActionInfo, CommandInfo, PluginInfo
)
__all__ = [
'BasePlugin',
'BaseAction',
'BaseCommand',
'register_plugin',
'ComponentType',
'ActionActivationType',
'ChatMode',
'ComponentInfo',
'ActionInfo',
'CommandInfo',
'PluginInfo',
]

View File

@@ -0,0 +1,120 @@
from abc import ABC, abstractmethod
from typing import Tuple, Dict, Any, Optional
from src.common.logger_manager import get_logger
from src.plugin_system.apis.plugin_api import PluginAPI
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
logger = get_logger("base_action")
class BaseAction(ABC):
"""Action组件基类
Action是插件的一种组件类型用于处理聊天中的动作逻辑
子类可以通过类属性定义激活条件:
- focus_activation_type: 专注模式激活类型
- normal_activation_type: 普通模式激活类型
- activation_keywords: 激活关键词列表
- keyword_case_sensitive: 关键词是否区分大小写
- mode_enable: 启用的聊天模式
- parallel_action: 是否允许并行执行
- random_activation_probability: 随机激活概率
- llm_judge_prompt: LLM判断提示词
"""
# 默认激活设置(子类可以覆盖)
focus_activation_type: ActionActivationType = ActionActivationType.NEVER
normal_activation_type: ActionActivationType = ActionActivationType.NEVER
activation_keywords: list = []
keyword_case_sensitive: bool = False
mode_enable: ChatMode = ChatMode.ALL
parallel_action: bool = True
random_activation_probability: float = 0.0
llm_judge_prompt: str = ""
def __init__(self,
action_data: dict,
reasoning: str,
cycle_timers: dict,
thinking_id: str,
**kwargs):
"""初始化Action组件
Args:
action_data: 动作数据
reasoning: 执行该动作的理由
cycle_timers: 计时器字典
thinking_id: 思考ID
**kwargs: 其他参数(包含服务对象)
"""
self.action_data = action_data
self.reasoning = reasoning
self.cycle_timers = cycle_timers
self.thinking_id = thinking_id
# 创建API实例
self.api = PluginAPI(
chat_stream=kwargs.get("chat_stream"),
expressor=kwargs.get("expressor"),
replyer=kwargs.get("replyer"),
observations=kwargs.get("observations"),
log_prefix=kwargs.get("log_prefix", "")
)
self.log_prefix = kwargs.get("log_prefix", "")
logger.debug(f"{self.log_prefix} Action组件初始化完成")
async def send_reply(self, content: str) -> bool:
"""发送回复消息
Args:
content: 回复内容
Returns:
bool: 是否发送成功
"""
return await self.api.send_message("text", content)
@classmethod
def get_action_info(cls, name: str = None, description: str = None) -> 'ActionInfo':
"""从类属性生成ActionInfo
Args:
name: Action名称如果不提供则使用类名
description: Action描述如果不提供则使用类文档字符串
Returns:
ActionInfo: 生成的Action信息对象
"""
# 自动生成名称和描述
if name is None:
name = cls.__name__.lower().replace('action', '')
if description is None:
description = cls.__doc__ or f"{cls.__name__} Action组件"
description = description.strip().split('\n')[0] # 取第一行作为描述
return ActionInfo(
name=name,
component_type=ComponentType.ACTION,
description=description,
focus_activation_type=cls.focus_activation_type,
normal_activation_type=cls.normal_activation_type,
activation_keywords=cls.activation_keywords.copy() if cls.activation_keywords else [],
keyword_case_sensitive=cls.keyword_case_sensitive,
mode_enable=cls.mode_enable,
parallel_action=cls.parallel_action,
random_activation_probability=cls.random_activation_probability,
llm_judge_prompt=cls.llm_judge_prompt
)
@abstractmethod
async def execute(self) -> Tuple[bool, str]:
"""执行Action的抽象方法子类必须实现
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
pass

View File

@@ -0,0 +1,113 @@
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional, List
from src.common.logger_manager import get_logger
from src.plugin_system.apis.plugin_api import PluginAPI
from src.plugin_system.base.component_types import CommandInfo, ComponentType
from src.chat.message_receive.message import MessageRecv
logger = get_logger("base_command")
class BaseCommand(ABC):
"""Command组件基类
Command是插件的一种组件类型用于处理命令请求
子类可以通过类属性定义命令模式:
- command_pattern: 命令匹配的正则表达式
- command_help: 命令帮助信息
- command_examples: 命令使用示例列表
"""
# 默认命令设置(子类可以覆盖)
command_pattern: str = ""
command_help: str = ""
command_examples: List[str] = []
def __init__(self, message: MessageRecv):
"""初始化Command组件
Args:
message: 接收到的消息对象
"""
self.message = message
self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组
# 创建API实例
self.api = PluginAPI(
chat_stream=message.chat_stream,
log_prefix=f"[Command]"
)
self.log_prefix = f"[Command]"
logger.debug(f"{self.log_prefix} Command组件初始化完成")
def set_matched_groups(self, groups: Dict[str, str]) -> None:
"""设置正则表达式匹配的命名组
Args:
groups: 正则表达式匹配的命名组
"""
self.matched_groups = groups
@abstractmethod
async def execute(self) -> Tuple[bool, Optional[str]]:
"""执行Command的抽象方法子类必须实现
Returns:
Tuple[bool, Optional[str]]: (是否执行成功, 可选的回复消息)
"""
pass
async def send_reply(self, content: str) -> None:
"""发送回复消息
Args:
content: 回复内容
"""
# 获取聊天流信息
chat_stream = self.message.chat_stream
if chat_stream.group_info:
# 群聊
await self.api.send_text_to_group(
text=content,
group_id=str(chat_stream.group_info.group_id),
platform=chat_stream.platform
)
else:
# 私聊
await self.api.send_text_to_user(
text=content,
user_id=str(chat_stream.user_info.user_id),
platform=chat_stream.platform
)
@classmethod
def get_command_info(cls, name: str = None, description: str = None) -> 'CommandInfo':
"""从类属性生成CommandInfo
Args:
name: Command名称如果不提供则使用类名
description: Command描述如果不提供则使用类文档字符串
Returns:
CommandInfo: 生成的Command信息对象
"""
# 自动生成名称和描述
if name is None:
name = cls.__name__.lower().replace('command', '')
if description is None:
description = cls.__doc__ or f"{cls.__name__} Command组件"
description = description.strip().split('\n')[0] # 取第一行作为描述
return CommandInfo(
name=name,
component_type=ComponentType.COMMAND,
description=description,
command_pattern=cls.command_pattern,
command_help=cls.command_help,
command_examples=cls.command_examples.copy() if cls.command_examples else []
)

View File

@@ -0,0 +1,226 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Type, Optional, Any
import os
import inspect
import toml
from src.common.logger_manager import get_logger
from src.plugin_system.base.component_types import (
PluginInfo, ComponentInfo, ActionInfo, CommandInfo,
ComponentType, ActionActivationType, ChatMode
)
from src.plugin_system.core.component_registry import component_registry
logger = get_logger("base_plugin")
# 全局插件类注册表
_plugin_classes: Dict[str, Type['BasePlugin']] = {}
class BasePlugin(ABC):
"""插件基类
所有插件都应该继承这个基类,一个插件可以包含多种组件:
- Action组件处理聊天中的动作
- Command组件处理命令请求
- 未来可扩展Scheduler、Listener等
"""
# 插件基本信息(子类必须定义)
plugin_name: str = "" # 插件名称
plugin_description: str = "" # 插件描述
plugin_version: str = "1.0.0" # 插件版本
plugin_author: str = "" # 插件作者
enable_plugin: bool = True # 是否启用插件
dependencies: List[str] = [] # 依赖的其他插件
config_file_name: Optional[str] = None # 配置文件名
def __init__(self, plugin_dir: str = None):
"""初始化插件
Args:
plugin_dir: 插件目录路径,由插件管理器传递
"""
self.config: Dict[str, Any] = {} # 插件配置
self.plugin_dir = plugin_dir # 插件目录路径
self.log_prefix = f"[Plugin:{self.plugin_name}]"
# 验证插件信息
self._validate_plugin_info()
# 加载插件配置
self._load_plugin_config()
# 创建插件信息对象
self.plugin_info = PluginInfo(
name=self.plugin_name,
description=self.plugin_description,
version=self.plugin_version,
author=self.plugin_author,
enabled=self.enable_plugin,
is_built_in=False,
config_file=self.config_file_name or "",
dependencies=self.dependencies.copy()
)
logger.debug(f"{self.log_prefix} 插件基类初始化完成")
def _validate_plugin_info(self):
"""验证插件基本信息"""
if not self.plugin_name:
raise ValueError(f"插件类 {self.__class__.__name__} 必须定义 plugin_name")
if not self.plugin_description:
raise ValueError(f"插件 {self.plugin_name} 必须定义 plugin_description")
def _load_plugin_config(self):
"""加载插件配置文件"""
if not self.config_file_name:
logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载")
return
# 优先使用传入的插件目录路径
if self.plugin_dir:
plugin_dir = self.plugin_dir
else:
# fallback尝试从类的模块信息获取路径
try:
plugin_module_path = inspect.getfile(self.__class__)
plugin_dir = os.path.dirname(plugin_module_path)
except (TypeError, OSError):
# 最后的fallback从模块的__file__属性获取
module = inspect.getmodule(self.__class__)
if module and hasattr(module, '__file__') and module.__file__:
plugin_dir = os.path.dirname(module.__file__)
else:
logger.warning(f"{self.log_prefix} 无法获取插件目录路径,跳过配置加载")
return
config_file_path = os.path.join(plugin_dir, self.config_file_name)
if not os.path.exists(config_file_path):
logger.warning(f"{self.log_prefix} 配置文件 {config_file_path} 不存在")
return
file_ext = os.path.splitext(self.config_file_name)[1].lower()
if file_ext == ".toml":
with open(config_file_path, "r", encoding="utf-8") as f:
self.config = toml.load(f) or {}
logger.info(f"{self.log_prefix} 配置已从 {config_file_path} 加载")
else:
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
self.config = {}
@abstractmethod
def get_plugin_components(self) -> List[tuple[ComponentInfo, Type]]:
"""获取插件包含的组件列表
子类必须实现此方法,返回组件信息和组件类的列表
Returns:
List[tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...]
"""
pass
def register_plugin(self) -> bool:
"""注册插件及其所有组件"""
if not self.enable_plugin:
logger.info(f"{self.log_prefix} 插件已禁用,跳过注册")
return False
components = self.get_plugin_components()
# 检查依赖
if not self._check_dependencies():
logger.error(f"{self.log_prefix} 依赖检查失败,跳过注册")
return False
# 注册所有组件
registered_components = []
for component_info, component_class in components:
component_info.plugin_name = self.plugin_name
if component_registry.register_component(component_info, component_class):
registered_components.append(component_info)
else:
logger.warning(f"{self.log_prefix} 组件 {component_info.name} 注册失败")
# 更新插件信息中的组件列表
self.plugin_info.components = registered_components
# 注册插件
if component_registry.register_plugin(self.plugin_info):
logger.info(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个组件")
return True
else:
logger.error(f"{self.log_prefix} 插件注册失败")
return False
def _check_dependencies(self) -> bool:
"""检查插件依赖"""
if not self.dependencies:
return True
for dep in self.dependencies:
if not component_registry.get_plugin_info(dep):
logger.error(f"{self.log_prefix} 缺少依赖插件: {dep}")
return False
return True
def get_config(self, key: str, default: Any = None) -> Any:
"""获取插件配置值
Args:
key: 配置键名
default: 默认值
Returns:
Any: 配置值或默认值
"""
return self.config.get(key, default)
def register_plugin(cls):
"""插件注册装饰器
用法:
@register_plugin
class MyPlugin(BasePlugin):
plugin_name = "my_plugin"
plugin_description = "我的插件"
...
"""
if not issubclass(cls, BasePlugin):
logger.error(f"{cls.__name__} 不是 BasePlugin 的子类")
return cls
# 只是注册插件类,不立即实例化
# 插件管理器会负责实例化和注册
plugin_name = cls.plugin_name or cls.__name__
_plugin_classes[plugin_name] = cls
logger.debug(f"插件类已注册: {plugin_name}")
return cls
def get_registered_plugin_classes() -> Dict[str, Type['BasePlugin']]:
"""获取所有已注册的插件类"""
return _plugin_classes.copy()
def instantiate_and_register_plugin(plugin_class: Type['BasePlugin'], plugin_dir: str = None) -> bool:
"""实例化并注册插件
Args:
plugin_class: 插件类
plugin_dir: 插件目录路径
Returns:
bool: 是否成功
"""
try:
plugin_instance = plugin_class(plugin_dir=plugin_dir)
return plugin_instance.register_plugin()
except Exception as e:
logger.error(f"注册插件 {plugin_class.__name__} 时出错: {e}")
import traceback
logger.error(traceback.format_exc())
return False

View File

@@ -0,0 +1,104 @@
from enum import Enum
from typing import Dict, Any, List
from dataclasses import dataclass
# 组件类型枚举
class ComponentType(Enum):
"""组件类型枚举"""
ACTION = "action" # 动作组件
COMMAND = "command" # 命令组件
SCHEDULER = "scheduler" # 定时任务组件(预留)
LISTENER = "listener" # 事件监听组件(预留)
# 动作激活类型枚举
class ActionActivationType(Enum):
"""动作激活类型枚举"""
NEVER = "never" # 从不激活(默认关闭)
ALWAYS = "always" # 默认参与到planner
LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner
RANDOM = "random" # 随机启用action到planner
KEYWORD = "keyword" # 关键词触发启用action到planner
# 聊天模式枚举
class ChatMode(Enum):
"""聊天模式枚举"""
FOCUS = "focus" # Focus聊天模式
NORMAL = "normal" # Normal聊天模式
ALL = "all" # 所有聊天模式
@dataclass
class ComponentInfo:
"""组件信息"""
name: str # 组件名称
component_type: ComponentType # 组件类型
description: str # 组件描述
enabled: bool = True # 是否启用
plugin_name: str = "" # 所属插件名称
is_built_in: bool = False # 是否为内置组件
metadata: Dict[str, Any] = None # 额外元数据
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
@dataclass
class ActionInfo(ComponentInfo):
"""动作组件信息"""
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS
random_activation_probability: float = 0.3
llm_judge_prompt: str = ""
activation_keywords: List[str] = None
keyword_case_sensitive: bool = False
mode_enable: ChatMode = ChatMode.ALL
parallel_action: bool = False
action_parameters: Dict[str, Any] = None
action_require: List[str] = None
associated_types: List[str] = None
def __post_init__(self):
super().__post_init__()
if self.activation_keywords is None:
self.activation_keywords = []
if self.action_parameters is None:
self.action_parameters = {}
if self.action_require is None:
self.action_require = []
if self.associated_types is None:
self.associated_types = []
self.component_type = ComponentType.ACTION
@dataclass
class CommandInfo(ComponentInfo):
"""命令组件信息"""
command_pattern: str = "" # 命令匹配模式(正则表达式)
command_help: str = "" # 命令帮助信息
command_examples: List[str] = None # 命令使用示例
def __post_init__(self):
super().__post_init__()
if self.command_examples is None:
self.command_examples = []
self.component_type = ComponentType.COMMAND
@dataclass
class PluginInfo:
"""插件信息"""
name: str # 插件名称
description: str # 插件描述
version: str = "1.0.0" # 插件版本
author: str = "" # 插件作者
enabled: bool = True # 是否启用
is_built_in: bool = False # 是否为内置插件
components: List[ComponentInfo] = None # 包含的组件列表
dependencies: List[str] = None # 依赖的其他插件
config_file: str = "" # 配置文件路径
metadata: Dict[str, Any] = None # 额外元数据
def __post_init__(self):
if self.components is None:
self.components = []
if self.dependencies is None:
self.dependencies = []
if self.metadata is None:
self.metadata = {}

View File

@@ -0,0 +1,13 @@
"""
插件核心管理模块
提供插件的加载、注册和管理功能
"""
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
__all__ = [
'plugin_manager',
'component_registry',
]

View File

@@ -0,0 +1,245 @@
from typing import Dict, List, Type, Optional, Any, Pattern
from abc import ABC
import re
from src.common.logger_manager import get_logger
from src.plugin_system.base.component_types import (
ComponentInfo, ActionInfo, CommandInfo, PluginInfo,
ComponentType, ActionActivationType, ChatMode
)
logger = get_logger("component_registry")
class ComponentRegistry:
"""统一的组件注册中心
负责管理所有插件组件的注册、查询和生命周期管理
"""
def __init__(self):
# 组件注册表
self._components: Dict[str, ComponentInfo] = {} # 组件名 -> 组件信息
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {
ComponentType.ACTION: {},
ComponentType.COMMAND: {},
}
self._component_classes: Dict[str, Type] = {} # 组件名 -> 组件类
# 插件注册表
self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息
# Action特定注册表
self._action_registry: Dict[str, Type] = {} # action名 -> action类
self._default_actions: Dict[str, str] = {} # 启用的action名 -> 描述
# Command特定注册表
self._command_registry: Dict[str, Type] = {} # command名 -> command类
self._command_patterns: Dict[Pattern, Type] = {} # 编译后的正则 -> command类
logger.info("组件注册中心初始化完成")
# === 通用组件注册方法 ===
def register_component(self, component_info: ComponentInfo, component_class: Type) -> bool:
"""注册组件
Args:
component_info: 组件信息
component_class: 组件类
Returns:
bool: 是否注册成功
"""
component_name = component_info.name
component_type = component_info.component_type
if component_name in self._components:
logger.warning(f"组件 {component_name} 已存在,跳过注册")
return False
# 注册到通用注册表
self._components[component_name] = component_info
self._components_by_type[component_type][component_name] = component_info
self._component_classes[component_name] = component_class
# 根据组件类型进行特定注册
if component_type == ComponentType.ACTION:
self._register_action_component(component_info, component_class)
elif component_type == ComponentType.COMMAND:
self._register_command_component(component_info, component_class)
logger.info(f"已注册{component_type.value}组件: {component_name} ({component_class.__name__})")
return True
def _register_action_component(self, action_info: ActionInfo, action_class: Type):
"""注册Action组件到Action特定注册表"""
action_name = action_info.name
self._action_registry[action_name] = action_class
# 如果启用,添加到默认动作集
if action_info.enabled:
self._default_actions[action_name] = action_info.description
def _register_command_component(self, command_info: CommandInfo, command_class: Type):
"""注册Command组件到Command特定注册表"""
command_name = command_info.name
self._command_registry[command_name] = command_class
# 编译正则表达式并注册
if command_info.command_pattern:
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
self._command_patterns[pattern] = command_class
# === 组件查询方法 ===
def get_component_info(self, component_name: str) -> Optional[ComponentInfo]:
"""获取组件信息"""
return self._components.get(component_name)
def get_component_class(self, component_name: str) -> Optional[Type]:
"""获取组件类"""
return self._component_classes.get(component_name)
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取指定类型的所有组件"""
return self._components_by_type.get(component_type, {}).copy()
def get_enabled_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取指定类型的所有启用组件"""
components = self.get_components_by_type(component_type)
return {name: info for name, info in components.items() if info.enabled}
# === Action特定查询方法 ===
def get_action_registry(self) -> Dict[str, Type]:
"""获取Action注册表用于兼容现有系统"""
return self._action_registry.copy()
def get_default_actions(self) -> Dict[str, str]:
"""获取默认启用的Action列表用于兼容现有系统"""
return self._default_actions.copy()
def get_action_info(self, action_name: str) -> Optional[ActionInfo]:
"""获取Action信息"""
info = self.get_component_info(action_name)
return info if isinstance(info, ActionInfo) else None
# === Command特定查询方法 ===
def get_command_registry(self) -> Dict[str, Type]:
"""获取Command注册表用于兼容现有系统"""
return self._command_registry.copy()
def get_command_patterns(self) -> Dict[Pattern, Type]:
"""获取Command模式注册表用于兼容现有系统"""
return self._command_patterns.copy()
def get_command_info(self, command_name: str) -> Optional[CommandInfo]:
"""获取Command信息"""
info = self.get_component_info(command_name)
return info if isinstance(info, CommandInfo) else None
def find_command_by_text(self, text: str) -> Optional[tuple[Type, dict]]:
"""根据文本查找匹配的命令
Args:
text: 输入文本
Returns:
Optional[tuple[Type, dict]]: (命令类, 匹配的命名组) 或 None
"""
for pattern, command_class in self._command_patterns.items():
match = pattern.match(text)
if match:
command_name = None
# 查找对应的组件信息
for name, cls in self._command_registry.items():
if cls == command_class:
command_name = name
break
# 检查命令是否启用
if command_name:
command_info = self.get_command_info(command_name)
if command_info and command_info.enabled:
return command_class, match.groupdict()
return None
# === 插件管理方法 ===
def register_plugin(self, plugin_info: PluginInfo) -> bool:
"""注册插件
Args:
plugin_info: 插件信息
Returns:
bool: 是否注册成功
"""
plugin_name = plugin_info.name
if plugin_name in self._plugins:
logger.warning(f"插件 {plugin_name} 已存在,跳过注册")
return False
self._plugins[plugin_name] = plugin_info
logger.info(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})")
return True
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
"""获取插件信息"""
return self._plugins.get(plugin_name)
def get_all_plugins(self) -> Dict[str, PluginInfo]:
"""获取所有插件"""
return self._plugins.copy()
def get_enabled_plugins(self) -> Dict[str, PluginInfo]:
"""获取所有启用的插件"""
return {name: info for name, info in self._plugins.items() if info.enabled}
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
"""获取插件的所有组件"""
plugin_info = self.get_plugin_info(plugin_name)
return plugin_info.components if plugin_info else []
# === 状态管理方法 ===
def enable_component(self, component_name: str) -> bool:
"""启用组件"""
if component_name in self._components:
self._components[component_name].enabled = True
# 如果是Action更新默认动作集
component_info = self._components[component_name]
if isinstance(component_info, ActionInfo):
self._default_actions[component_name] = component_info.description
logger.info(f"已启用组件: {component_name}")
return True
return False
def disable_component(self, component_name: str) -> bool:
"""禁用组件"""
if component_name in self._components:
self._components[component_name].enabled = False
# 如果是Action从默认动作集中移除
if component_name in self._default_actions:
del self._default_actions[component_name]
logger.info(f"已禁用组件: {component_name}")
return True
return False
def get_registry_stats(self) -> Dict[str, Any]:
"""获取注册中心统计信息"""
return {
"total_components": len(self._components),
"total_plugins": len(self._plugins),
"components_by_type": {
component_type.value: len(components)
for component_type, components in self._components_by_type.items()
},
"enabled_components": len([c for c in self._components.values() if c.enabled]),
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
}
# 全局组件注册中心实例
component_registry = ComponentRegistry()

View File

@@ -0,0 +1,223 @@
from typing import Dict, List, Optional, Any
import os
import importlib
import importlib.util
from pathlib import Path
from src.common.logger_manager import get_logger
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import PluginInfo, ComponentType
logger = get_logger("plugin_manager")
class PluginManager:
"""插件管理器
负责加载、初始化和管理所有插件及其组件
"""
def __init__(self):
self.plugin_directories: List[str] = []
self.loaded_plugins: Dict[str, Any] = {}
self.failed_plugins: Dict[str, str] = {}
logger.info("插件管理器初始化完成")
def add_plugin_directory(self, directory: str):
"""添加插件目录"""
if os.path.exists(directory):
self.plugin_directories.append(directory)
logger.info(f"已添加插件目录: {directory}")
else:
logger.warning(f"插件目录不存在: {directory}")
def load_all_plugins(self) -> tuple[int, int]:
"""加载所有插件目录中的插件
Returns:
tuple[int, int]: (插件数量, 组件数量)
"""
logger.info("开始加载所有插件...")
# 第一阶段:加载所有插件模块(注册插件类)
total_loaded_modules = 0
total_failed_modules = 0
for directory in self.plugin_directories:
loaded, failed = self._load_plugin_modules_from_directory(directory)
total_loaded_modules += loaded
total_failed_modules += failed
logger.info(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
# 第二阶段:实例化所有已注册的插件类
from src.plugin_system.base.base_plugin import get_registered_plugin_classes, instantiate_and_register_plugin
plugin_classes = get_registered_plugin_classes()
total_registered = 0
total_failed_registration = 0
for plugin_name, plugin_class in plugin_classes.items():
# 尝试找到插件对应的目录
plugin_dir = self._find_plugin_directory(plugin_class)
if instantiate_and_register_plugin(plugin_class, plugin_dir):
total_registered += 1
self.loaded_plugins[plugin_name] = plugin_class
else:
total_failed_registration += 1
self.failed_plugins[plugin_name] = "插件注册失败"
logger.info(f"插件注册完成 - 成功: {total_registered}, 失败: {total_failed_registration}")
# 获取组件统计信息
stats = component_registry.get_registry_stats()
logger.info(f"组件注册统计: {stats}")
# 返回插件数量和组件数量
return total_registered, stats.get('total_components', 0)
def _find_plugin_directory(self, plugin_class) -> Optional[str]:
"""查找插件类对应的目录路径"""
try:
import inspect
module = inspect.getmodule(plugin_class)
if module and hasattr(module, '__file__') and module.__file__:
return os.path.dirname(module.__file__)
except Exception:
pass
return None
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
"""从指定目录加载插件模块"""
loaded_count = 0
failed_count = 0
if not os.path.exists(directory):
logger.warning(f"插件目录不存在: {directory}")
return loaded_count, failed_count
logger.info(f"正在扫描插件目录: {directory}")
# 遍历目录中的所有Python文件和包
for item in os.listdir(directory):
item_path = os.path.join(directory, item)
if os.path.isfile(item_path) and item.endswith('.py') and item != '__init__.py':
# 单文件插件
if self._load_plugin_module_file(item_path):
loaded_count += 1
else:
failed_count += 1
elif os.path.isdir(item_path) and not item.startswith('.') and not item.startswith('__'):
# 插件包
plugin_file = os.path.join(item_path, 'plugin.py')
if os.path.exists(plugin_file):
if self._load_plugin_module_file(plugin_file):
loaded_count += 1
else:
failed_count += 1
return loaded_count, failed_count
def _load_plugin_module_file(self, plugin_file: str) -> bool:
"""加载单个插件模块文件"""
plugin_name = None
# 生成模块名
plugin_path = Path(plugin_file)
if plugin_path.parent.name != 'plugins':
# 插件包格式parent_dir.plugin
module_name = f"plugins.{plugin_path.parent.name}.plugin"
else:
# 单文件格式plugins.filename
module_name = f"plugins.{plugin_path.stem}"
try:
# 动态导入插件模块
spec = importlib.util.spec_from_file_location(module_name, plugin_file)
if spec is None or spec.loader is None:
logger.error(f"无法创建模块规范: {plugin_file}")
return False
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# 模块加载成功,插件类会自动通过装饰器注册
plugin_name = plugin_path.parent.name if plugin_path.parent.name != 'plugins' else plugin_path.stem
logger.debug(f"插件模块加载成功: {plugin_file}")
return True
except Exception as e:
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
logger.error(error_msg)
if plugin_name:
self.failed_plugins[plugin_name] = error_msg
return False
def get_loaded_plugins(self) -> List[PluginInfo]:
"""获取所有已加载的插件信息"""
return list(component_registry.get_all_plugins().values())
def get_enabled_plugins(self) -> List[PluginInfo]:
"""获取所有启用的插件信息"""
return list(component_registry.get_enabled_plugins().values())
def enable_plugin(self, plugin_name: str) -> bool:
"""启用插件"""
plugin_info = component_registry.get_plugin_info(plugin_name)
if plugin_info:
plugin_info.enabled = True
# 启用插件的所有组件
for component in plugin_info.components:
component_registry.enable_component(component.name)
logger.info(f"已启用插件: {plugin_name}")
return True
return False
def disable_plugin(self, plugin_name: str) -> bool:
"""禁用插件"""
plugin_info = component_registry.get_plugin_info(plugin_name)
if plugin_info:
plugin_info.enabled = False
# 禁用插件的所有组件
for component in plugin_info.components:
component_registry.disable_component(component.name)
logger.info(f"已禁用插件: {plugin_name}")
return True
return False
def get_plugin_stats(self) -> Dict[str, Any]:
"""获取插件统计信息"""
all_plugins = component_registry.get_all_plugins()
enabled_plugins = component_registry.get_enabled_plugins()
action_components = component_registry.get_components_by_type(ComponentType.ACTION)
command_components = component_registry.get_components_by_type(ComponentType.COMMAND)
return {
"total_plugins": len(all_plugins),
"enabled_plugins": len(enabled_plugins),
"failed_plugins": len(self.failed_plugins),
"total_components": len(action_components) + len(command_components),
"action_components": len(action_components),
"command_components": len(command_components),
"loaded_plugin_files": len(self.loaded_plugins),
"failed_plugin_details": self.failed_plugins.copy()
}
def reload_plugin(self, plugin_name: str) -> bool:
"""重新加载插件(高级功能,需要谨慎使用)"""
# TODO: 实现插件热重载功能
logger.warning("插件热重载功能尚未实现")
return False
# 全局插件管理器实例
plugin_manager = PluginManager()
# 默认插件目录
plugin_manager.add_plugin_directory("src/plugins/built_in")
plugin_manager.add_plugin_directory("src/plugins/examples")
plugin_manager.add_plugin_directory("plugins") # 用户插件目录