refc:重构插件api,补全文档,合并expressor和replyer,分离reply和sender,新log浏览器

This commit is contained in:
SengokuCola
2025-06-19 20:20:34 +08:00
parent 7e05ede846
commit ab28b94e33
63 changed files with 5285 additions and 8316 deletions

View File

@@ -8,6 +8,7 @@ MaiBot 插件系统
from src.plugin_system.base.base_plugin import BasePlugin, register_plugin
from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.config_types import ConfigField
from src.plugin_system.base.component_types import (
ComponentType,
ActionActivationType,
@@ -18,11 +19,11 @@ from src.plugin_system.base.component_types import (
PluginInfo,
PythonDependency,
)
from src.plugin_system.apis.plugin_api import PluginAPI, create_plugin_api, create_command_api
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.dependency_manager import dependency_manager
__version__ = "1.0.0"
__all__ = [
@@ -39,14 +40,11 @@ __all__ = [
"CommandInfo",
"PluginInfo",
"PythonDependency",
# API接口
"PluginAPI",
"create_plugin_api",
"create_command_api",
# 管理器
"plugin_manager",
"component_registry",
"dependency_manager",
# 装饰器
"register_plugin",
"ConfigField",
]

View File

@@ -1,37 +1,33 @@
"""
插件API模块
插件系统API模块
提供插件可以使用的各种API接口
提供插件开发所需的各种API
"""
from src.plugin_system.apis.plugin_api import PluginAPI, create_plugin_api, create_command_api
from src.plugin_system.apis.message_api import MessageAPI
from src.plugin_system.apis.llm_api import LLMAPI
from src.plugin_system.apis.database_api import DatabaseAPI
from src.plugin_system.apis.config_api import ConfigAPI
from src.plugin_system.apis.utils_api import UtilsAPI
from src.plugin_system.apis.stream_api import StreamAPI
from src.plugin_system.apis.hearflow_api import HearflowAPI
# 新增分类的API聚合
from src.plugin_system.apis.action_apis import ActionAPI
from src.plugin_system.apis.independent_apis import IndependentAPI, StaticAPI
# 导入所有API模块
from src.plugin_system.apis import (
chat_api,
config_api,
database_api,
emoji_api,
generator_api,
llm_api,
message_api,
person_api,
send_api,
utils_api
)
# 导出所有API模块使它们可以通过 apis.xxx 方式访问
__all__ = [
# 原有统一API
"PluginAPI",
"create_plugin_api",
"create_command_api",
# 原有单独API
"MessageAPI",
"LLMAPI",
"DatabaseAPI",
"ConfigAPI",
"UtilsAPI",
"StreamAPI",
"HearflowAPI",
# 新增分类API
"ActionAPI", # 需要Action依赖的API
"IndependentAPI", # 独立API
"StaticAPI", # 静态API
"chat_api",
"config_api",
"database_api",
"emoji_api",
"generator_api",
"llm_api",
"message_api",
"person_api",
"send_api",
"utils_api"
]

View File

@@ -1,88 +0,0 @@
"""
Action相关API聚合模块
聚合了需要Action组件依赖的API这些API需要通过Action初始化时注入的服务对象才能正常工作。
包括MessageAPI、DatabaseAPI等需要chat_stream、expressor等服务的API。
"""
from src.plugin_system.apis.message_api import MessageAPI
from src.plugin_system.apis.database_api import DatabaseAPI
from src.common.logger import get_logger
logger = get_logger("action_apis")
class ActionAPI(MessageAPI, DatabaseAPI):
"""
Action相关API聚合类
聚合了需要Action组件依赖的API功能。这些API需要以下依赖
- _services: 包含chat_stream、expressor、replyer、observations等服务对象
- log_prefix: 日志前缀
- thinking_id: 思考ID
- cycle_timers: 计时器
- action_data: Action数据
使用场景:
- 在Action组件中使用需要发送消息、存储数据等功能
- 需要访问聊天上下文和执行环境的操作
"""
def __init__(
self,
chat_stream=None,
expressor=None,
replyer=None,
observations=None,
log_prefix: str = "[ActionAPI]",
thinking_id: str = "",
cycle_timers: dict = None,
action_data: dict = None,
):
"""
初始化Action相关API
Args:
chat_stream: 聊天流对象
expressor: 表达器对象
replyer: 回复器对象
observations: 观察列表
log_prefix: 日志前缀
thinking_id: 思考ID
cycle_timers: 计时器字典
action_data: Action数据
"""
# 存储依赖对象
self._services = {
"chat_stream": chat_stream,
"expressor": expressor,
"replyer": replyer,
"observations": observations or [],
}
self.log_prefix = log_prefix
self.thinking_id = thinking_id
self.cycle_timers = cycle_timers or {}
self.action_data = action_data or {}
logger.debug(f"{self.log_prefix} ActionAPI 初始化完成")
def set_chat_stream(self, chat_stream):
"""设置聊天流对象"""
self._services["chat_stream"] = chat_stream
logger.debug(f"{self.log_prefix} 设置聊天流")
def set_expressor(self, expressor):
"""设置表达器对象"""
self._services["expressor"] = expressor
logger.debug(f"{self.log_prefix} 设置表达器")
def set_replyer(self, replyer):
"""设置回复器对象"""
self._services["replyer"] = replyer
logger.debug(f"{self.log_prefix} 设置回复器")
def set_observations(self, observations):
"""设置观察列表"""
self._services["observations"] = observations or []
logger.debug(f"{self.log_prefix} 设置观察列表")

View File

@@ -0,0 +1,292 @@
"""
聊天API模块
专门负责聊天信息的查询和管理采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import chat_api
streams = chat_api.get_all_group_streams()
chat_type = chat_api.get_stream_type(stream)
或者:
from src.plugin_system.apis.chat_api import ChatManager as chat
streams = chat.get_all_group_streams()
"""
from typing import List, Dict, Any, Optional
from src.common.logger import get_logger
# 导入依赖
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.focus_chat.info.obs_info import ObsInfo
logger = get_logger("chat_api")
class ChatManager:
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: str = "qq") -> List[ChatStream]:
"""获取所有聊天流
Args:
platform: 平台筛选,默认为"qq"
Returns:
List[ChatStream]: 聊天流列表
"""
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if stream.platform == platform:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的聊天流")
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流失败: {e}")
return streams
@staticmethod
def get_group_streams(platform: str = "qq") -> List[ChatStream]:
"""获取所有群聊聊天流
Args:
platform: 平台筛选,默认为"qq"
Returns:
List[ChatStream]: 群聊聊天流列表
"""
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if stream.platform == platform and stream.group_info:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e:
logger.error(f"[ChatAPI] 获取群聊流失败: {e}")
return streams
@staticmethod
def get_private_streams(platform: str = "qq") -> List[ChatStream]:
"""获取所有私聊聊天流
Args:
platform: 平台筛选,默认为"qq"
Returns:
List[ChatStream]: 私聊聊天流列表
"""
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if stream.platform == platform and not stream.group_info:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e:
logger.error(f"[ChatAPI] 获取私聊流失败: {e}")
return streams
@staticmethod
def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]:
"""根据群ID获取聊天流
Args:
group_id: 群聊ID
platform: 平台,默认为"qq"
Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None
"""
try:
for _, stream in get_chat_manager().streams.items():
if (
stream.group_info
and str(stream.group_info.group_id) == str(group_id)
and stream.platform == platform
):
logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流")
return stream
logger.warning(f"[ChatAPI] 未找到群ID {group_id} 的聊天流")
except Exception as e:
logger.error(f"[ChatAPI] 查找群聊流失败: {e}")
return None
@staticmethod
def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]:
"""根据用户ID获取私聊流
Args:
user_id: 用户ID
platform: 平台,默认为"qq"
Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None
"""
try:
for _, stream in get_chat_manager().streams.items():
if (
not stream.group_info
and str(stream.user_info.user_id) == str(user_id)
and stream.platform == platform
):
logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流")
return stream
logger.warning(f"[ChatAPI] 未找到用户ID {user_id} 的私聊流")
except Exception as e:
logger.error(f"[ChatAPI] 查找私聊流失败: {e}")
return None
@staticmethod
def get_stream_type(chat_stream: ChatStream) -> str:
"""获取聊天流类型
Args:
chat_stream: 聊天流对象
Returns:
str: 聊天类型 ("group", "private", "unknown")
"""
if not chat_stream:
return "unknown"
if hasattr(chat_stream, "group_info"):
return "group" if chat_stream.group_info else "private"
return "unknown"
@staticmethod
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
"""获取聊天流详细信息
Args:
chat_stream: 聊天流对象
Returns:
Dict[str, Any]: 聊天流信息字典
"""
if not chat_stream:
return {}
try:
info = {
"stream_id": chat_stream.stream_id,
"platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream),
}
if chat_stream.group_info:
info.update({
"group_id": chat_stream.group_info.group_id,
"group_name": getattr(chat_stream.group_info, "group_name", "未知群聊"),
})
if chat_stream.user_info:
info.update({
"user_id": chat_stream.user_info.user_id,
"user_name": chat_stream.user_info.user_nickname,
})
return info
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流信息失败: {e}")
return {}
@staticmethod
def get_recent_messages_from_obs(observations: List[Any], count: int = 5) -> List[Dict[str, Any]]:
"""从观察对象获取最近的消息
Args:
observations: 观察对象列表
count: 要获取的消息数量
Returns:
List[Dict]: 消息列表,每个消息包含发送者、内容等信息
"""
messages = []
try:
if observations and len(observations) > 0:
obs = observations[0]
if hasattr(obs, "get_talking_message"):
obs: ObsInfo
raw_messages = obs.get_talking_message()
# 转换为简化格式
for msg in raw_messages[-count:]:
simple_msg = {
"sender": msg.get("sender", "未知"),
"content": msg.get("content", ""),
"timestamp": msg.get("timestamp", 0),
}
messages.append(simple_msg)
logger.debug(f"[ChatAPI] 获取到 {len(messages)} 条最近消息")
except Exception as e:
logger.error(f"[ChatAPI] 获取最近消息失败: {e}")
return messages
@staticmethod
def get_streams_summary() -> Dict[str, int]:
"""获取聊天流统计摘要
Returns:
Dict[str, int]: 包含各种统计信息的字典
"""
try:
all_streams = ChatManager.get_all_streams()
group_streams = ChatManager.get_group_streams()
private_streams = ChatManager.get_private_streams()
summary = {
"total_streams": len(all_streams),
"group_streams": len(group_streams),
"private_streams": len(private_streams),
"qq_streams": len([s for s in all_streams if s.platform == "qq"]),
}
logger.debug(f"[ChatAPI] 聊天流统计: {summary}")
return summary
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}")
return {"total_streams": 0, "group_streams": 0, "private_streams": 0, "qq_streams": 0}
# =============================================================================
# 模块级别的便捷函数 - 类似 requests.get(), requests.post() 的设计
# =============================================================================
def get_all_streams(platform: str = "qq") -> List[ChatStream]:
"""获取所有聊天流的便捷函数"""
return ChatManager.get_all_streams(platform)
def get_group_streams(platform: str = "qq") -> List[ChatStream]:
"""获取群聊聊天流的便捷函数"""
return ChatManager.get_group_streams(platform)
def get_private_streams(platform: str = "qq") -> List[ChatStream]:
"""获取私聊聊天流的便捷函数"""
return ChatManager.get_private_streams(platform)
def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]:
"""根据群ID获取聊天流的便捷函数"""
return ChatManager.get_stream_by_group_id(group_id, platform)
def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]:
"""根据用户ID获取私聊流的便捷函数"""
return ChatManager.get_stream_by_user_id(user_id, platform)
def get_stream_type(chat_stream: ChatStream) -> str:
"""获取聊天流类型的便捷函数"""
return ChatManager.get_stream_type(chat_stream)
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
"""获取聊天流信息的便捷函数"""
return ChatManager.get_stream_info(chat_stream)
def get_streams_summary() -> Dict[str, int]:
"""获取聊天流统计摘要的便捷函数"""
return ChatManager.get_streams_summary()

View File

@@ -1,3 +1,12 @@
"""配置API模块
提供了配置读取和用户信息获取等功能
使用方式:
from src.plugin_system.apis import config_api
value = config_api.get_global_config("section.key")
platform, user_id = await config_api.get_user_id_by_person_name("用户名")
"""
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config
@@ -6,92 +15,104 @@ from src.person_info.person_info import get_person_info_manager
logger = get_logger("config_api")
class ConfigAPI:
"""配置API模块
# =============================================================================
# 配置访问API函数
# =============================================================================
提供了配置读取和用户信息获取等功能
def get_global_config(key: str, default: Any = None) -> Any:
"""
安全地从全局配置中获取一个值。
插件应使用此方法读取全局配置,以保证只读和隔离性。
def get_global_config(self, key: str, default: Any = None) -> Any:
"""
安全地从全局配置中获取一个值。
插件应使用此方法读取全局配置,以保证只读和隔离性。
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 如果配置不存在时返回的默认值
Returns:
Any: 配置值或默认值
"""
# 支持嵌套键访问
keys = key.split(".")
current = global_config
try:
for k in keys:
if hasattr(current, k):
current = getattr(current, k)
else:
return default
return current
except Exception as e:
logger.warning(f"获取全局配置 {key} 失败: {e}")
return default
def get_config(self, key: str, default: Any = None) -> Any:
"""
从插件配置中获取值,支持嵌套键访问
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 如果配置不存在时返回的默认值
Returns:
Any: 配置值或默认值
"""
# 获取插件配置
plugin_config = getattr(self, "_plugin_config", {})
if not plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = plugin_config
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 如果配置不存在时返回的默认值
Returns:
Any: 配置值或默认值
"""
# 支持嵌套键访问
keys = key.split(".")
current = global_config
try:
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
if hasattr(current, k):
current = getattr(current, k)
else:
return default
return current
except Exception as e:
logger.warning(f"[ConfigAPI] 获取全局配置 {key} 失败: {e}")
return default
async def get_user_id_by_person_name(self, person_name: str) -> tuple[str, str]:
"""根据用户名获取用户ID
Args:
person_name: 用户名
def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any:
"""
从插件配置中获取值,支持嵌套键访问
Returns:
tuple[str, str]: (平台, 用户ID)
"""
Args:
plugin_config: 插件配置字典
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 如果配置不存在时返回的默认值
Returns:
Any: 配置值或默认值
"""
if not plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
# =============================================================================
# 用户信息API函数
# =============================================================================
async def get_user_id_by_person_name(person_name: str) -> tuple[str, str]:
"""根据用户名获取用户ID
Args:
person_name: 用户名
Returns:
tuple[str, str]: (平台, 用户ID)
"""
try:
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(person_name)
user_id = await person_info_manager.get_value(person_id, "user_id")
platform = await person_info_manager.get_value(person_id, "platform")
return platform, user_id
except Exception as e:
logger.error(f"[ConfigAPI] 根据用户名获取用户ID失败: {e}")
return "", ""
async def get_person_info(self, person_id: str, key: str, default: Any = None) -> Any:
"""获取用户信息
Args:
person_id: 用户ID
key: 信息键名
default: 默认值
async def get_person_info(person_id: str, key: str, default: Any = None) -> Any:
"""获取用户信息
Returns:
Any: 用户信息值或默认值
"""
Args:
person_id: 用户ID
key: 信息键名
default: 默认值
Returns:
Any: 用户信息值或默认值
"""
try:
person_info_manager = get_person_info_manager()
return await person_info_manager.get_value(person_id, key, default)
except Exception as e:
logger.error(f"[ConfigAPI] 获取用户信息失败: {e}")
return default

View File

@@ -1,352 +1,97 @@
"""数据库API模块
提供数据库操作相关功能采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import database_api
records = await database_api.db_query(ActionRecords, query_type="get")
record = await database_api.db_save(ActionRecords, data={"action_id": "123"})
"""
import traceback
import time
from typing import Dict, List, Any, Union, Type
from src.common.logger import get_logger
from src.common.database.database_model import ActionRecords
from src.common.database.database import db
from peewee import Model, DoesNotExist
logger = get_logger("database_api")
# =============================================================================
# 通用数据库查询API函数
# =============================================================================
class DatabaseAPI:
"""数据库API模块
async def db_query(
model_class: Type[Model],
query_type: str = "get",
filters: Dict[str, Any] = None,
data: Dict[str, Any] = None,
limit: int = None,
order_by: List[str] = None,
single_result: bool = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
提供了数据库操作相关的功能
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
Args:
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
filters: 过滤条件字典,键为字段名,值为要匹配的值
data: 用于创建或更新的数据字典
limit: 限制结果数量
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回不同的结果:
- "get": 返回查询结果列表或单个结果(如果 single_result=True
- "create": 返回创建的记录
- "update": 返回受影响的行数
- "delete": 返回受影响的行数
- "count": 返回记录数量
示例:
# 查询最近10条消息
messages = await database_api.db_query(
Messages,
query_type="get",
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by=["-time"]
)
# 创建一条记录
new_record = await database_api.db_query(
ActionRecords,
query_type="create",
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
)
# 更新记录
updated_count = await database_api.db_query(
ActionRecords,
query_type="update",
filters={"action_id": "123"},
data={"action_done": True}
)
# 删除记录
deleted_count = await database_api.db_query(
ActionRecords,
query_type="delete",
filters={"action_id": "123"}
)
# 计数
count = await database_api.db_query(
Messages,
query_type="count",
filters={"chat_id": chat_stream.stream_id}
)
"""
async def store_action_info(
self,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: dict = None,
) -> None:
"""存储action信息到数据库
Args:
action_build_into_prompt: 是否构建到提示中
action_prompt_display: 显示的action提示信息
action_done: action是否完成
thinking_id: 思考ID
action_data: action数据如果不提供则使用空字典
"""
try:
chat_stream = self.get_service("chat_stream")
if not chat_stream:
logger.error(f"{self.log_prefix} 无法存储action信息缺少chat_stream服务")
return
action_time = time.time()
action_id = f"{action_time}_{thinking_id}"
ActionRecords.create(
action_id=action_id,
time=action_time,
action_name=self.__class__.__name__,
action_data=str(action_data or {}),
action_done=action_done,
action_build_into_prompt=action_build_into_prompt,
action_prompt_display=action_prompt_display,
chat_id=chat_stream.stream_id,
chat_info_stream_id=chat_stream.stream_id,
chat_info_platform=chat_stream.platform,
user_id=chat_stream.user_info.user_id if chat_stream.user_info else "",
user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "",
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "",
)
logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}")
except Exception as e:
logger.error(f"{self.log_prefix} 存储action信息时出错: {e}")
traceback.print_exc()
async def db_query(
self,
model_class: Type[Model],
query_type: str = "get",
filters: Dict[str, Any] = None,
data: Dict[str, Any] = None,
limit: int = None,
order_by: List[str] = None,
single_result: bool = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
Args:
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
filters: 过滤条件字典,键为字段名,值为要匹配的值
data: 用于创建或更新的数据字典
limit: 限制结果数量
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回不同的结果:
- "get": 返回查询结果列表或单个结果(如果 single_result=True
- "create": 返回创建的记录
- "update": 返回受影响的行数
- "delete": 返回受影响的行数
- "count": 返回记录数量
示例:
# 查询最近10条消息
messages = await self.db_query(
Messages,
query_type="get",
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by=["-time"]
)
# 创建一条记录
new_record = await self.db_query(
ActionRecords,
query_type="create",
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
)
# 更新记录
updated_count = await self.db_query(
ActionRecords,
query_type="update",
filters={"action_id": "123"},
data={"action_done": True}
)
# 删除记录
deleted_count = await self.db_query(
ActionRecords,
query_type="delete",
filters={"action_id": "123"}
)
# 计数
count = await self.db_query(
Messages,
query_type="count",
filters={"chat_id": chat_stream.stream_id}
)
"""
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
# 构建基本查询
if query_type in ["get", "update", "delete", "count"]:
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 执行查询
if query_type == "get":
# 应用排序
if order_by:
for field in order_by:
if field.startswith("-"):
query = query.order_by(getattr(model_class, field[1:]).desc())
else:
query = query.order_by(getattr(model_class, field))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if single_result:
return results[0] if results else None
return results
elif query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
# 创建记录
record = model_class.create(**data)
# 返回创建的记录
return model_class.select().where(model_class.id == record.id).dicts().get()
elif query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
# 更新记录
return query.update(**data).execute()
elif query_type == "delete":
# 删除记录
return query.delete().execute()
elif query_type == "count":
# 计数
return query.count()
else:
raise ValueError(f"不支持的查询类型: {query_type}")
except DoesNotExist:
# 记录不存在
if query_type == "get" and single_result:
return None
return []
except Exception as e:
logger.error(f"{self.log_prefix} 数据库操作出错: {e}")
traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get":
return None if single_result else []
elif query_type in ["create", "update", "delete", "count"]:
return None
return None
async def db_raw_query(
self, sql: str, params: List[Any] = None, fetch_results: bool = True
) -> Union[List[Dict[str, Any]], int, None]:
"""执行原始SQL查询
警告: 使用此方法需要小心确保SQL语句已正确构造以避免SQL注入风险。
Args:
sql: 原始SQL查询字符串
params: 查询参数列表用于替换SQL中的占位符
fetch_results: 是否获取查询结果对于SELECT查询设为True对于
UPDATE/INSERT/DELETE等操作设为False
Returns:
如果fetch_results为True返回查询结果列表
如果fetch_results为False返回受影响的行数
如果出错返回None
"""
try:
cursor = db.execute_sql(sql, params or [])
if fetch_results:
# 获取列名
columns = [col[0] for col in cursor.description]
# 构建结果字典列表
results = []
for row in cursor.fetchall():
results.append(dict(zip(columns, row)))
return results
else:
# 返回受影响的行数
return cursor.rowcount
except Exception as e:
logger.error(f"{self.log_prefix} 执行原始SQL查询出错: {e}")
traceback.print_exc()
return None
async def db_save(
self, model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None
) -> Union[Dict[str, Any], None]:
"""保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
如果没有找到匹配记录或未提供key_field和key_value则创建新记录。
Args:
model_class: Peewee模型类如ActionRecords, Messages等
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名,例如"action_id"
key_value: 用于查找现有记录的字段值
Returns:
Dict[str, Any]: 保存后的记录数据
None: 如果操作失败
示例:
# 创建或更新一条记录
record = await self.db_save(
ActionRecords,
{
"action_id": "123",
"time": time.time(),
"action_name": "TestAction",
"action_done": True
},
key_field="action_id",
key_value="123"
)
"""
try:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
# 查找现有记录
existing_records = list(
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
)
if existing_records:
# 更新现有记录
existing_record = existing_records[0]
for field, value in data.items():
setattr(existing_record, field, value)
existing_record.save()
# 返回更新后的记录
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get()
return updated_record
# 如果没有找到现有记录或未提供key_field和key_value创建新记录
new_record = model_class.create(**data)
# 返回创建的记录
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get()
return created_record
except Exception as e:
logger.error(f"{self.log_prefix} 保存数据库记录出错: {e}")
traceback.print_exc()
return None
async def db_get(
self, model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
这是db_query方法的简化版本专注于数据检索操作。
Args:
model_class: Peewee模型类
filters: 过滤条件,字段名和值的字典
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
limit: 结果数量限制如果为1则返回单个记录而不是列表
Returns:
如果limit=1返回单个记录字典或None
否则返回记录字典列表或空列表。
示例:
# 获取单个记录
record = await self.db_get(
ActionRecords,
filters={"action_id": "123"},
limit=1
)
# 获取最近10条记录
records = await self.db_get(
Messages,
filters={"chat_id": chat_stream.stream_id},
order_by="-time",
limit=10
)
"""
try:
# 构建查询
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
# 构建基本查询
if query_type in ["get", "update", "delete", "count"]:
query = model_class.select()
# 应用过滤条件
@@ -354,12 +99,15 @@ class DatabaseAPI:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 执行查询
if query_type == "get":
# 应用排序
if order_by:
if order_by.startswith("-"):
query = query.order_by(getattr(model_class, order_by[1:]).desc())
else:
query = query.order_by(getattr(model_class, order_by))
for field in order_by:
if field.startswith("-"):
query = query.order_by(getattr(model_class, field[1:]).desc())
else:
query = query.order_by(getattr(model_class, field))
# 应用限制
if limit:
@@ -369,11 +117,270 @@ class DatabaseAPI:
results = list(query.dicts())
# 返回结果
if limit == 1:
if single_result:
return results[0] if results else None
return results
except Exception as e:
logger.error(f"{self.log_prefix} 获取数据库记录出错: {e}")
traceback.print_exc()
return None if limit == 1 else []
elif query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
# 创建记录
record = model_class.create(**data)
# 返回创建的记录
return model_class.select().where(model_class.id == record.id).dicts().get()
elif query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
# 更新记录
return query.update(**data).execute()
elif query_type == "delete":
# 删除记录
return query.delete().execute()
elif query_type == "count":
# 计数
return query.count()
else:
raise ValueError(f"不支持的查询类型: {query_type}")
except DoesNotExist:
# 记录不存在
if query_type == "get" and single_result:
return None
return []
except Exception as e:
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get":
return None if single_result else []
elif query_type in ["create", "update", "delete", "count"]:
return None
return None
async def db_save(
model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None
) -> Union[Dict[str, Any], None]:
"""保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
如果没有找到匹配记录或未提供key_field和key_value则创建新记录。
Args:
model_class: Peewee模型类如ActionRecords, Messages等
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名,例如"action_id"
key_value: 用于查找现有记录的字段值
Returns:
Dict[str, Any]: 保存后的记录数据
None: 如果操作失败
示例:
# 创建或更新一条记录
record = await database_api.db_save(
ActionRecords,
{
"action_id": "123",
"time": time.time(),
"action_name": "TestAction",
"action_done": True
},
key_field="action_id",
key_value="123"
)
"""
try:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
# 查找现有记录
existing_records = list(
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
)
if existing_records:
# 更新现有记录
existing_record = existing_records[0]
for field, value in data.items():
setattr(existing_record, field, value)
existing_record.save()
# 返回更新后的记录
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get()
return updated_record
# 如果没有找到现有记录或未提供key_field和key_value创建新记录
new_record = model_class.create(**data)
# 返回创建的记录
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get()
return created_record
except Exception as e:
logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}")
traceback.print_exc()
return None
async def db_get(
model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
这是db_query方法的简化版本专注于数据检索操作。
Args:
model_class: Peewee模型类
filters: 过滤条件,字段名和值的字典
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
limit: 结果数量限制如果为1则返回单个记录而不是列表
Returns:
如果limit=1返回单个记录字典或None
否则返回记录字典列表或空列表。
示例:
# 获取单个记录
record = await database_api.db_get(
ActionRecords,
filters={"action_id": "123"},
limit=1
)
# 获取最近10条记录
records = await database_api.db_get(
Messages,
filters={"chat_id": chat_stream.stream_id},
order_by="-time",
limit=10
)
"""
try:
# 构建查询
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 应用排序
if order_by:
if order_by.startswith("-"):
query = query.order_by(getattr(model_class, order_by[1:]).desc())
else:
query = query.order_by(getattr(model_class, order_by))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if limit == 1:
return results[0] if results else None
return results
except Exception as e:
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
traceback.print_exc()
return None if limit == 1 else []
async def store_action_info(
chat_stream=None,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: dict = None,
action_name: str = "",
) -> Union[Dict[str, Any], None]:
"""存储动作信息到数据库
将Action执行的相关信息保存到ActionRecords表中用于后续的记忆和上下文构建。
Args:
chat_stream: 聊天流对象,包含聊天相关信息
action_build_into_prompt: 是否将此动作构建到提示中
action_prompt_display: 动作的提示显示文本
action_done: 动作是否完成
thinking_id: 关联的思考ID
action_data: 动作数据字典
action_name: 动作名称
Returns:
Dict[str, Any]: 保存的记录数据
None: 如果保存失败
示例:
record = await database_api.store_action_info(
chat_stream=chat_stream,
action_build_into_prompt=True,
action_prompt_display="执行了回复动作",
action_done=True,
thinking_id="thinking_123",
action_data={"content": "Hello"},
action_name="reply_action"
)
"""
try:
import time
import json
from src.common.database.database_model import ActionRecords
# 构建动作记录数据
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)), # 使用thinking_id或生成唯一ID
"time": time.time(),
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
"action_done": action_done,
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
}
# 从chat_stream获取聊天信息
if chat_stream:
record_data.update({
"chat_id": getattr(chat_stream, 'stream_id', ''),
"chat_info_stream_id": getattr(chat_stream, 'stream_id', ''),
"chat_info_platform": getattr(chat_stream, 'platform', ''),
})
else:
# 如果没有chat_stream设置默认值
record_data.update({
"chat_id": "",
"chat_info_stream_id": "",
"chat_info_platform": "",
})
# 使用已有的db_save函数保存记录
saved_record = await db_save(
ActionRecords,
data=record_data,
key_field="action_id",
key_value=record_data["action_id"]
)
if saved_record:
logger.info(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
else:
logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}")
return saved_record
except Exception as e:
logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}")
traceback.print_exc()
return None

View File

@@ -0,0 +1,219 @@
"""
表情API模块
提供表情包相关功能采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import emoji_api
result = await emoji_api.get_by_description("开心")
count = emoji_api.get_count()
"""
from typing import Optional, Tuple
from src.common.logger import get_logger
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.utils.utils_image import image_path_to_base64
logger = get_logger("emoji_api")
# =============================================================================
# 表情包获取API函数
# =============================================================================
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
"""根据描述选择表情包
Args:
description: 表情包的描述文本,例如"开心""难过""愤怒"
Returns:
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
"""
try:
logger.info(f"[EmojiAPI] 根据描述获取表情包: {description}")
emoji_manager = get_emoji_manager()
emoji_result = await emoji_manager.get_emoji_for_text(description)
if not emoji_result:
logger.warning(f"[EmojiAPI] 未找到匹配描述 '{description}' 的表情包")
return None
emoji_path, emoji_description, matched_emotion = emoji_result
emoji_base64 = image_path_to_base64(emoji_path)
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法将表情包文件转换为base64: {emoji_path}")
return None
logger.info(f"[EmojiAPI] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
return emoji_base64, emoji_description, matched_emotion
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包失败: {e}")
return None
async def get_random() -> Optional[Tuple[str, str, str]]:
"""随机获取表情包
Returns:
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 随机情感标签) 或 None
"""
try:
logger.info("[EmojiAPI] 随机获取表情包")
emoji_manager = get_emoji_manager()
all_emojis = emoji_manager.emoji_objects
if not all_emojis:
logger.warning("[EmojiAPI] 没有可用的表情包")
return None
# 过滤有效表情包
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
if not valid_emojis:
logger.warning("[EmojiAPI] 没有有效的表情包")
return None
# 随机选择
import random
selected_emoji = random.choice(valid_emojis)
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
return None
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
# 记录使用次数
emoji_manager.record_usage(selected_emoji.hash)
logger.info(f"[EmojiAPI] 成功获取随机表情包: {selected_emoji.description}")
return emoji_base64, selected_emoji.description, matched_emotion
except Exception as e:
logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}")
return None
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
"""根据情感标签获取表情包
Args:
emotion: 情感标签,如"happy""sad""angry"
Returns:
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
"""
try:
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
emoji_manager = get_emoji_manager()
all_emojis = emoji_manager.emoji_objects
# 筛选匹配情感的表情包
matching_emojis = []
for emoji_obj in all_emojis:
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]:
matching_emojis.append(emoji_obj)
if not matching_emojis:
logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包")
return None
# 随机选择匹配的表情包
import random
selected_emoji = random.choice(matching_emojis)
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
return None
# 记录使用次数
emoji_manager.record_usage(selected_emoji.hash)
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
return emoji_base64, selected_emoji.description, emotion
except Exception as e:
logger.error(f"[EmojiAPI] 根据情感获取表情包失败: {e}")
return None
# =============================================================================
# 表情包信息查询API函数
# =============================================================================
def get_count() -> int:
"""获取表情包数量
Returns:
int: 当前可用的表情包数量
"""
try:
emoji_manager = get_emoji_manager()
return emoji_manager.emoji_num
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包数量失败: {e}")
return 0
def get_info() -> dict:
"""获取表情包系统信息
Returns:
dict: 包含表情包数量、最大数量等信息
"""
try:
emoji_manager = get_emoji_manager()
return {
"current_count": emoji_manager.emoji_num,
"max_count": emoji_manager.emoji_num_max,
"available_emojis": len([e for e in emoji_manager.emoji_objects if not e.is_deleted]),
}
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包信息失败: {e}")
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
def get_emotions() -> list:
"""获取所有可用的情感标签
Returns:
list: 所有表情包的情感标签列表(去重)
"""
try:
emoji_manager = get_emoji_manager()
emotions = set()
for emoji_obj in emoji_manager.emoji_objects:
if not emoji_obj.is_deleted and emoji_obj.emotion:
emotions.update(emoji_obj.emotion)
return sorted(list(emotions))
except Exception as e:
logger.error(f"[EmojiAPI] 获取情感标签失败: {e}")
return []
def get_descriptions() -> list:
"""获取所有表情包描述
Returns:
list: 所有可用表情包的描述列表
"""
try:
emoji_manager = get_emoji_manager()
descriptions = []
for emoji_obj in emoji_manager.emoji_objects:
if not emoji_obj.is_deleted and emoji_obj.description:
descriptions.append(emoji_obj.description)
return descriptions
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")
return []

View File

@@ -0,0 +1,170 @@
"""
回复器API模块
提供回复器相关功能采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import generator_api
replyer = generator_api.get_replyer(chat_stream)
success, reply_set = await generator_api.generate_reply(chat_stream, action_data, reasoning)
"""
from typing import Tuple, Any, Dict, List
from src.common.logger import get_logger
from src.chat.focus_chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import get_chat_manager
logger = get_logger("generator_api")
# =============================================================================
# 回复器获取API函数
# =============================================================================
def get_replyer(chat_stream=None, platform: str = None, chat_id: str = None, is_group: bool = True) -> DefaultReplyer:
"""获取回复器对象
优先使用chat_stream如果没有则使用platform和chat_id组合
Args:
chat_stream: 聊天流对象(优先)
platform: 平台名称,如"qq"
chat_id: 聊天ID群ID或用户ID
is_group: 是否为群聊
Returns:
Optional[Any]: 回复器对象如果获取失败则返回None
"""
try:
# 优先使用聊天流
if chat_stream:
logger.debug("[GeneratorAPI] 使用聊天流获取回复器")
return DefaultReplyer(chat_stream=chat_stream)
# 使用平台和ID组合
if platform and chat_id:
logger.debug("[GeneratorAPI] 使用平台和ID获取回复器")
chat_manager = get_chat_manager()
if not chat_manager:
logger.warning("[GeneratorAPI] 无法获取聊天管理器")
return None
# 查找对应的聊天流
target_stream = None
for _stream_id, stream in chat_manager.streams.items():
if stream.platform == platform:
if is_group and stream.group_info:
if str(stream.group_info.group_id) == str(chat_id):
target_stream = stream
break
elif not is_group and stream.user_info:
if str(stream.user_info.user_id) == str(chat_id):
target_stream = stream
break
return DefaultReplyer(chat_stream=target_stream)
logger.warning("[GeneratorAPI] 缺少必要参数,无法获取回复器")
return None
except Exception as e:
logger.error(f"[GeneratorAPI] 获取回复器失败: {e}")
return None
# =============================================================================
# 回复生成API函数
# =============================================================================
async def generate_reply(
chat_stream=None,
action_data: Dict[str, Any] = None,
platform: str = None,
chat_id: str = None,
is_group: bool = True
) -> Tuple[bool, List[Tuple[str, Any]]]:
"""生成回复
Args:
chat_stream: 聊天流对象(优先)
action_data: 动作数据
reasoning: 推理原因
thinking_id: 思考ID
cycle_timers: 循环计时器
anchor_message: 锚点消息
platform: 平台名称(备用)
chat_id: 聊天ID备用
is_group: 是否为群聊(备用)
Returns:
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, platform, chat_id, is_group)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, []
logger.info("[GeneratorAPI] 开始生成回复")
# 调用回复器生成回复
success, reply_set = await replyer.generate_reply_with_context(
reply_data=action_data or {},
)
if success:
logger.info(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
else:
logger.warning("[GeneratorAPI] 回复生成失败")
return success, reply_set or []
except Exception as e:
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
return False, []
async def rewrite_reply(
chat_stream=None,
reply_data: Dict[str, Any] = None,
platform: str = None,
chat_id: str = None,
is_group: bool = True
) -> Tuple[bool, List[Tuple[str, Any]]]:
"""重写回复
Args:
chat_stream: 聊天流对象(优先)
action_data: 动作数据
platform: 平台名称(备用)
chat_id: 聊天ID备用
is_group: 是否为群聊(备用)
Returns:
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, platform, chat_id, is_group)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, []
logger.info("[GeneratorAPI] 开始重写回复")
# 调用回复器重写回复
success, reply_set = await replyer.rewrite_reply_with_context(
reply_data=reply_data or {},
)
if success:
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项")
else:
logger.warning("[GeneratorAPI] 重写回复失败")
return success, reply_set or []
except Exception as e:
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
return False, []

View File

@@ -1,177 +0,0 @@
from typing import Optional, List, Any, Tuple
from src.common.logger import get_logger
logger = get_logger("hearflow_api")
def _get_heartflow():
"""获取heartflow实例的延迟导入函数"""
from src.chat.heart_flow.heartflow import heartflow
return heartflow
def _get_subheartflow_types():
"""获取SubHeartflow和ChatState类型的延迟导入函数"""
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
return SubHeartflow, ChatState
class HearflowAPI:
"""心流API模块
提供与心流和子心流相关的操作接口
"""
def __init__(self):
self.log_prefix = "[HearflowAPI]"
async def get_sub_hearflow_by_chat_id(self, chat_id: str) -> Optional[Any]:
"""根据chat_id获取指定的sub_hearflow实例
Args:
chat_id: 聊天ID与sub_hearflow的subheartflow_id相同
Returns:
Optional[SubHeartflow]: sub_hearflow实例如果不存在则返回None
"""
# 使用延迟导入
heartflow = _get_heartflow()
# 直接从subheartflow_manager获取已存在的子心流
# 使用锁来确保线程安全
async with heartflow.subheartflow_manager._lock:
subflow = heartflow.subheartflow_manager.subheartflows.get(chat_id)
if subflow and not subflow.should_stop:
logger.debug(f"{self.log_prefix} 成功获取子心流实例: {chat_id}")
return subflow
else:
logger.debug(f"{self.log_prefix} 子心流不存在或已停止: {chat_id}")
return None
async def get_or_create_sub_hearflow_by_chat_id(self, chat_id: str) -> Optional[Any]:
"""根据chat_id获取或创建sub_hearflow实例
Args:
chat_id: 聊天ID
Returns:
Optional[SubHeartflow]: sub_hearflow实例创建失败时返回None
"""
heartflow = _get_heartflow()
return await heartflow.get_or_create_subheartflow(chat_id)
def get_all_sub_hearflow_ids(self) -> List[str]:
"""获取所有子心流的ID列表
Returns:
List[str]: 所有子心流的ID列表
"""
heartflow = _get_heartflow()
all_subflows = heartflow.subheartflow_manager.get_all_subheartflows()
chat_ids = [subflow.chat_id for subflow in all_subflows if not subflow.should_stop]
logger.debug(f"{self.log_prefix} 获取到 {len(chat_ids)} 个活跃的子心流ID")
return chat_ids
def get_all_sub_hearflows(self) -> List[Any]:
"""获取所有子心流实例
Returns:
List[SubHeartflow]: 所有活跃的子心流实例列表
"""
heartflow = _get_heartflow()
all_subflows = heartflow.subheartflow_manager.get_all_subheartflows()
active_subflows = [subflow for subflow in all_subflows if not subflow.should_stop]
logger.debug(f"{self.log_prefix} 获取到 {len(active_subflows)} 个活跃的子心流实例")
return active_subflows
async def get_sub_hearflow_chat_state(self, chat_id: str) -> Optional[Any]:
"""获取指定子心流的聊天状态
Args:
chat_id: 聊天ID
Returns:
Optional[ChatState]: 聊天状态如果子心流不存在则返回None
"""
subflow = await self.get_sub_hearflow_by_chat_id(chat_id)
if subflow:
return subflow.chat_state.chat_status
return None
async def set_sub_hearflow_chat_state(self, chat_id: str, target_state: Any) -> bool:
"""设置指定子心流的聊天状态
Args:
chat_id: 聊天ID
target_state: 目标状态(ChatState枚举值)
Returns:
bool: 是否设置成功
"""
heartflow = _get_heartflow()
return await heartflow.subheartflow_manager.force_change_state(chat_id, target_state)
async def get_sub_hearflow_replyer_and_expressor(self, chat_id: str) -> Tuple[Optional[Any], Optional[Any]]:
"""根据chat_id获取指定子心流的replyer和expressor实例
Args:
chat_id: 聊天ID
Returns:
Tuple[Optional[Any], Optional[Any]]: (replyer实例, expressor实例)如果子心流不存在或未处于FOCUSED状态返回(None, None)
"""
subflow = await self.get_sub_hearflow_by_chat_id(chat_id)
if not subflow:
logger.debug(f"{self.log_prefix} 子心流不存在: {chat_id}")
return None, None
# 使用延迟导入获取ChatState
_, ChatState = _get_subheartflow_types()
# 检查子心流是否处于FOCUSED状态且有HeartFC实例
if subflow.chat_state.chat_status != ChatState.FOCUSED:
logger.debug(
f"{self.log_prefix} 子心流 {chat_id} 未处于FOCUSED状态当前状态: {subflow.chat_state.chat_status.value}"
)
return None, None
if not subflow.heart_fc_instance:
logger.debug(f"{self.log_prefix} 子心流 {chat_id} 没有HeartFC实例")
return None, None
# 返回replyer和expressor实例
replyer = subflow.heart_fc_instance.replyer
expressor = subflow.heart_fc_instance.expressor
if replyer and expressor:
logger.debug(f"{self.log_prefix} 成功获取子心流 {chat_id} 的replyer和expressor")
else:
logger.warning(f"{self.log_prefix} 子心流 {chat_id} 的replyer或expressor为空")
return replyer, expressor
async def get_sub_hearflow_replyer(self, chat_id: str) -> Optional[Any]:
"""根据chat_id获取指定子心流的replyer实例
Args:
chat_id: 聊天ID
Returns:
Optional[Any]: replyer实例如果不存在则返回None
"""
replyer, _ = await self.get_sub_hearflow_replyer_and_expressor(chat_id)
return replyer
async def get_sub_hearflow_expressor(self, chat_id: str) -> Optional[Any]:
"""根据chat_id获取指定子心流的expressor实例
Args:
chat_id: 聊天ID
Returns:
Optional[Any]: expressor实例如果不存在则返回None
"""
_, expressor = await self.get_sub_hearflow_replyer_and_expressor(chat_id)
return expressor

View File

@@ -1,134 +0,0 @@
"""
独立API聚合模块
聚合了不需要Action组件依赖的API这些API可以独立使用不需要注入服务对象。
包括LLMAPI、ConfigAPI、UtilsAPI、StreamAPI、HearflowAPI等独立功能的API。
"""
from src.plugin_system.apis.llm_api import LLMAPI
from src.plugin_system.apis.config_api import ConfigAPI
from src.plugin_system.apis.utils_api import UtilsAPI
from src.plugin_system.apis.stream_api import StreamAPI
from src.plugin_system.apis.hearflow_api import HearflowAPI
from src.common.logger import get_logger
logger = get_logger("independent_apis")
class IndependentAPI(LLMAPI, ConfigAPI, UtilsAPI, StreamAPI, HearflowAPI):
"""
独立API聚合类
聚合了不需要Action组件依赖的API功能。这些API的特点
- 不需要chat_stream、expressor等服务对象
- 可以独立调用不依赖Action执行上下文
- 主要是工具类方法和配置查询方法
包含的API
- LLMAPI: LLM模型调用仅需要全局配置
- ConfigAPI: 配置读取(使用全局配置)
- UtilsAPI: 工具方法(文件操作、时间处理等)
- StreamAPI: 聊天流查询使用ChatManager
- HearflowAPI: 心流状态控制使用heartflow
使用场景:
- 在Command组件中使用
- 独立的工具函数调用
- 配置查询和系统状态检查
"""
def __init__(self, log_prefix: str = "[IndependentAPI]"):
"""
初始化独立API
Args:
log_prefix: 日志前缀,用于区分不同的调用来源
"""
self.log_prefix = log_prefix
logger.debug(f"{self.log_prefix} IndependentAPI 初始化完成")
# 提供便捷的静态访问方式
class StaticAPI:
"""
静态API类
提供完全静态的API访问方式不需要实例化适合简单的工具调用。
"""
# LLM相关
@staticmethod
def get_available_models():
"""获取可用的LLM模型"""
api = LLMAPI()
return api.get_available_models()
@staticmethod
async def generate_with_model(prompt: str, model_config: dict, **kwargs):
"""使用LLM生成内容"""
api = LLMAPI()
api.log_prefix = "[StaticAPI]"
return await api.generate_with_model(prompt, model_config, **kwargs)
# 配置相关
@staticmethod
def get_global_config(key: str, default=None):
"""获取全局配置"""
api = ConfigAPI()
return api.get_global_config(key, default)
@staticmethod
async def get_user_id_by_name(person_name: str):
"""根据用户名获取用户ID"""
api = ConfigAPI()
return await api.get_user_id_by_person_name(person_name)
# 工具相关
@staticmethod
def get_timestamp():
"""获取当前时间戳"""
api = UtilsAPI()
return api.get_timestamp()
@staticmethod
def format_time(timestamp=None, format_str="%Y-%m-%d %H:%M:%S"):
"""格式化时间"""
api = UtilsAPI()
return api.format_time(timestamp, format_str)
@staticmethod
def generate_unique_id():
"""生成唯一ID"""
api = UtilsAPI()
return api.generate_unique_id()
# 聊天流相关
@staticmethod
def get_chat_stream_by_group_id(group_id: str, platform: str = "qq"):
"""通过群ID获取聊天流"""
api = StreamAPI()
api.log_prefix = "[StaticAPI]"
return api.get_chat_stream_by_group_id(group_id, platform)
@staticmethod
def get_all_group_chat_streams(platform: str = "qq"):
"""获取所有群聊聊天流"""
api = StreamAPI()
api.log_prefix = "[StaticAPI]"
return api.get_all_group_chat_streams(platform)
# 心流相关
@staticmethod
async def get_sub_hearflow_by_chat_id(chat_id: str):
"""获取子心流"""
api = HearflowAPI()
api.log_prefix = "[StaticAPI]"
return await api.get_sub_hearflow_by_chat_id(chat_id)
@staticmethod
async def set_sub_hearflow_chat_state(chat_id: str, target_state):
"""设置子心流状态"""
api = HearflowAPI()
api.log_prefix = "[StaticAPI]"
return await api.set_sub_hearflow_chat_state(chat_id, target_state)

View File

@@ -1,3 +1,12 @@
"""LLM API模块
提供了与LLM模型交互的功能
使用方式:
from src.plugin_system.apis import llm_api
models = llm_api.get_available_models()
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
"""
from typing import Tuple, Dict, Any
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
@@ -6,49 +15,51 @@ from src.config.config import global_config
logger = get_logger("llm_api")
class LLMAPI:
"""LLM API模块
# =============================================================================
# LLM模型API函数
# =============================================================================
提供了与LLM模型交互的功能
def get_available_models() -> Dict[str, Any]:
"""获取所有可用的模型配置
Returns:
Dict[str, Any]: 模型配置字典key为模型名称value为模型配置
"""
def get_available_models(self) -> Dict[str, Any]:
"""获取所有可用的模型配置
Returns:
Dict[str, Any]: 模型配置字典key为模型名称value为模型配置
"""
try:
if not hasattr(global_config, "model"):
logger.error(f"{self.log_prefix} 无法获取模型列表:全局配置中未找到 model 配置")
logger.error("[LLMAPI] 无法获取模型列表:全局配置中未找到 model 配置")
return {}
models = global_config.model
return models
except Exception as e:
logger.error(f"[LLMAPI] 获取可用模型失败: {e}")
return {}
async def generate_with_model(
self, prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
) -> Tuple[bool, str, str, str]:
"""使用指定模型生成内容
Args:
prompt: 提示词
model_config: 模型配置(从 get_available_models 获取的模型配置)
request_type: 请求类型标识
**kwargs: 其他模型特定参数如temperature、max_tokens等
async def generate_with_model(
prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
) -> Tuple[bool, str, str, str]:
"""使用指定模型生成内容
Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
"""
try:
logger.info(f"{self.log_prefix} 使用模型生成内容,提示词: {prompt[:100]}...")
Args:
prompt: 提示词
model_config: 模型配置(从 get_available_models 获取的模型配置)
request_type: 请求类型标识
**kwargs: 其他模型特定参数如temperature、max_tokens等
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
"""
try:
logger.info(f"[LLMAPI] 使用模型生成内容,提示词: {prompt[:100]}...")
response, (reasoning, model_name) = await llm_request.generate_response_async(prompt)
return True, response, reasoning, model_name
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"{self.log_prefix} {error_msg}")
return False, error_msg, "", ""
response, (reasoning, model_name) = await llm_request.generate_response_async(prompt)
return True, response, reasoning, model_name
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", ""

View File

@@ -1,202 +1,329 @@
import traceback
"""
消息API模块
提供消息查询和构建成字符串的功能采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import message_api
messages = message_api.get_messages_by_time_in_chat(chat_id, start_time, end_time)
readable_text = message_api.build_readable_messages(messages)
"""
from typing import List, Dict, Any, Tuple, Optional
import time
from typing import List, Dict, Any
from src.common.logger import get_logger
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
# 以下为类型注解需要
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.focus_chat.info.obs_info import ObsInfo
# 新增导入
from src.chat.focus_chat.heartFC_sender import HeartFCSender
from src.chat.message_receive.message import MessageSending
from maim_message import Seg, UserInfo
from src.config.config import global_config
logger = get_logger("message_api")
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_users,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
build_readable_messages,
build_readable_messages_with_list,
get_person_id_list,
)
class MessageAPI:
"""消息API模块
# =============================================================================
# 消息查询API函数
# =============================================================================
提供了发送消息、获取消息历史等功能
def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
获取指定时间范围内的消息
Args:
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
Returns:
消息列表
"""
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
async def send_message_to_target(
self,
message_type: str,
content: str,
platform: str,
target_id: str,
is_group: bool = True,
display_message: str = "",
typing: bool = False,
) -> bool:
"""直接向指定目标发送消息
Args:
message_type: 消息类型,如"text""image""emoji"
content: 消息内容
platform: 目标平台,如"qq"
target_id: 目标ID群ID或用户ID
is_group: 是否为群聊True为群聊False为私聊
display_message: 显示消息(可选)
def get_messages_by_time_in_chat(
chat_id: str, start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
获取指定聊天中指定时间范围内的消息
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
Returns:
消息列表
"""
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode)
Returns:
bool: 是否发送成功
"""
try:
# 构建目标聊天流ID
if is_group:
# 群聊:从数据库查找对应的聊天流
target_stream = None
for _, stream in get_chat_manager().streams.items():
if (
stream.group_info
and str(stream.group_info.group_id) == str(target_id)
and stream.platform == platform
):
target_stream = stream
break
if not target_stream:
logger.error(f"{getattr(self, 'log_prefix', '')} 未找到群ID为 {target_id} 的聊天流")
return False
else:
# 私聊:从数据库查找对应的聊天流
target_stream = None
for _, stream in get_chat_manager().streams.items():
if (
not stream.group_info
and str(stream.user_info.user_id) == str(target_id)
and stream.platform == platform
):
target_stream = stream
break
def get_messages_by_time_in_chat_inclusive(
chat_id: str, start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
获取指定聊天中指定时间范围内的消息(包含边界)
Args:
chat_id: 聊天ID
start_time: 开始时间戳(包含)
end_time: 结束时间戳(包含)
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
Returns:
消息列表
"""
return get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode)
if not target_stream:
logger.error(f"{getattr(self, 'log_prefix', '')} 未找到用户ID为 {target_id} 的私聊流")
return False
# 创建HeartFCSender实例
heart_fc_sender = HeartFCSender()
def get_messages_by_time_in_chat_for_users(
chat_id: str,
start_time: float,
end_time: float,
person_ids: list,
limit: int = 0,
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
"""
获取指定聊天中指定用户在指定时间范围内的消息
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳
person_ids: 用户ID列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
Returns:
消息列表
"""
return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode)
# 生成消息ID和thinking_id
current_time = time.time()
message_id = f"plugin_msg_{int(current_time * 1000)}"
# 构建机器人用户信息
bot_user_info = UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=platform,
)
def get_random_chat_messages(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
随机选择一个聊天,返回该聊天在指定时间范围内的消息
Args:
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
Returns:
消息列表
"""
return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)
# 创建消息段
message_segment = Seg(type=message_type, data=content)
# 创建空锚点消息(用于回复)
anchor_message = await create_empty_anchor_message(platform, target_stream.group_info, target_stream)
def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
获取指定用户在所有聊天中指定时间范围内的消息
Args:
start_time: 开始时间戳
end_time: 结束时间戳
person_ids: 用户ID列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
Returns:
消息列表
"""
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
# 构建发送消息对象
bot_message = MessageSending(
message_id=message_id,
chat_stream=target_stream,
bot_user_info=bot_user_info,
sender_info=target_stream.user_info, # 目标用户信息
message_segment=message_segment,
display_message=display_message,
reply=anchor_message,
is_head=True,
is_emoji=(message_type == "emoji"),
thinking_start_time=current_time,
)
# 发送消息
sent_msg = await heart_fc_sender.send_message(
bot_message, has_thinking=False, typing=typing, set_reply=False
)
def get_messages_before_time(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
"""
获取指定时间戳之前的消息
Args:
timestamp: 时间戳
limit: 限制返回的消息数量0为不限制
Returns:
消息列表
"""
return get_raw_msg_before_timestamp(timestamp, limit)
if sent_msg:
logger.info(f"{getattr(self, 'log_prefix', '')} 成功发送消息到 {platform}:{target_id}")
return True
else:
logger.error(f"{getattr(self, 'log_prefix', '')} 发送消息失败")
return False
except Exception as e:
logger.error(f"{getattr(self, 'log_prefix', '')} 向目标发送消息时出错: {e}")
traceback.print_exc()
return False
def get_messages_before_time_in_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
"""
获取指定聊天中指定时间戳之前的消息
Args:
chat_id: 聊天ID
timestamp: 时间戳
limit: 限制返回的消息数量0为不限制
Returns:
消息列表
"""
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
async def send_text_to_group(self, text: str, group_id: str, platform: str = "qq") -> bool:
"""便捷方法:向指定群聊发送文本消息
Args:
text: 要发送的文本内容
group_id: 群聊ID
platform: 平台,默认为"qq"
def get_messages_before_time_for_users(
timestamp: float, person_ids: list, limit: int = 0
) -> List[Dict[str, Any]]:
"""
获取指定用户在指定时间戳之前的消息
Args:
timestamp: 时间戳
person_ids: 用户ID列表
limit: 限制返回的消息数量0为不限制
Returns:
消息列表
"""
return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit)
Returns:
bool: 是否发送成功
"""
return await self.send_message_to_target(
message_type="text", content=text, platform=platform, target_id=group_id, is_group=True
)
async def send_text_to_user(self, text: str, user_id: str, platform: str = "qq") -> bool:
"""便捷方法:向指定用户发送私聊文本消息
def get_recent_messages(
chat_id: str,
hours: float = 24.0,
limit: int = 100,
limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
获取指定聊天中最近一段时间的消息
Args:
chat_id: 聊天ID
hours: 最近多少小时默认24小时
limit: 限制返回的消息数量默认100条
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
Returns:
消息列表
"""
now = time.time()
start_time = now - hours * 3600
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)
Args:
text: 要发送的文本内容
user_id: 用户ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
return await self.send_message_to_target(
message_type="text", content=text, platform=platform, target_id=user_id, is_group=False
)
# =============================================================================
# 消息计数API函数
# =============================================================================
def get_chat_type(self) -> str:
"""获取当前聊天类型
def count_new_messages(
chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None
) -> int:
"""
计算指定聊天中从开始时间到结束时间的新消息数量
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳如果为None则使用当前时间
Returns:
新消息数量
"""
return num_new_messages_since(chat_id, start_time, end_time)
Returns:
str: 聊天类型 ("group""private")
"""
services = getattr(self, "_services", {})
chat_stream: ChatStream = services.get("chat_stream")
if chat_stream and hasattr(chat_stream, "group_info"):
return "group" if chat_stream.group_info else "private"
return "unknown"
def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]:
"""获取最近的消息
def count_new_messages_for_users(
chat_id: str, start_time: float, end_time: float, person_ids: list
) -> int:
"""
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳
person_ids: 用户ID列表
Returns:
新消息数量
"""
return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids)
Args:
count: 要获取的消息数量
Returns:
List[Dict]: 消息列表,每个消息包含发送者、内容等信息
"""
messages = []
services = getattr(self, "_services", {})
observations = services.get("observations", [])
# =============================================================================
# 消息格式化API函数
# =============================================================================
if observations and len(observations) > 0:
obs = observations[0]
if hasattr(obs, "get_talking_message"):
obs: ObsInfo
raw_messages = obs.get_talking_message()
# 转换为简化格式
for msg in raw_messages[-count:]:
simple_msg = {
"sender": msg.get("sender", "未知"),
"content": msg.get("content", ""),
"timestamp": msg.get("timestamp", 0),
}
messages.append(simple_msg)
def build_readable_messages_to_str(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
) -> str:
"""
将消息列表构建成可读的字符串
Args:
messages: 消息列表
replace_bot_name: 是否将机器人的名称替换为""
merge_messages: 是否合并连续消息
timestamp_mode: 时间戳显示模式,'relative''absolute'
read_mark: 已读标记时间戳,用于分割已读和未读消息
truncate: 是否截断长消息
show_actions: 是否显示动作记录
Returns:
格式化后的可读字符串
"""
return build_readable_messages(
messages, replace_bot_name, merge_messages, timestamp_mode, read_mark, truncate, show_actions
)
return messages
async def build_readable_messages_with_details(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
"""
将消息列表构建成可读的字符串,并返回详细信息
Args:
messages: 消息列表
replace_bot_name: 是否将机器人的名称替换为""
merge_messages: 是否合并连续消息
timestamp_mode: 时间戳显示模式,'relative''absolute'
truncate: 是否截断长消息
Returns:
格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
"""
return await build_readable_messages_with_list(
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
)
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
"""
从消息列表中提取不重复的用户ID列表
Args:
messages: 消息列表
Returns:
用户ID列表
"""
return await get_person_id_list(messages)

View File

@@ -0,0 +1,153 @@
"""个人信息API模块
提供个人信息查询功能,用于插件获取用户相关信息
使用方式:
from src.plugin_system.apis import person_api
person_id = person_api.get_person_id("qq", 123456)
value = await person_api.get_person_value(person_id, "nickname")
"""
from typing import Any
from src.common.logger import get_logger
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
logger = get_logger("person_api")
# =============================================================================
# 个人信息API函数
# =============================================================================
def get_person_id(platform: str, user_id: int) -> str:
"""根据平台和用户ID获取person_id
Args:
platform: 平台名称,如 "qq", "telegram"
user_id: 用户ID
Returns:
str: 唯一的person_idMD5哈希值
示例:
person_id = person_api.get_person_id("qq", 123456)
"""
try:
return PersonInfoManager.get_person_id(platform, user_id)
except Exception as e:
logger.error(f"[PersonAPI] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
return ""
async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any:
"""根据person_id和字段名获取某个值
Args:
person_id: 用户的唯一标识ID
field_name: 要获取的字段名,如 "nickname", "impression"
default: 当字段不存在或获取失败时返回的默认值
Returns:
Any: 字段值或默认值
示例:
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
impression = await person_api.get_person_value(person_id, "impression")
"""
try:
person_info_manager = get_person_info_manager()
value = await person_info_manager.get_value(person_id, field_name)
return value if value is not None else default
except Exception as e:
logger.error(f"[PersonAPI] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
return default
async def get_person_values(person_id: str, field_names: list, default_dict: dict = None) -> dict:
"""批量获取用户信息字段值
Args:
person_id: 用户的唯一标识ID
field_names: 要获取的字段名列表
default_dict: 默认值字典,键为字段名,值为默认值
Returns:
dict: 字段名到值的映射字典
示例:
values = await person_api.get_person_values(
person_id,
["nickname", "impression", "know_times"],
{"nickname": "未知用户", "know_times": 0}
)
"""
try:
person_info_manager = get_person_info_manager()
values = await person_info_manager.get_values(person_id, field_names)
# 如果获取成功,返回结果
if values:
return values
# 如果获取失败,构建默认值字典
result = {}
if default_dict:
for field in field_names:
result[field] = default_dict.get(field, None)
else:
for field in field_names:
result[field] = None
return result
except Exception as e:
logger.error(f"[PersonAPI] 批量获取用户信息失败: person_id={person_id}, fields={field_names}, error={e}")
# 返回默认值字典
result = {}
if default_dict:
for field in field_names:
result[field] = default_dict.get(field, None)
else:
for field in field_names:
result[field] = None
return result
async def is_person_known(platform: str, user_id: int) -> bool:
"""判断是否认识某个用户
Args:
platform: 平台名称
user_id: 用户ID
Returns:
bool: 是否认识该用户
示例:
known = await person_api.is_person_known("qq", 123456)
"""
try:
person_info_manager = get_person_info_manager()
return await person_info_manager.is_person_known(platform, user_id)
except Exception as e:
logger.error(f"[PersonAPI] 检查用户是否已知失败: platform={platform}, user_id={user_id}, error={e}")
return False
def get_person_id_by_name(person_name: str) -> str:
"""根据用户名获取person_id
Args:
person_name: 用户名
Returns:
str: person_id如果未找到返回空字符串
示例:
person_id = person_api.get_person_id_by_name("张三")
"""
try:
person_info_manager = get_person_info_manager()
return person_info_manager.get_person_id_by_person_name(person_name)
except Exception as e:
logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
return ""

View File

@@ -1,234 +0,0 @@
# -*- coding: utf-8 -*-
"""
统一的插件API聚合模块
提供所有插件API功能的统一访问入口
"""
from src.common.logger import get_logger
# 导入所有API模块
from src.plugin_system.apis.message_api import MessageAPI
from src.plugin_system.apis.llm_api import LLMAPI
from src.plugin_system.apis.database_api import DatabaseAPI
from src.plugin_system.apis.config_api import ConfigAPI
from src.plugin_system.apis.utils_api import UtilsAPI
from src.plugin_system.apis.stream_api import StreamAPI
from src.plugin_system.apis.hearflow_api import HearflowAPI
logger = get_logger("plugin_api")
class PluginAPI(MessageAPI, LLMAPI, DatabaseAPI, ConfigAPI, UtilsAPI, StreamAPI, HearflowAPI):
"""
插件API聚合类
集成了所有可供插件使用的API功能提供统一的访问接口。
插件组件可以直接使用此API实例来访问各种功能。
特性:
- 聚合所有API模块的功能
- 支持依赖注入和配置
- 提供统一的错误处理和日志记录
"""
def __init__(
self,
chat_stream=None,
expressor=None,
replyer=None,
observations=None,
log_prefix: str = "[PluginAPI]",
plugin_config: dict = None,
):
"""
初始化插件API
Args:
chat_stream: 聊天流对象
expressor: 表达器对象
replyer: 回复器对象
observations: 观察列表
log_prefix: 日志前缀
plugin_config: 插件配置字典
"""
# 存储依赖对象
self._services = {
"chat_stream": chat_stream,
"expressor": expressor,
"replyer": replyer,
"observations": observations or [],
}
self.log_prefix = log_prefix
# 存储action上下文信息
self._action_context = {}
# 调用所有父类的初始化
super().__init__()
# 存储插件配置
self._plugin_config = plugin_config or {}
def set_chat_stream(self, chat_stream):
"""设置聊天流对象"""
self._services["chat_stream"] = chat_stream
logger.debug(f"{self.log_prefix} 设置聊天流: {getattr(chat_stream, 'stream_id', 'Unknown')}")
def set_expressor(self, expressor):
"""设置表达器对象"""
self._services["expressor"] = expressor
logger.debug(f"{self.log_prefix} 设置表达器")
def set_replyer(self, replyer):
"""设置回复器对象"""
self._services["replyer"] = replyer
logger.debug(f"{self.log_prefix} 设置回复器")
def set_observations(self, observations):
"""设置观察列表"""
self._services["observations"] = observations or []
logger.debug(f"{self.log_prefix} 设置观察列表,数量: {len(observations or [])}")
def get_service(self, service_name: str):
"""获取指定的服务对象"""
return self._services.get(service_name)
def has_service(self, service_name: str) -> bool:
"""检查是否有指定的服务对象"""
return service_name in self._services and self._services[service_name] is not None
def set_action_context(self, thinking_id: str = None, shutting_down: bool = False, **kwargs):
"""设置action上下文信息"""
if thinking_id:
self._action_context["thinking_id"] = thinking_id
self._action_context["shutting_down"] = shutting_down
self._action_context.update(kwargs)
def get_action_context(self, key: str, default=None):
"""获取action上下文信息"""
return self._action_context.get(key, default)
def get_config(self, key: str, default=None):
"""获取插件配置值,支持嵌套键访问
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self._plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self._plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
def has_config(self, key: str) -> bool:
"""检查是否存在指定的配置项
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
Returns:
bool: 是否存在该配置项
"""
if not self._plugin_config:
return False
keys = key.split(".")
current = self._plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return False
return True
def get_all_config(self) -> dict:
"""获取所有插件配置
Returns:
dict: 插件配置字典的副本
"""
return self._plugin_config.copy() if self._plugin_config else {}
# 便捷的工厂函数
def create_plugin_api(
chat_stream=None,
expressor=None,
replyer=None,
observations=None,
log_prefix: str = "[Plugin]",
plugin_config: dict = None,
) -> PluginAPI:
"""
创建插件API实例的便捷函数
Args:
chat_stream: 聊天流对象
expressor: 表达器对象
replyer: 回复器对象
observations: 观察列表
log_prefix: 日志前缀
plugin_config: 插件配置字典
Returns:
PluginAPI: 配置好的插件API实例
"""
return PluginAPI(
chat_stream=chat_stream,
expressor=expressor,
replyer=replyer,
observations=observations,
log_prefix=log_prefix,
plugin_config=plugin_config,
)
def create_command_api(message, log_prefix: str = "[Command]") -> PluginAPI:
"""
为命令创建插件API实例的便捷函数
Args:
message: 消息对象,应该包含 chat_stream 等信息
log_prefix: 日志前缀
Returns:
PluginAPI: 配置好的插件API实例
"""
chat_stream = getattr(message, "chat_stream", None)
api = PluginAPI(chat_stream=chat_stream, log_prefix=log_prefix)
return api
# 导出主要接口
__all__ = [
"PluginAPI",
"create_plugin_api",
"create_command_api",
# 也可以导出各个API类供单独使用
"MessageAPI",
"LLMAPI",
"DatabaseAPI",
"ConfigAPI",
"UtilsAPI",
"StreamAPI",
"HearflowAPI",
]

View File

@@ -0,0 +1,445 @@
"""
发送API模块
专门负责发送各种类型的消息采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import send_api
await send_api.text_to_group("hello", "123456")
await send_api.emoji_to_group(emoji_base64, "123456")
await send_api.custom_message("video", video_data, "123456", True)
"""
import traceback
import time
import difflib
from typing import Optional
from src.common.logger import get_logger
# 导入依赖
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.focus_chat.heartFC_sender import HeartFCSender
from src.chat.message_receive.message import MessageSending, MessageRecv
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
from src.person_info.person_info import get_person_info_manager
from maim_message import Seg, UserInfo
from src.config.config import global_config
logger = get_logger("send_api")
# =============================================================================
# 内部实现函数(不暴露给外部)
# =============================================================================
async def _send_to_target(
message_type: str,
content: str,
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
"""向指定目标发送消息的内部实现
Args:
message_type: 消息类型,如"text""image""emoji"
content: 消息内容
stream_id: 目标流ID
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息的格式,如"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
try:
logger.info(f"[SendAPI] 发送{message_type}消息到 {stream_id}")
# 查找目标聊天流
target_stream = get_chat_manager().get_stream(stream_id)
if not target_stream:
logger.error(f"[SendAPI] 未找到聊天流: {stream_id}")
return False
# 创建发送器
heart_fc_sender = HeartFCSender()
# 生成消息ID
current_time = time.time()
message_id = f"send_api_{int(current_time * 1000)}"
# 构建机器人用户信息
bot_user_info = UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=target_stream.platform,
)
# 创建消息段
message_segment = Seg(type=message_type, data=content)
# 处理回复消息
anchor_message = None
if reply_to:
anchor_message = await _find_reply_message(target_stream, reply_to)
# 构建发送消息对象
bot_message = MessageSending(
message_id=message_id,
chat_stream=target_stream,
bot_user_info=bot_user_info,
sender_info=target_stream.user_info,
message_segment=message_segment,
display_message=display_message,
reply=anchor_message,
is_head=True,
is_emoji=(message_type == "emoji"),
thinking_start_time=current_time,
)
# 发送消息
sent_msg = await heart_fc_sender.send_message(
bot_message, typing=typing, set_reply=(anchor_message is not None), storage_message=storage_message
)
if sent_msg:
logger.info(f"[SendAPI] 成功发送消息到 {stream_id}")
return True
else:
logger.error("[SendAPI] 发送消息失败")
return False
except Exception as e:
logger.error(f"[SendAPI] 发送消息时出错: {e}")
traceback.print_exc()
return False
async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]:
"""查找要回复的消息
Args:
target_stream: 目标聊天流
reply_to: 回复格式,如"发送者:消息内容""发送者:消息内容"
Returns:
Optional[MessageRecv]: 找到的消息如果没找到则返回None
"""
try:
# 解析reply_to参数
if ":" in reply_to:
parts = reply_to.split(":", 1)
elif "" in reply_to:
parts = reply_to.split("", 1)
else:
logger.warning(f"[SendAPI] reply_to格式不正确: {reply_to}")
return None
if len(parts) != 2:
logger.warning(f"[SendAPI] reply_to格式不正确: {reply_to}")
return None
sender = parts[0].strip()
text = parts[1].strip()
# 获取聊天流的最新20条消息
reverse_talking_message = get_raw_msg_before_timestamp_with_chat(
target_stream.stream_id,
time.time(), # 当前时间之前的消息
20 # 最新的20条消息
)
# 反转列表,使最新的消息在前面
reverse_talking_message = list(reversed(reverse_talking_message))
find_msg = None
for message in reverse_talking_message:
user_id = message["user_id"]
platform = message["chat_info_platform"]
person_id = get_person_info_manager().get_person_id(platform, user_id)
person_name = await get_person_info_manager().get_value(person_id, "person_name")
if person_name == sender:
similarity = difflib.SequenceMatcher(None, text, message["processed_plain_text"]).ratio()
if similarity >= 0.9:
find_msg = message
break
if not find_msg:
logger.info("[SendAPI] 未找到匹配的回复消息")
return None
# 构建MessageRecv对象
user_info = {
"platform": find_msg.get("user_platform", ""),
"user_id": find_msg.get("user_id", ""),
"user_nickname": find_msg.get("user_nickname", ""),
"user_cardname": find_msg.get("user_cardname", ""),
}
group_info = {}
if find_msg.get("chat_info_group_id"):
group_info = {
"platform": find_msg.get("chat_info_group_platform", ""),
"group_id": find_msg.get("chat_info_group_id", ""),
"group_name": find_msg.get("chat_info_group_name", ""),
}
format_info = {"content_format": "", "accept_format": ""}
template_info = {"template_items": {}}
message_info = {
"platform": target_stream.platform,
"message_id": find_msg.get("message_id"),
"time": find_msg.get("time"),
"group_info": group_info,
"user_info": user_info,
"additional_config": find_msg.get("additional_config"),
"format_info": format_info,
"template_info": template_info,
}
message_dict = {
"message_info": message_info,
"raw_message": find_msg.get("processed_plain_text"),
"detailed_plain_text": find_msg.get("processed_plain_text"),
"processed_plain_text": find_msg.get("processed_plain_text"),
}
find_rec_msg = MessageRecv(message_dict)
find_rec_msg.update_chat_stream(target_stream)
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {sender}")
return find_rec_msg
except Exception as e:
logger.error(f"[SendAPI] 查找回复消息时出错: {e}")
traceback.print_exc()
return None
# =============================================================================
# 公共API函数 - 预定义类型的发送函数
# =============================================================================
async def text_to_group(text: str, group_id: str, platform: str = "qq", typing: bool = False, reply_to: str = "", storage_message: bool = True) -> bool:
"""向群聊发送文本消息
Args:
text: 要发送的文本内容
group_id: 群聊ID
platform: 平台,默认为"qq"
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
async def text_to_user(text: str, user_id: str, platform: str = "qq", typing: bool = False, reply_to: str = "", storage_message: bool = True) -> bool:
"""向用户发送私聊文本消息
Args:
text: 要发送的文本内容
user_id: 用户ID
platform: 平台,默认为"qq"
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向群聊发送表情包
Args:
emoji_base64: 表情包的base64编码
group_id: 群聊ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
async def emoji_to_user(emoji_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向用户发送表情包
Args:
emoji_base64: 表情包的base64编码
user_id: 用户ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
async def image_to_group(image_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向群聊发送图片
Args:
image_base64: 图片的base64编码
group_id: 群聊ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message)
async def image_to_user(image_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向用户发送图片
Args:
image_base64: 图片的base64编码
user_id: 用户ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("image", image_base64, stream_id, "", typing=False)
async def command_to_group(command: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向群聊发送命令
Args:
command: 命令
group_id: 群聊ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message)
async def command_to_user(command: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向用户发送命令
Args:
command: 命令
user_id: 用户ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message)
# =============================================================================
# 通用发送函数 - 支持任意消息类型
# =============================================================================
async def custom_to_group(
message_type: str,
content: str,
group_id: str,
platform: str = "qq",
display_message: str = "",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True
) -> bool:
"""向群聊发送自定义类型消息
Args:
message_type: 消息类型,如"text""image""emoji""video""file"
content: 消息内容通常是base64编码或文本
group_id: 群聊ID
platform: 平台,默认为"qq"
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
async def custom_to_user(
message_type: str,
content: str,
user_id: str,
platform: str = "qq",
display_message: str = "",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True
) -> bool:
"""向用户发送自定义类型消息
Args:
message_type: 消息类型,如"text""image""emoji""video""file"
content: 消息内容通常是base64编码或文本
user_id: 用户ID
platform: 平台,默认为"qq"
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
async def custom_message(
message_type: str,
content: str,
target_id: str,
is_group: bool = True,
platform: str = "qq",
display_message: str = "",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True
) -> bool:
"""发送自定义消息的通用接口
Args:
message_type: 消息类型,如"text""image""emoji""video""file""audio"
content: 消息内容
target_id: 目标ID群ID或用户ID
is_group: 是否为群聊True为群聊False为私聊
platform: 平台,默认为"qq"
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
示例:
# 发送视频到群聊
await send_api.custom_message("video", video_base64, "123456", True)
# 发送文件到用户
await send_api.custom_message("file", file_base64, "987654", False)
# 发送音频到群聊并回复特定消息
await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好")
"""
stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)

View File

@@ -1,220 +0,0 @@
from typing import Optional, List, Dict, Any, Tuple
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatManager, ChatStream
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
import asyncio
logger = get_logger("stream_api")
class StreamAPI:
"""聊天流API模块
提供了获取聊天流、通过群ID查找聊天流等功能
"""
def get_chat_stream_by_group_id(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]:
"""通过QQ群ID获取聊天流
Args:
group_id: QQ群ID
platform: 平台标识,默认为"qq"
Returns:
Optional[ChatStream]: 找到的聊天流对象如果未找到则返回None
"""
try:
chat_manager = ChatManager()
# 遍历所有已加载的聊天流查找匹配的群ID
for stream_id, stream in chat_manager.streams.items():
if (
stream.group_info
and str(stream.group_info.group_id) == str(group_id)
and stream.platform == platform
):
logger.info(f"{self.log_prefix} 通过群ID {group_id} 找到聊天流: {stream_id}")
return stream
logger.warning(f"{self.log_prefix} 未找到群ID为 {group_id} 的聊天流")
return None
except Exception as e:
logger.error(f"{self.log_prefix} 通过群ID获取聊天流时出错: {e}")
return None
def get_all_group_chat_streams(self, platform: str = "qq") -> List[ChatStream]:
"""获取所有群聊的聊天流
Args:
platform: 平台标识,默认为"qq"
Returns:
List[ChatStream]: 所有群聊的聊天流列表
"""
try:
chat_manager = ChatManager()
group_streams = []
for stream in chat_manager.streams.values():
if stream.group_info and stream.platform == platform:
group_streams.append(stream)
logger.info(f"{self.log_prefix} 找到 {len(group_streams)} 个群聊聊天流")
return group_streams
except Exception as e:
logger.error(f"{self.log_prefix} 获取所有群聊聊天流时出错: {e}")
return []
def get_chat_stream_by_user_id(self, user_id: str, platform: str = "qq") -> Optional[ChatStream]:
"""通过用户ID获取私聊聊天流
Args:
user_id: 用户ID
platform: 平台标识,默认为"qq"
Returns:
Optional[ChatStream]: 找到的私聊聊天流对象如果未找到则返回None
"""
try:
chat_manager = ChatManager()
# 遍历所有已加载的聊天流查找匹配的用户ID私聊
for stream_id, stream in chat_manager.streams.items():
if (
not stream.group_info # 私聊没有群信息
and stream.user_info
and str(stream.user_info.user_id) == str(user_id)
and stream.platform == platform
):
logger.info(f"{self.log_prefix} 通过用户ID {user_id} 找到私聊聊天流: {stream_id}")
return stream
logger.warning(f"{self.log_prefix} 未找到用户ID为 {user_id} 的私聊聊天流")
return None
except Exception as e:
logger.error(f"{self.log_prefix} 通过用户ID获取私聊聊天流时出错: {e}")
return None
def get_chat_streams_info(self) -> List[Dict[str, Any]]:
"""获取所有聊天流的基本信息
Returns:
List[Dict[str, Any]]: 包含聊天流基本信息的字典列表
"""
try:
chat_manager = ChatManager()
streams_info = []
for stream_id, stream in chat_manager.streams.items():
info = {
"stream_id": stream_id,
"platform": stream.platform,
"chat_type": "group" if stream.group_info else "private",
"create_time": stream.create_time,
"last_active_time": stream.last_active_time,
}
if stream.group_info:
info.update({"group_id": stream.group_info.group_id, "group_name": stream.group_info.group_name})
if stream.user_info:
info.update({"user_id": stream.user_info.user_id, "user_nickname": stream.user_info.user_nickname})
streams_info.append(info)
logger.info(f"{self.log_prefix} 获取到 {len(streams_info)} 个聊天流信息")
return streams_info
except Exception as e:
logger.error(f"{self.log_prefix} 获取聊天流信息时出错: {e}")
return []
async def get_chat_stream_by_group_id_async(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]:
"""异步通过QQ群ID获取聊天流包括从数据库搜索
Args:
group_id: QQ群ID
platform: 平台标识,默认为"qq"
Returns:
Optional[ChatStream]: 找到的聊天流对象如果未找到则返回None
"""
try:
# 首先尝试从内存中查找
stream = self.get_chat_stream_by_group_id(group_id, platform)
if stream:
return stream
# 如果内存中没有,尝试从数据库加载所有聊天流后再查找
chat_manager = ChatManager()
await chat_manager.load_all_streams()
# 再次尝试从内存中查找
stream = self.get_chat_stream_by_group_id(group_id, platform)
return stream
except Exception as e:
logger.error(f"{self.log_prefix} 异步通过群ID获取聊天流时出错: {e}")
return None
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
"""等待新消息或超时
Args:
timeout: 超时时间默认1200秒
Returns:
Tuple[bool, str]: (是否收到新消息, 空字符串)
"""
try:
# 获取必要的服务对象
observations = self.get_service("observations")
if not observations:
logger.warning(f"{self.log_prefix} 无法获取observations服务无法等待新消息")
return False, ""
# 获取第一个观察对象通常是ChattingObservation
observation = observations[0] if observations else None
if not observation:
logger.warning(f"{self.log_prefix} 无观察对象,无法等待新消息")
return False, ""
# 从action上下文获取thinking_id
thinking_id = self.get_action_context("thinking_id")
if not thinking_id:
logger.warning(f"{self.log_prefix} 无thinking_id无法等待新消息")
return False, ""
logger.info(f"{self.log_prefix} 开始等待新消息... (超时: {timeout}秒)")
wait_start_time = asyncio.get_event_loop().time()
while True:
# 检查关闭标志
shutting_down = self.get_action_context("shutting_down", False)
if shutting_down:
logger.info(f"{self.log_prefix} 等待新消息时检测到关闭信号,中断等待")
return False, ""
# 检查新消息
thinking_id_timestamp = parse_thinking_id_to_timestamp(thinking_id)
if await observation.has_new_messages_since(thinking_id_timestamp):
logger.info(f"{self.log_prefix} 检测到新消息")
return True, ""
# 检查超时
if asyncio.get_event_loop().time() - wait_start_time > timeout:
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒)")
return False, ""
# 短暂休眠
await asyncio.sleep(0.5)
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
return False, ""
except Exception as e:
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
return False, f"等待新消息失败: {str(e)}"

View File

@@ -1,126 +1,165 @@
"""工具类API模块
提供了各种辅助功能
使用方式:
from src.plugin_system.apis import utils_api
plugin_path = utils_api.get_plugin_path()
data = utils_api.read_json_file("data.json")
timestamp = utils_api.get_timestamp()
"""
import os
import json
import time
import inspect
import datetime
import uuid
from typing import Any, Optional
from src.common.logger import get_logger
logger = get_logger("utils_api")
class UtilsAPI:
"""工具类API模块
# =============================================================================
# 文件操作API函数
# =============================================================================
提供了各种辅助功能
def get_plugin_path(caller_frame=None) -> str:
"""获取调用者插件的路径
Args:
caller_frame: 调用者的栈帧默认为None自动获取
Returns:
str: 插件目录的绝对路径
"""
try:
if caller_frame is None:
caller_frame = inspect.currentframe().f_back
def get_plugin_path(self) -> str:
"""获取当前插件的路径
Returns:
str: 插件目录的绝对路径
"""
import inspect
plugin_module_path = inspect.getfile(self.__class__)
plugin_module_path = inspect.getfile(caller_frame)
plugin_dir = os.path.dirname(plugin_module_path)
return plugin_dir
except Exception as e:
logger.error(f"[UtilsAPI] 获取插件路径失败: {e}")
return ""
def read_json_file(self, file_path: str, default: Any = None) -> Any:
"""读取JSON文件
Args:
file_path: 文件路径,可以是相对于插件目录的路径
default: 如果文件不存在或读取失败时返回的默认值
def read_json_file(file_path: str, default: Any = None) -> Any:
"""读取JSON文件
Returns:
Any: JSON数据或默认值
"""
try:
# 如果是相对路径,则相对于插件目录
if not os.path.isabs(file_path):
file_path = os.path.join(self.get_plugin_path(), file_path)
Args:
file_path: 文件路径,可以是相对于插件目录的路径
default: 如果文件不存在或读取失败时返回的默认值
if not os.path.exists(file_path):
logger.warning(f"{self.log_prefix} 文件不存在: {file_path}")
return default
Returns:
Any: JSON数据或默认值
"""
try:
# 如果是相对路径,则相对于调用者的插件目录
if not os.path.isabs(file_path):
caller_frame = inspect.currentframe().f_back
plugin_dir = get_plugin_path(caller_frame)
file_path = os.path.join(plugin_dir, file_path)
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"{self.log_prefix} 读取JSON文件出错: {e}")
if not os.path.exists(file_path):
logger.warning(f"[UtilsAPI] 文件不存在: {file_path}")
return default
def write_json_file(self, file_path: str, data: Any, indent: int = 2) -> bool:
"""写入JSON文件
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"[UtilsAPI] 读取JSON文件出错: {e}")
return default
Args:
file_path: 文件路径,可以是相对于插件目录的路径
data: 要写入的数据
indent: JSON缩进
Returns:
bool: 是否写入成功
"""
try:
# 如果是相对路径,则相对于插件目录
if not os.path.isabs(file_path):
file_path = os.path.join(self.get_plugin_path(), file_path)
def write_json_file(file_path: str, data: Any, indent: int = 2) -> bool:
"""写入JSON文件
# 确保目录存在
os.makedirs(os.path.dirname(file_path), exist_ok=True)
Args:
file_path: 文件路径,可以是相对于插件目录的路径
data: 要写入的数据
indent: JSON缩进
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=indent)
return True
except Exception as e:
logger.error(f"{self.log_prefix} 写入JSON文件出错: {e}")
return False
Returns:
bool: 是否写入成功
"""
try:
# 如果是相对路径,则相对于调用者的插件目录
if not os.path.isabs(file_path):
caller_frame = inspect.currentframe().f_back
plugin_dir = get_plugin_path(caller_frame)
file_path = os.path.join(plugin_dir, file_path)
def get_timestamp(self) -> int:
"""获取当前时间戳
# 确保目录存在
os.makedirs(os.path.dirname(file_path), exist_ok=True)
Returns:
int: 当前时间戳(秒)
"""
return int(time.time())
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=indent)
return True
except Exception as e:
logger.error(f"[UtilsAPI] 写入JSON文件出错: {e}")
return False
def format_time(self, timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
"""格式化时间
Args:
timestamp: 时间戳如果为None则使用当前时间
format_str: 时间格式字符串
# =============================================================================
# 时间相关API函数
# =============================================================================
Returns:
str: 格式化后的时间字符串
"""
import datetime
def get_timestamp() -> int:
"""获取当前时间戳
Returns:
int: 当前时间戳(秒)
"""
return int(time.time())
def format_time(timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
"""格式化时间
Args:
timestamp: 时间戳如果为None则使用当前时间
format_str: 时间格式字符串
Returns:
str: 格式化后的时间字符串
"""
try:
if timestamp is None:
timestamp = time.time()
return datetime.datetime.fromtimestamp(timestamp).strftime(format_str)
except Exception as e:
logger.error(f"[UtilsAPI] 格式化时间失败: {e}")
return ""
def parse_time(self, time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int:
"""解析时间字符串为时间戳
Args:
time_str: 时间字符串
format_str: 时间格式字符串
def parse_time(time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int:
"""解析时间字符串为时间戳
Returns:
int: 时间戳(秒)
"""
import datetime
Args:
time_str: 时间字符串
format_str: 时间格式字符串
Returns:
int: 时间戳(秒)
"""
try:
dt = datetime.datetime.strptime(time_str, format_str)
return int(dt.timestamp())
except Exception as e:
logger.error(f"[UtilsAPI] 解析时间失败: {e}")
return 0
def generate_unique_id(self) -> str:
"""生成唯一ID
Returns:
str: 唯一ID
"""
import uuid
# =============================================================================
# 其他工具函数
# =============================================================================
return str(uuid.uuid4())
def generate_unique_id() -> str:
"""生成唯一ID
Returns:
str: 唯一ID
"""
return str(uuid.uuid4())

View File

@@ -1,435 +1,469 @@
from abc import ABC, abstractmethod
from typing import Tuple
from src.common.logger import get_logger
from src.plugin_system.apis.plugin_api import PluginAPI
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
logger = get_logger("base_action")
class BaseAction(ABC):
"""Action组件基类
Action是插件的一种组件类型用于处理聊天中的动作逻辑
子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性:
- focus_activation_type: 专注模式激活类型
- normal_activation_type: 普通模式激活类型
- activation_keywords: 激活关键词列表
- keyword_case_sensitive: 关键词是否区分大小写
- mode_enable: 启用的聊天模式
- parallel_action: 是否允许并行执行
- random_activation_probability: 随机激活概率
- llm_judge_prompt: LLM判断提示词
"""
def __init__(
self,
action_data: dict,
reasoning: str,
cycle_timers: dict,
thinking_id: str,
observations: list = None,
expressor=None,
replyer=None,
chat_stream=None,
log_prefix: str = "",
shutting_down: bool = False,
plugin_config: dict = None,
**kwargs,
):
"""初始化Action组件
Args:
action_data: 动作数据
reasoning: 执行该动作的理由
cycle_timers: 计时器字典
thinking_id: 思考ID
observations: 观察列表
expressor: 表达器对象
replyer: 回复器对象
chat_stream: 聊天流对象
log_prefix: 日志前缀
shutting_down: 是否正在关闭
plugin_config: 插件配置字典
**kwargs: 其他参数
"""
self.action_data = action_data
self.reasoning = reasoning
self.cycle_timers = cycle_timers
self.thinking_id = thinking_id
self.log_prefix = log_prefix
self.shutting_down = shutting_down
# 设置动作基本信息实例属性
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
# 设置激活类型实例属性(从类属性复制,提供默认值)
self.focus_activation_type: str = self._get_activation_type_value("focus_activation_type", "never")
self.normal_activation_type: str = self._get_activation_type_value("normal_activation_type", "never")
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
self.mode_enable: str = self._get_mode_value("mode_enable", "all")
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
self.enable_plugin: bool = True # 默认启用
# 创建API实例传递所有服务对象
self.api = PluginAPI(
chat_stream=chat_stream or kwargs.get("chat_stream"),
expressor=expressor or kwargs.get("expressor"),
replyer=replyer or kwargs.get("replyer"),
observations=observations or kwargs.get("observations", []),
log_prefix=log_prefix,
plugin_config=plugin_config or kwargs.get("plugin_config"),
)
# 设置API的action上下文
self.api.set_action_context(thinking_id=thinking_id, shutting_down=shutting_down)
logger.debug(f"{self.log_prefix} Action组件初始化完成")
def _get_activation_type_value(self, attr_name: str, default: str) -> str:
"""获取激活类型的字符串值"""
attr = getattr(self.__class__, attr_name, None)
if attr is None:
return default
if hasattr(attr, "value"):
return attr.value
return str(attr)
def _get_mode_value(self, attr_name: str, default: str) -> str:
"""获取模式的字符串值"""
attr = getattr(self.__class__, attr_name, None)
if attr is None:
return default
if hasattr(attr, "value"):
return attr.value
return str(attr)
async def send_text(self, content: str) -> bool:
"""发送回复消息
Args:
content: 回复内容
Returns:
bool: 是否发送成功
"""
chat_stream = self.api.get_service("chat_stream")
if not chat_stream:
logger.error(f"{self.log_prefix} 没有可用的聊天流发送回复")
return False
if chat_stream.group_info:
# 群聊
return await self.api.send_text_to_group(
text=content, group_id=str(chat_stream.group_info.group_id), platform=chat_stream.platform
)
else:
# 私聊
return await self.api.send_text_to_user(
text=content, user_id=str(chat_stream.user_info.user_id), platform=chat_stream.platform
)
async def send_type(self, type: str, text: str, typing: bool = False) -> bool:
"""发送回复消息
Args:
text: 回复内容
Returns:
bool: 是否发送成功
"""
chat_stream = self.api.get_service("chat_stream")
if not chat_stream:
logger.error(f"{self.log_prefix} 没有可用的聊天流发送回复")
return False
if chat_stream.group_info:
# 群聊
return await self.api.send_message_to_target(
message_type=type,
content=text,
platform=chat_stream.platform,
target_id=str(chat_stream.group_info.group_id),
is_group=True,
typing=typing,
)
else:
# 私聊
return await self.api.send_message_to_target(
message_type=type,
content=text,
platform=chat_stream.platform,
target_id=str(chat_stream.user_info.user_id),
is_group=False,
typing=typing,
)
async def send_command(self, command_name: str, args: dict = None, display_message: str = None) -> bool:
"""发送命令消息
使用和send_text相同的方式通过MessageAPI发送命令
Args:
command_name: 命令名称
args: 命令参数
display_message: 显示消息
Returns:
bool: 是否发送成功
"""
try:
# 构造命令数据
command_data = {"name": command_name, "args": args or {}}
# 使用send_message_to_target方法发送命令
chat_stream = self.api.get_service("chat_stream")
if not chat_stream:
logger.error(f"{self.log_prefix} 没有可用的聊天流发送命令")
return False
if chat_stream.group_info:
# 群聊
success = await self.api.send_message_to_target(
message_type="command",
content=command_data,
platform=chat_stream.platform,
target_id=str(chat_stream.group_info.group_id),
is_group=True,
display_message=display_message or f"执行命令: {command_name}",
)
else:
# 私聊
success = await self.api.send_message_to_target(
message_type="command",
content=command_data,
platform=chat_stream.platform,
target_id=str(chat_stream.user_info.user_id),
is_group=False,
display_message=display_message or f"执行命令: {command_name}",
)
if success:
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
else:
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
return False
async def send_message_by_expressor(self, text: str, target: str = "") -> bool:
"""通过expressor发送文本消息的Action专用方法
Args:
text: 要发送的消息文本
target: 目标消息(可选)
Returns:
bool: 是否发送成功
"""
try:
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
# 获取服务
expressor = self.api.get_service("expressor")
chat_stream = self.api.get_service("chat_stream")
observations = self.api.get_service("observations") or []
if not expressor or not chat_stream:
logger.error(f"{self.log_prefix} 无法通过expressor发送消息缺少必要的服务")
return False
# 构造动作数据
reply_data = {"text": text, "target": target, "emojis": []}
# 查找 ChattingObservation 实例
chatting_observation = None
for obs in observations:
if isinstance(obs, ChattingObservation):
chatting_observation = obs
break
if not chatting_observation:
logger.warning(f"{self.log_prefix} 未找到 ChattingObservation 实例,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message = chatting_observation.search_message_by_text(target)
if not anchor_message:
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message.update_chat_stream(chat_stream)
# 使用Action上下文信息发送消息
success, _ = await expressor.deal_reply(
cycle_timers=self.cycle_timers,
action_data=reply_data,
anchor_message=anchor_message,
reasoning=self.reasoning,
thinking_id=self.thinking_id,
)
if success:
logger.info(f"{self.log_prefix} 成功通过expressor发送消息")
else:
logger.error(f"{self.log_prefix} 通过expressor发送消息失败")
return success
except Exception as e:
logger.error(f"{self.log_prefix} 通过expressor发送消息时出错: {e}")
return False
async def send_message_by_replyer(self, target: str = "", extra_info_block: str = None) -> bool:
"""通过replyer发送消息的Action专用方法
Args:
target: 目标消息(可选)
extra_info_block: 额外信息块(可选)
Returns:
bool: 是否发送成功
"""
try:
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
# 获取服务
replyer = self.api.get_service("replyer")
chat_stream = self.api.get_service("chat_stream")
observations = self.api.get_service("observations") or []
if not replyer or not chat_stream:
logger.error(f"{self.log_prefix} 无法通过replyer发送消息缺少必要的服务")
return False
# 构造动作数据
reply_data = {"target": target, "extra_info_block": extra_info_block}
# 查找 ChattingObservation 实例
chatting_observation = None
for obs in observations:
if isinstance(obs, ChattingObservation):
chatting_observation = obs
break
if not chatting_observation:
logger.warning(f"{self.log_prefix} 未找到 ChattingObservation 实例,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message = chatting_observation.search_message_by_text(target)
if not anchor_message:
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message.update_chat_stream(chat_stream)
# 使用Action上下文信息发送消息
success, _ = await replyer.deal_reply(
cycle_timers=self.cycle_timers,
action_data=reply_data,
anchor_message=anchor_message,
reasoning=self.reasoning,
thinking_id=self.thinking_id,
)
if success:
logger.info(f"{self.log_prefix} 成功通过replyer发送消息")
else:
logger.error(f"{self.log_prefix} 通过replyer发送消息失败")
return success
except Exception as e:
logger.error(f"{self.log_prefix} 通过replyer发送消息时出错: {e}")
return False
@classmethod
def get_action_info(cls) -> "ActionInfo":
"""从类属性生成ActionInfo
所有信息都从类属性中读取,确保一致性和完整性。
Action类必须定义所有必要的类属性。
Returns:
ActionInfo: 生成的Action信息对象
"""
# 从类属性读取名称,如果没有定义则使用类名自动生成
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
# 从类属性读取描述,如果没有定义则使用文档字符串的第一行
description = getattr(cls, "action_description", None)
if description is None:
description = "Action动作"
# 安全获取激活类型值
def get_enum_value(attr_name, default):
attr = getattr(cls, attr_name, None)
if attr is None:
# 如果没有定义,返回默认的枚举值
return getattr(ActionActivationType, default.upper(), ActionActivationType.NEVER)
return attr
def get_mode_value(attr_name, default):
attr = getattr(cls, attr_name, None)
if attr is None:
return getattr(ChatMode, default.upper(), ChatMode.ALL)
return attr
return ActionInfo(
name=name,
component_type=ComponentType.ACTION,
description=description,
focus_activation_type=get_enum_value("focus_activation_type", "never"),
normal_activation_type=get_enum_value("normal_activation_type", "never"),
activation_keywords=getattr(cls, "activation_keywords", []).copy(),
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
mode_enable=get_mode_value("mode_enable", "all"),
parallel_action=getattr(cls, "parallel_action", True),
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
# 使用正确的字段名
action_parameters=getattr(cls, "action_parameters", {}).copy(),
action_require=getattr(cls, "action_require", []).copy(),
associated_types=getattr(cls, "associated_types", []).copy(),
)
@abstractmethod
async def execute(self) -> Tuple[bool, str]:
"""执行Action的抽象方法子类必须实现
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
pass
async def handle_action(self) -> Tuple[bool, str]:
"""兼容旧系统的handle_action接口委托给execute方法
为了保持向后兼容性旧系统的代码可能会调用handle_action方法。
此方法将调用委托给新的execute方法。
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
return await self.execute()
from abc import ABC, abstractmethod
from typing import Tuple, Optional
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
from src.plugin_system.apis import send_api, database_api,message_api
import time
import asyncio
logger = get_logger("base_action")
class BaseAction(ABC):
"""Action组件基类
Action是插件的一种组件类型用于处理聊天中的动作逻辑
子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性:
- focus_activation_type: 专注模式激活类型
- normal_activation_type: 普通模式激活类型
- activation_keywords: 激活关键词列表
- keyword_case_sensitive: 关键词是否区分大小写
- mode_enable: 启用的聊天模式
- parallel_action: 是否允许并行执行
- random_activation_probability: 随机激活概率
- llm_judge_prompt: LLM判断提示词
"""
def __init__(
self,
action_data: dict,
reasoning: str,
cycle_timers: dict,
thinking_id: str,
chat_stream=None,
log_prefix: str = "",
shutting_down: bool = False,
plugin_config: dict = None,
**kwargs,
):
"""初始化Action组件
Args:
action_data: 动作数据
reasoning: 执行该动作的理由
cycle_timers: 计时器字典
thinking_id: 思考ID
observations: 观察列表
expressor: 表达器对象
replyer: 回复器对象
chat_stream: 聊天流对象
log_prefix: 日志前缀
shutting_down: 是否正在关闭
plugin_config: 插件配置字典
**kwargs: 其他参数
"""
self.action_data = action_data
self.reasoning = reasoning
self.cycle_timers = cycle_timers
self.thinking_id = thinking_id
self.log_prefix = log_prefix
self.shutting_down = shutting_down
# 保存插件配置
self.plugin_config = plugin_config or {}
# 设置动作基本信息实例属性
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
# 设置激活类型实例属性(从类属性复制,提供默认值)
self.focus_activation_type: str = self._get_activation_type_value("focus_activation_type", "always")
self.normal_activation_type: str = self._get_activation_type_value("normal_activation_type", "always")
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
self.mode_enable: str = self._get_mode_value("mode_enable", "all")
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
# =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
# =============================================================================
# 获取聊天流对象
self.chat_stream = chat_stream or kwargs.get("chat_stream")
self.chat_id = self.chat_stream.stream_id
# 初始化基础信息(带类型注解)
self.is_group: bool = False
self.platform: Optional[str] = None
self.group_id: Optional[str] = None
self.user_id: Optional[str] = None
self.target_id: Optional[str] = None
self.group_name: Optional[str] = None
self.user_nickname: Optional[str] = None
# 如果有聊天流,提取所有信息
if self.chat_stream:
self.platform = getattr(self.chat_stream, 'platform', None)
# 获取群聊信息
# print(self.chat_stream)
# print(self.chat_stream.group_info)
if self.chat_stream.group_info:
self.is_group = True
self.group_id = str(self.chat_stream.group_info.group_id)
self.group_name = getattr(self.chat_stream.group_info, 'group_name', None)
else:
self.is_group = False
self.user_id = str(self.chat_stream.user_info.user_id)
self.user_nickname = getattr(self.chat_stream.user_info, 'user_nickname', None)
# 设置目标ID群聊用群ID私聊用户ID
self.target_id = self.group_id if self.is_group else self.user_id
logger.debug(f"{self.log_prefix} Action组件初始化完成")
logger.debug(f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}")
def _get_activation_type_value(self, attr_name: str, default: str) -> str:
"""获取激活类型的字符串值"""
attr = getattr(self.__class__, attr_name, None)
if attr is None:
return default
if hasattr(attr, "value"):
return attr.value
return str(attr)
def _get_mode_value(self, attr_name: str, default: str) -> str:
"""获取模式的字符串值"""
attr = getattr(self.__class__, attr_name, None)
if attr is None:
return default
if hasattr(attr, "value"):
return attr.value
return str(attr)
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
"""等待新消息或超时
在loop_start_time之后等待新消息如果没有新消息且没有超时就一直等待。
使用message_api检查self.chat_id对应的聊天中是否有新消息。
Args:
timeout: 超时时间默认1200秒
Returns:
Tuple[bool, str]: (是否收到新消息, 空字符串)
"""
try:
# 获取循环开始时间,如果没有则使用当前时间
loop_start_time = self.action_data.get("loop_start_time", time.time())
logger.info(f"{self.log_prefix} 开始等待新消息... (最长等待: {timeout}秒, 从时间点: {loop_start_time})")
# 确保有有效的chat_id
if not self.chat_id:
logger.error(f"{self.log_prefix} 等待新消息失败: 没有有效的chat_id")
return False, "没有有效的chat_id"
wait_start_time = asyncio.get_event_loop().time()
while True:
# 检查关闭标志
# shutting_down = self.get_action_context("shutting_down", False)
# if shutting_down:
# logger.info(f"{self.log_prefix} 等待新消息时检测到关闭信号,中断等待")
# return False, ""
# 检查新消息
current_time = time.time()
new_message_count = message_api.count_new_messages(
chat_id=self.chat_id,
start_time=loop_start_time,
end_time=current_time
)
if new_message_count > 0:
logger.info(f"{self.log_prefix} 检测到{new_message_count}条新消息聊天ID: {self.chat_id}")
return True, ""
# 检查超时
elapsed_time = asyncio.get_event_loop().time() - wait_start_time
if elapsed_time > timeout:
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒)聊天ID: {self.chat_id}")
return False, ""
# 每30秒记录一次等待状态
if int(elapsed_time) % 15 == 0 and int(elapsed_time) > 0:
logger.debug(f"{self.log_prefix} 已等待{int(elapsed_time)}秒,继续等待新消息...")
# 短暂休眠
await asyncio.sleep(0.5)
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
return False, ""
except Exception as e:
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
return False, f"等待新消息失败: {str(e)}"
async def send_text(self, content: str, reply_to: str = "") -> bool:
"""发送文本消息
Args:
content: 文本内容
Returns:
bool: 是否发送成功
"""
if not self.target_id or not self.platform:
logger.error(f"{self.log_prefix} 缺少发送消息所需的信息")
return False
if self.is_group:
return await send_api.text_to_group(
text=content, group_id=self.target_id, platform=self.platform, reply_to=reply_to
)
else:
return await send_api.text_to_user(
text=content, user_id=self.target_id, platform=self.platform, reply_to=reply_to
)
async def send_emoji(self, emoji_base64: str) -> bool:
"""发送表情包
Args:
emoji_base64: 表情包的base64编码
Returns:
bool: 是否发送成功
"""
# 导入send_api
from src.plugin_system.apis import send_api
if not self.target_id or not self.platform:
logger.error(f"{self.log_prefix} 缺少发送消息所需的信息")
return False
if self.is_group:
return await send_api.emoji_to_group(emoji_base64, self.target_id, self.platform)
else:
return await send_api.emoji_to_user(emoji_base64, self.target_id, self.platform)
async def send_image(self, image_base64: str) -> bool:
"""发送图片
Args:
image_base64: 图片的base64编码
Returns:
bool: 是否发送成功
"""
# 导入send_api
from src.plugin_system.apis import send_api
if not self.target_id or not self.platform:
logger.error(f"{self.log_prefix} 缺少发送消息所需的信息")
return False
if self.is_group:
return await send_api.image_to_group(image_base64, self.target_id, self.platform)
else:
return await send_api.image_to_user(image_base64, self.target_id, self.platform)
async def send_custom(self, message_type: str, content: str, typing: bool = False) -> bool:
"""发送自定义类型消息
Args:
message_type: 消息类型,如"video""file""audio"
content: 消息内容
typing: 是否显示正在输入
Returns:
bool: 是否发送成功
"""
# 导入send_api
from src.plugin_system.apis import send_api
if not self.target_id or not self.platform:
logger.error(f"{self.log_prefix} 缺少发送消息所需的信息")
return False
return await send_api.custom_message(
message_type=message_type,
content=content,
target_id=self.target_id,
is_group=self.is_group,
platform=self.platform,
typing=typing
)
async def store_action_info(
self,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
) -> None:
"""存储动作信息到数据库
Args:
action_build_into_prompt: 是否构建到提示中
action_prompt_display: 显示的action提示信息
action_done: action是否完成
"""
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=action_build_into_prompt,
action_prompt_display=action_prompt_display,
action_done=action_done,
thinking_id=self.thinking_id,
action_data=self.action_data,
action_name=self.action_name,
)
async def send_command(self, command_name: str, args: dict = None, display_message: str = None, storage_message: bool = True) -> bool:
"""发送命令消息
使用和send_text相同的方式通过MessageAPI发送命令
Args:
command_name: 命令名称
args: 命令参数
display_message: 显示消息
Returns:
bool: 是否发送成功
"""
try:
# 构造命令数据
command_data = {"name": command_name, "args": args or {}}
if self.is_group:
# 群聊
success = await send_api.command_to_group(
command=command_data,
group_id=str(self.group_id),
platform=self.platform,
storage_message=storage_message
)
else:
# 私聊
success = await send_api.command_to_user(
command=command_data,
user_id=str(self.user_id),
platform=self.platform,
storage_message=storage_message
)
if success:
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
else:
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
return False
@classmethod
def get_action_info(cls) -> "ActionInfo":
"""从类属性生成ActionInfo
所有信息都从类属性中读取,确保一致性和完整性。
Action类必须定义所有必要的类属性。
Returns:
ActionInfo: 生成的Action信息对象
"""
# 从类属性读取名称,如果没有定义则使用类名自动生成
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
# 从类属性读取描述,如果没有定义则使用文档字符串的第一行
description = getattr(cls, "action_description", None)
if description is None:
description = "Action动作"
# 安全获取激活类型值
def get_enum_value(attr_name, default):
attr = getattr(cls, attr_name, None)
if attr is None:
# 如果没有定义,返回默认的枚举值
return getattr(ActionActivationType, default.upper(), ActionActivationType.NEVER)
return attr
def get_mode_value(attr_name, default):
attr = getattr(cls, attr_name, None)
if attr is None:
return getattr(ChatMode, default.upper(), ChatMode.ALL)
return attr
return ActionInfo(
name=name,
component_type=ComponentType.ACTION,
description=description,
focus_activation_type=get_enum_value("focus_activation_type", "always"),
normal_activation_type=get_enum_value("normal_activation_type", "always"),
activation_keywords=getattr(cls, "activation_keywords", []).copy(),
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
mode_enable=get_mode_value("mode_enable", "all"),
parallel_action=getattr(cls, "parallel_action", True),
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
# 使用正确的字段名
action_parameters=getattr(cls, "action_parameters", {}).copy(),
action_require=getattr(cls, "action_require", []).copy(),
associated_types=getattr(cls, "associated_types", []).copy(),
)
@abstractmethod
async def execute(self) -> Tuple[bool, str]:
"""执行Action的抽象方法子类必须实现
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
pass
async def handle_action(self) -> Tuple[bool, str]:
"""兼容旧系统的handle_action接口委托给execute方法
为了保持向后兼容性旧系统的代码可能会调用handle_action方法。
此方法将调用委托给新的execute方法
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
return await self.execute()
def get_action_context(self, key: str, default=None):
"""获取action上下文信息
Args:
key: 上下文键名
default: 默认值
Returns:
Any: 上下文值或默认值
"""
return self.api.get_action_context(key, default)
def get_config(self, key: str, default=None):
"""获取插件配置值,支持嵌套键访问
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current

View File

@@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional, List
from src.common.logger import get_logger
from src.plugin_system.apis.plugin_api import PluginAPI
from src.plugin_system.base.component_types import CommandInfo, ComponentType
from src.chat.message_receive.message import MessageRecv
from src.plugin_system.apis import send_api
logger = get_logger("base_command")
@@ -20,6 +20,9 @@ class BaseCommand(ABC):
- intercept_message: 是否拦截消息处理默认True拦截False继续传递
"""
command_name: str = ""
command_description: str = ""
# 默认命令设置(子类可以覆盖)
command_pattern: str = ""
command_help: str = ""
@@ -35,9 +38,7 @@ class BaseCommand(ABC):
"""
self.message = message
self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组
# 创建API实例
self.api = PluginAPI(chat_stream=message.chat_stream, log_prefix="[Command]", plugin_config=plugin_config)
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
self.log_prefix = "[Command]"
@@ -60,6 +61,31 @@ class BaseCommand(ABC):
"""
pass
def get_config(self, key: str, default=None):
"""获取插件配置值,支持嵌套键访问
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
async def send_text(self, content: str) -> None:
"""发送回复消息
@@ -71,13 +97,19 @@ class BaseCommand(ABC):
if chat_stream.group_info:
# 群聊
await self.api.send_text_to_group(
text=content, group_id=str(chat_stream.group_info.group_id), platform=chat_stream.platform
await send_api.text_to_group(
text=content,
group_id=str(chat_stream.group_info.group_id),
platform=chat_stream.platform
)
else:
# 私聊
await self.api.send_text_to_user(
text=content, user_id=str(chat_stream.user_info.user_id), platform=chat_stream.platform
await send_api.text_to_user(
text=content,
user_id=str(chat_stream.user_info.user_id),
platform=chat_stream.platform
)
async def send_type(
@@ -98,31 +130,30 @@ class BaseCommand(ABC):
if chat_stream.group_info:
# 群聊
return await self.api.send_message_to_target(
from src.plugin_system.apis import send_api
return await send_api.custom_message(
message_type=message_type,
content=content,
platform=chat_stream.platform,
target_id=str(chat_stream.group_info.group_id),
is_group=True,
display_message=display_message,
platform=chat_stream.platform,
typing=typing,
)
else:
# 私聊
return await self.api.send_message_to_target(
from src.plugin_system.apis import send_api
return await send_api.custom_message(
message_type=message_type,
content=content,
platform=chat_stream.platform,
target_id=str(chat_stream.user_info.user_id),
is_group=False,
display_message=display_message,
platform=chat_stream.platform,
typing=typing,
)
async def send_command(self, command_name: str, args: dict = None, display_message: str = None) -> bool:
"""发送命令消息
使用和send_text相同的方式通过MessageAPI发送命令
Args:
command_name: 命令名称
args: 命令参数
@@ -135,29 +166,28 @@ class BaseCommand(ABC):
# 构造命令数据
command_data = {"name": command_name, "args": args or {}}
# 使用send_message_to_target方法发送命令
# 获取聊天流信息
chat_stream = self.message.chat_stream
command_content = command_data
if chat_stream.group_info:
# 群聊
success = await self.api.send_message_to_target(
from src.plugin_system.apis import send_api
success = await send_api.custom_message(
message_type="command",
content=command_content,
platform=chat_stream.platform,
content=command_data,
target_id=str(chat_stream.group_info.group_id),
is_group=True,
display_message=display_message or f"执行命令: {command_name}",
platform=chat_stream.platform,
)
else:
# 私聊
success = await self.api.send_message_to_target(
from src.plugin_system.apis import send_api
success = await send_api.custom_message(
message_type="command",
content=command_content,
platform=chat_stream.platform,
content=command_data,
target_id=str(chat_stream.user_info.user_id),
is_group=False,
display_message=display_message or f"执行命令: {command_name}",
platform=chat_stream.platform,
)
if success:
@@ -172,7 +202,7 @@ class BaseCommand(ABC):
return False
@classmethod
def get_command_info(cls, name: str = None, description: str = None) -> "CommandInfo":
def get_command_info(cls) -> "CommandInfo":
"""从类属性生成CommandInfo
Args:
@@ -183,19 +213,10 @@ class BaseCommand(ABC):
CommandInfo: 生成的Command信息对象
"""
# 优先使用类属性,然后自动生成
if name is None:
name = getattr(cls, "command_name", cls.__name__.lower().replace("command", ""))
if description is None:
description = getattr(cls, "command_description", None)
if description is None:
description = cls.__doc__ or f"{cls.__name__} Command组件"
description = description.strip().split("\n")[0] # 取第一行作为描述
return CommandInfo(
name=name,
name=cls.command_name,
component_type=ComponentType.COMMAND,
description=description,
description=cls.command_description,
command_pattern=cls.command_pattern,
command_help=cls.command_help,
command_examples=cls.command_examples.copy() if cls.command_examples else [],

View File

@@ -54,23 +54,43 @@ class ComponentRegistry:
"""
component_name = component_info.name
component_type = component_info.component_type
plugin_name = getattr(component_info, 'plugin_name', 'unknown')
if component_name in self._components:
logger.warning(f"组件 {component_name} 已存在,跳过注册")
# 🔥 系统级别自动区分:为不同类型的组件添加命名空间前缀
if component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}"
elif component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}"
else:
# 未来扩展的组件类型
namespaced_name = f"{component_type.value}.{component_name}"
# 检查命名空间化的名称是否冲突
if namespaced_name in self._components:
existing_info = self._components[namespaced_name]
existing_plugin = getattr(existing_info, 'plugin_name', 'unknown')
logger.warning(
f"组件冲突: {component_type.value}组件 '{component_name}' "
f"已被插件 '{existing_plugin}' 注册,跳过插件 '{plugin_name}' 的注册"
)
return False
# 注册到通用注册表
self._components[component_name] = component_info
self._components_by_type[component_type][component_name] = component_info
self._component_classes[component_name] = component_class
# 注册到通用注册表(使用命名空间化的名称)
self._components[namespaced_name] = component_info
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名
self._component_classes[namespaced_name] = component_class
# 根据组件类型进行特定注册
# 根据组件类型进行特定注册(使用原始名称)
if component_type == ComponentType.ACTION:
self._register_action_component(component_info, component_class)
elif component_type == ComponentType.COMMAND:
self._register_command_component(component_info, component_class)
logger.debug(f"已注册{component_type.value}组件: {component_name} ({component_class.__name__})")
logger.debug(
f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' "
f"({component_class.__name__}) [插件: {plugin_name}]"
)
return True
def _register_action_component(self, action_info: ActionInfo, action_class: Type):
@@ -94,13 +114,103 @@ class ComponentRegistry:
# === 组件查询方法 ===
def get_component_info(self, component_name: str) -> Optional[ComponentInfo]:
"""获取组件信息"""
return self._components.get(component_name)
def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]:
"""获取组件信息,支持自动命名空间解析
Args:
component_name: 组件名称,可以是原始名称或命名空间化的名称
component_type: 组件类型,如果提供则优先在该类型中查找
Returns:
Optional[ComponentInfo]: 组件信息或None
"""
# 1. 如果已经是命名空间化的名称,直接查找
if '.' in component_name:
return self._components.get(component_name)
# 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type:
if component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}"
elif component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}"
else:
namespaced_name = f"{component_type.value}.{component_name}"
return self._components.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = []
for namespace_prefix in ["action", "command"]:
namespaced_name = f"{namespace_prefix}.{component_name}"
component_info = self._components.get(namespaced_name)
if component_info:
candidates.append((namespace_prefix, namespaced_name, component_info))
if len(candidates) == 1:
# 只有一个匹配,直接返回
return candidates[0][2]
elif len(candidates) > 1:
# 多个匹配,记录警告并返回第一个
namespaces = [ns for ns, _, _ in candidates]
logger.warning(
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces}"
f"使用第一个匹配项: {candidates[0][1]}"
)
return candidates[0][2]
# 4. 都没找到
return None
def get_component_class(self, component_name: str) -> Optional[Type]:
"""获取组件类"""
return self._component_classes.get(component_name)
def get_component_class(self, component_name: str, component_type: ComponentType = None) -> Optional[Type]:
"""获取组件类,支持自动命名空间解析
Args:
component_name: 组件名称,可以是原始名称或命名空间化的名称
component_type: 组件类型,如果提供则优先在该类型中查找
Returns:
Optional[Type]: 组件类或None
"""
# 1. 如果已经是命名空间化的名称,直接查找
if '.' in component_name:
return self._component_classes.get(component_name)
# 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type:
if component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}"
elif component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}"
else:
namespaced_name = f"{component_type.value}.{component_name}"
return self._component_classes.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = []
for namespace_prefix in ["action", "command"]:
namespaced_name = f"{namespace_prefix}.{component_name}"
component_class = self._component_classes.get(namespaced_name)
if component_class:
candidates.append((namespace_prefix, namespaced_name, component_class))
if len(candidates) == 1:
# 只有一个匹配,直接返回
namespace, full_name, cls = candidates[0]
logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'")
return cls
elif len(candidates) > 1:
# 多个匹配,记录警告并返回第一个
namespaces = [ns for ns, _, _ in candidates]
logger.warning(
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces}"
f"使用第一个匹配项: {candidates[0][1]}"
)
return candidates[0][2]
# 4. 都没找到
return None
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取指定类型的所有组件"""
@@ -123,7 +233,7 @@ class ComponentRegistry:
def get_action_info(self, action_name: str) -> Optional[ActionInfo]:
"""获取Action信息"""
info = self.get_component_info(action_name)
info = self.get_component_info(action_name, ComponentType.ACTION)
return info if isinstance(info, ActionInfo) else None
# === Command特定查询方法 ===
@@ -138,7 +248,7 @@ class ComponentRegistry:
def get_command_info(self, command_name: str) -> Optional[CommandInfo]:
"""获取Command信息"""
info = self.get_component_info(command_name)
info = self.get_component_info(command_name, ComponentType.COMMAND)
return info if isinstance(info, CommandInfo) else None
def find_command_by_text(self, text: str) -> Optional[tuple[Type, dict, bool, str]]:
@@ -150,7 +260,9 @@ class ComponentRegistry:
Returns:
Optional[tuple[Type, dict, bool, str]]: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
"""
for pattern, command_class in self._command_patterns.items():
match = pattern.match(text)
if match:
command_name = None
@@ -159,17 +271,18 @@ class ComponentRegistry:
if cls == command_class:
command_name = name
break
# 检查命令是否启用
if command_name:
command_info = self.get_command_info(command_name)
if command_info and command_info.enabled:
return (
command_class,
match.groupdict(),
command_info.intercept_message,
command_info.plugin_name,
)
if command_info:
if command_info.enabled:
return (
command_class,
match.groupdict(),
command_info.intercept_message,
command_info.plugin_name,
)
return None
# === 插件管理方法 ===
@@ -227,26 +340,51 @@ class ComponentRegistry:
# === 状态管理方法 ===
def enable_component(self, component_name: str) -> bool:
"""启用组件"""
if component_name in self._components:
self._components[component_name].enabled = True
def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
"""启用组件,支持命名空间解析"""
# 首先尝试找到正确的命名空间化名称
component_info = self.get_component_info(component_name, component_type)
if not component_info:
return False
# 根据组件类型构造正确的命名空间化名称
if component_info.component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}" if '.' not in component_name else component_name
elif component_info.component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}" if '.' not in component_name else component_name
else:
namespaced_name = f"{component_info.component_type.value}.{component_name}" if '.' not in component_name else component_name
if namespaced_name in self._components:
self._components[namespaced_name].enabled = True
# 如果是Action更新默认动作集
component_info = self._components[component_name]
if isinstance(component_info, ActionInfo):
self._default_actions[component_name] = component_info.description
logger.debug(f"已启用组件: {component_name}")
logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
return True
return False
def disable_component(self, component_name: str) -> bool:
"""禁用组件"""
if component_name in self._components:
self._components[component_name].enabled = False
def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
"""禁用组件,支持命名空间解析"""
# 首先尝试找到正确的命名空间化名称
component_info = self.get_component_info(component_name, component_type)
if not component_info:
return False
# 根据组件类型构造正确的命名空间化名称
if component_info.component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}" if '.' not in component_name else component_name
elif component_info.component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}" if '.' not in component_name else component_name
else:
namespaced_name = f"{component_info.component_type.value}.{component_name}" if '.' not in component_name else component_name
if namespaced_name in self._components:
self._components[namespaced_name].enabled = False
# 如果是Action从默认动作集中移除
if component_name in self._default_actions:
del self._default_actions[component_name]
logger.debug(f"已禁用组件: {component_name}")
logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
return True
return False