初始化
This commit is contained in:
41
src/plugin_system/apis/__init__.py
Normal file
41
src/plugin_system/apis/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
插件系统API模块
|
||||
|
||||
提供了插件开发所需的各种API
|
||||
"""
|
||||
|
||||
# 导入所有API模块
|
||||
from src.plugin_system.apis import (
|
||||
chat_api,
|
||||
component_manage_api,
|
||||
config_api,
|
||||
database_api,
|
||||
emoji_api,
|
||||
generator_api,
|
||||
llm_api,
|
||||
message_api,
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
send_api,
|
||||
tool_api,
|
||||
)
|
||||
from .logging_api import get_logger
|
||||
from .plugin_register_api import register_plugin
|
||||
|
||||
# 导出所有API模块,使它们可以通过 apis.xxx 方式访问
|
||||
__all__ = [
|
||||
"chat_api",
|
||||
"component_manage_api",
|
||||
"config_api",
|
||||
"database_api",
|
||||
"emoji_api",
|
||||
"generator_api",
|
||||
"llm_api",
|
||||
"message_api",
|
||||
"person_api",
|
||||
"plugin_manage_api",
|
||||
"send_api",
|
||||
"get_logger",
|
||||
"register_plugin",
|
||||
"tool_api",
|
||||
]
|
||||
325
src/plugin_system/apis/chat_api.py
Normal file
325
src/plugin_system/apis/chat_api.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
聊天API模块
|
||||
|
||||
专门负责聊天信息的查询和管理,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import chat_api
|
||||
streams = chat_api.get_all_group_streams()
|
||||
chat_type = chat_api.get_stream_type(stream)
|
||||
|
||||
或者:
|
||||
from src.plugin_system.apis.chat_api import ChatManager as chat
|
||||
streams = chat.get_all_group_streams()
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
|
||||
logger = get_logger("chat_api")
|
||||
|
||||
|
||||
class SpecialTypes(Enum):
|
||||
"""特殊枚举类型"""
|
||||
|
||||
ALL_PLATFORMS = "all_platforms"
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
|
||||
|
||||
@staticmethod
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 聊天流列表
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
||||
"""
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取聊天流失败: {e}")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有群聊聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 群聊聊天流列表
|
||||
"""
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的群聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取群聊流失败: {e}")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有私聊聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 私聊聊天流列表
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
||||
"""
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的私聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取私聊流失败: {e}")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_stream_by_group_id(
|
||||
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
|
||||
"""根据群ID获取聊天流
|
||||
|
||||
Args:
|
||||
group_id: 群聊ID
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 group_id 为空字符串
|
||||
TypeError: 如果 group_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
|
||||
"""
|
||||
if not isinstance(group_id, str):
|
||||
raise TypeError("group_id 必须是字符串类型")
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
if not group_id:
|
||||
raise ValueError("group_id 不能为空")
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (
|
||||
stream.group_info
|
||||
and str(stream.group_info.group_id) == str(group_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流")
|
||||
return stream
|
||||
logger.warning(f"[ChatAPI] 未找到群ID {group_id} 的聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 查找群聊流失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_private_stream_by_user_id(
|
||||
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
|
||||
"""根据用户ID获取私聊流
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 user_id 为空字符串
|
||||
TypeError: 如果 user_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
|
||||
"""
|
||||
if not isinstance(user_id, str):
|
||||
raise TypeError("user_id 必须是字符串类型")
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
if not user_id:
|
||||
raise ValueError("user_id 不能为空")
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (
|
||||
not stream.group_info
|
||||
and str(stream.user_info.user_id) == str(user_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流")
|
||||
return stream
|
||||
logger.warning(f"[ChatAPI] 未找到用户ID {user_id} 的私聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 查找私聊流失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
"""获取聊天流类型
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
|
||||
Returns:
|
||||
str: 聊天类型 ("group", "private", "unknown")
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||
ValueError: 如果 chat_stream 为空
|
||||
"""
|
||||
if not isinstance(chat_stream, ChatStream):
|
||||
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||
if not chat_stream:
|
||||
raise ValueError("chat_stream 不能为 None")
|
||||
|
||||
if hasattr(chat_stream, "group_info"):
|
||||
return "group" if chat_stream.group_info else "private"
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
"""获取聊天流详细信息
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
|
||||
Returns:
|
||||
Dict ({str: Any}): 聊天流信息字典
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||
ValueError: 如果 chat_stream 为空
|
||||
"""
|
||||
if not chat_stream:
|
||||
raise ValueError("chat_stream 不能为 None")
|
||||
if not isinstance(chat_stream, ChatStream):
|
||||
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||
|
||||
try:
|
||||
info: Dict[str, Any] = {
|
||||
"stream_id": chat_stream.stream_id,
|
||||
"platform": chat_stream.platform,
|
||||
"type": ChatManager.get_stream_type(chat_stream),
|
||||
}
|
||||
|
||||
if chat_stream.group_info:
|
||||
info.update(
|
||||
{
|
||||
"group_id": chat_stream.group_info.group_id,
|
||||
"group_name": getattr(chat_stream.group_info, "group_name", "未知群聊"),
|
||||
}
|
||||
)
|
||||
|
||||
if chat_stream.user_info:
|
||||
info.update(
|
||||
{
|
||||
"user_id": chat_stream.user_info.user_id,
|
||||
"user_name": chat_stream.user_info.user_nickname,
|
||||
}
|
||||
)
|
||||
|
||||
return info
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取聊天流信息失败: {e}")
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_streams_summary() -> Dict[str, int]:
|
||||
"""获取聊天流统计摘要
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 包含各种统计信息的字典
|
||||
"""
|
||||
try:
|
||||
all_streams = ChatManager.get_all_streams(SpecialTypes.ALL_PLATFORMS)
|
||||
group_streams = ChatManager.get_group_streams(SpecialTypes.ALL_PLATFORMS)
|
||||
private_streams = ChatManager.get_private_streams(SpecialTypes.ALL_PLATFORMS)
|
||||
|
||||
summary = {
|
||||
"total_streams": len(all_streams),
|
||||
"group_streams": len(group_streams),
|
||||
"private_streams": len(private_streams),
|
||||
"qq_streams": len([s for s in all_streams if s.platform == "qq"]),
|
||||
}
|
||||
|
||||
logger.debug(f"[ChatAPI] 聊天流统计: {summary}")
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}")
|
||||
return {
|
||||
"total_streams": 0,
|
||||
"group_streams": 0,
|
||||
"private_streams": 0,
|
||||
"qq_streams": 0,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 模块级别的便捷函数 - 类似 requests.get(), requests.post() 的设计
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取所有聊天流的便捷函数"""
|
||||
return ChatManager.get_all_streams(platform)
|
||||
|
||||
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取群聊聊天流的便捷函数"""
|
||||
return ChatManager.get_group_streams(platform)
|
||||
|
||||
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取私聊聊天流的便捷函数"""
|
||||
return ChatManager.get_private_streams(platform)
|
||||
|
||||
|
||||
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
|
||||
"""根据群ID获取聊天流的便捷函数"""
|
||||
return ChatManager.get_group_stream_by_group_id(group_id, platform)
|
||||
|
||||
|
||||
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
|
||||
"""根据用户ID获取私聊流的便捷函数"""
|
||||
return ChatManager.get_private_stream_by_user_id(user_id, platform)
|
||||
|
||||
|
||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
"""获取聊天流类型的便捷函数"""
|
||||
return ChatManager.get_stream_type(chat_stream)
|
||||
|
||||
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
"""获取聊天流信息的便捷函数"""
|
||||
return ChatManager.get_stream_info(chat_stream)
|
||||
|
||||
|
||||
def get_streams_summary() -> Dict[str, int]:
|
||||
"""获取聊天流统计摘要的便捷函数"""
|
||||
return ChatManager.get_streams_summary()
|
||||
268
src/plugin_system/apis/component_manage_api.py
Normal file
268
src/plugin_system/apis/component_manage_api.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from typing import Optional, Union, Dict
|
||||
from src.plugin_system.base.component_types import (
|
||||
CommandInfo,
|
||||
ActionInfo,
|
||||
EventHandlerInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
ToolInfo,
|
||||
)
|
||||
|
||||
|
||||
# === 插件信息查询 ===
|
||||
def get_all_plugin_info() -> Dict[str, PluginInfo]:
|
||||
"""
|
||||
获取所有插件的信息。
|
||||
|
||||
Returns:
|
||||
dict: 包含所有插件信息的字典,键为插件名称,值为 PluginInfo 对象。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_all_plugins()
|
||||
|
||||
|
||||
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
|
||||
"""
|
||||
获取指定插件的信息。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件名称。
|
||||
|
||||
Returns:
|
||||
PluginInfo: 插件信息对象,如果插件不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_plugin_info(plugin_name)
|
||||
|
||||
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
component_name: str, component_type: ComponentType
|
||||
) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
"""
|
||||
获取指定组件的信息。
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
Returns:
|
||||
Union[CommandInfo, ActionInfo, EventHandlerInfo]: 组件信息对象,如果组件不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_component_info(component_name, component_type) # type: ignore
|
||||
|
||||
|
||||
def get_components_info_by_type(
|
||||
component_type: ComponentType,
|
||||
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
"""
|
||||
获取指定类型的所有组件信息。
|
||||
|
||||
Args:
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
dict: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_components_by_type(component_type) # type: ignore
|
||||
|
||||
|
||||
def get_enabled_components_info_by_type(
|
||||
component_type: ComponentType,
|
||||
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
"""
|
||||
获取指定类型的所有启用的组件信息。
|
||||
|
||||
Args:
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
dict: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_enabled_components_by_type(component_type) # type: ignore
|
||||
|
||||
|
||||
# === Action 查询方法 ===
|
||||
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
|
||||
"""
|
||||
获取指定 Action 的注册信息。
|
||||
|
||||
Args:
|
||||
action_name (str): Action 名称。
|
||||
|
||||
Returns:
|
||||
ActionInfo: Action 信息对象,如果 Action 不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_registered_action_info(action_name)
|
||||
|
||||
|
||||
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
|
||||
"""
|
||||
获取指定 Command 的注册信息。
|
||||
|
||||
Args:
|
||||
command_name (str): Command 名称。
|
||||
|
||||
Returns:
|
||||
CommandInfo: Command 信息对象,如果 Command 不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_registered_command_info(command_name)
|
||||
|
||||
|
||||
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
|
||||
"""
|
||||
获取指定 Tool 的注册信息。
|
||||
|
||||
Args:
|
||||
tool_name (str): Tool 名称。
|
||||
|
||||
Returns:
|
||||
ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_registered_tool_info(tool_name)
|
||||
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
def get_registered_event_handler_info(
|
||||
event_handler_name: str,
|
||||
) -> Optional[EventHandlerInfo]:
|
||||
"""
|
||||
获取指定 EventHandler 的注册信息。
|
||||
|
||||
Args:
|
||||
event_handler_name (str): EventHandler 名称。
|
||||
|
||||
Returns:
|
||||
EventHandlerInfo: EventHandler 信息对象,如果 EventHandler 不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_registered_event_handler_info(event_handler_name)
|
||||
|
||||
|
||||
# === 组件管理方法 ===
|
||||
def globally_enable_component(component_name: str, component_type: ComponentType) -> bool:
|
||||
"""
|
||||
全局启用指定组件。
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
bool: 启用成功返回 True,否则返回 False。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.enable_component(component_name, component_type)
|
||||
|
||||
|
||||
async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool:
|
||||
"""
|
||||
全局禁用指定组件。
|
||||
|
||||
**此函数是异步的,确保在异步环境中调用。**
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
bool: 禁用成功返回 True,否则返回 False。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return await component_registry.disable_component(component_name, component_type)
|
||||
|
||||
|
||||
def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
|
||||
"""
|
||||
局部启用指定组件。
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
stream_id (str): 消息流 ID。
|
||||
|
||||
Returns:
|
||||
bool: 启用成功返回 True,否则返回 False。
|
||||
"""
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
|
||||
case ComponentType.TOOL:
|
||||
return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
|
||||
case _:
|
||||
raise ValueError(f"未知 component type: {component_type}")
|
||||
|
||||
|
||||
def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
|
||||
"""
|
||||
局部禁用指定组件。
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
stream_id (str): 消息流 ID。
|
||||
|
||||
Returns:
|
||||
bool: 禁用成功返回 True,否则返回 False。
|
||||
"""
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
|
||||
case ComponentType.TOOL:
|
||||
return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
|
||||
case _:
|
||||
raise ValueError(f"未知 component type: {component_type}")
|
||||
|
||||
|
||||
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
|
||||
"""
|
||||
获取指定消息流中禁用的组件列表。
|
||||
|
||||
Args:
|
||||
stream_id (str): 消息流 ID。
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
list[str]: 禁用的组件名称列表。
|
||||
"""
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
return global_announcement_manager.get_disabled_chat_actions(stream_id)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.get_disabled_chat_commands(stream_id)
|
||||
case ComponentType.TOOL:
|
||||
return global_announcement_manager.get_disabled_chat_tools(stream_id)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
|
||||
case _:
|
||||
raise ValueError(f"未知 component type: {component_type}")
|
||||
77
src/plugin_system/apis/config_api.py
Normal file
77
src/plugin_system/apis/config_api.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""配置API模块
|
||||
|
||||
提供了配置读取和用户信息获取等功能
|
||||
使用方式:
|
||||
from src.plugin_system.apis import config_api
|
||||
value = config_api.get_global_config("section.key")
|
||||
platform, user_id = await config_api.get_user_id_by_person_name("用户名")
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("config_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 配置访问API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_global_config(key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
安全地从全局配置中获取一个值。
|
||||
插件应使用此方法读取全局配置,以保证只读和隔离性。
|
||||
|
||||
Args:
|
||||
key: 命名空间式配置键名,使用嵌套访问,如 "section.subsection.key",大小写敏感
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = global_config
|
||||
|
||||
try:
|
||||
for k in keys:
|
||||
if hasattr(current, k):
|
||||
current = getattr(current, k)
|
||||
else:
|
||||
raise KeyError(f"配置中不存在子空间或键 '{k}'")
|
||||
return current
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConfigAPI] 获取全局配置 {key} 失败: {e}")
|
||||
return default
|
||||
|
||||
|
||||
def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
从插件配置中获取值,支持嵌套键访问
|
||||
|
||||
Args:
|
||||
plugin_config: 插件配置字典
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key",大小写敏感
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = plugin_config
|
||||
|
||||
try:
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
elif hasattr(current, k):
|
||||
current = getattr(current, k)
|
||||
else:
|
||||
raise KeyError(f"配置中不存在子空间或键 '{k}'")
|
||||
return current
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConfigAPI] 获取插件配置 {key} 失败: {e}")
|
||||
return default
|
||||
29
src/plugin_system/apis/database_api.py
Normal file
29
src/plugin_system/apis/database_api.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""数据库API模块
|
||||
|
||||
提供数据库操作相关功能,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import database_api
|
||||
records = await database_api.db_query(ActionRecords, query_type="get")
|
||||
record = await database_api.db_save(ActionRecords, data={"action_id": "123"})
|
||||
|
||||
注意:此模块现在使用SQLAlchemy实现,提供更好的连接管理和错误处理
|
||||
"""
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import (
|
||||
db_query,
|
||||
db_save,
|
||||
db_get,
|
||||
store_action_info,
|
||||
get_model_class,
|
||||
MODEL_MAPPING
|
||||
)
|
||||
|
||||
# 保持向后兼容性
|
||||
__all__ = [
|
||||
'db_query',
|
||||
'db_save',
|
||||
'db_get',
|
||||
'store_action_info',
|
||||
'get_model_class',
|
||||
'MODEL_MAPPING'
|
||||
]
|
||||
268
src/plugin_system/apis/emoji_api.py
Normal file
268
src/plugin_system/apis/emoji_api.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
表情API模块
|
||||
|
||||
提供表情包相关功能,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import emoji_api
|
||||
result = await emoji_api.get_by_description("开心")
|
||||
count = emoji_api.get_count()
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
|
||||
logger = get_logger("emoji_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包获取API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""根据描述选择表情包
|
||||
|
||||
Args:
|
||||
description: 表情包的描述文本,例如"开心"、"难过"、"愤怒"等
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果描述为空字符串
|
||||
TypeError: 如果描述不是字符串类型
|
||||
"""
|
||||
if not description:
|
||||
raise ValueError("描述不能为空")
|
||||
if not isinstance(description, str):
|
||||
raise TypeError("描述必须是字符串类型")
|
||||
try:
|
||||
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
emoji_result = await emoji_manager.get_emoji_for_text(description)
|
||||
|
||||
if not emoji_result:
|
||||
logger.warning(f"[EmojiAPI] 未找到匹配描述 '{description}' 的表情包")
|
||||
return None
|
||||
|
||||
emoji_path, emoji_description, matched_emotion = emoji_result
|
||||
emoji_base64 = image_path_to_base64(emoji_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiAPI] 无法将表情包文件转换为base64: {emoji_path}")
|
||||
return None
|
||||
|
||||
logger.debug(f"[EmojiAPI] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
|
||||
return emoji_base64, emoji_description, matched_emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||
"""随机获取指定数量的表情包
|
||||
|
||||
Args:
|
||||
count: 要获取的表情包数量,默认为1
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,失败则返回空列表
|
||||
|
||||
Raises:
|
||||
TypeError: 如果count不是整数类型
|
||||
ValueError: 如果count为负数
|
||||
"""
|
||||
if not isinstance(count, int):
|
||||
raise TypeError("count 必须是整数类型")
|
||||
if count < 0:
|
||||
raise ValueError("count 不能为负数")
|
||||
if count == 0:
|
||||
logger.warning("[EmojiAPI] count 为0,返回空列表")
|
||||
return []
|
||||
|
||||
try:
|
||||
logger.info(f"[EmojiAPI] 随机获取 {count} 个表情包")
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
all_emojis = emoji_manager.emoji_objects
|
||||
|
||||
if not all_emojis:
|
||||
logger.warning("[EmojiAPI] 没有可用的表情包")
|
||||
return []
|
||||
|
||||
# 过滤有效表情包
|
||||
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
|
||||
if not valid_emojis:
|
||||
logger.warning("[EmojiAPI] 没有有效的表情包")
|
||||
return []
|
||||
|
||||
if len(valid_emojis) < count:
|
||||
logger.warning(
|
||||
f"[EmojiAPI] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
|
||||
)
|
||||
count = len(valid_emojis)
|
||||
|
||||
# 随机选择
|
||||
selected_emojis = random.sample(valid_emojis, count)
|
||||
|
||||
results = []
|
||||
for selected_emoji in selected_emojis:
|
||||
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
|
||||
continue
|
||||
|
||||
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
|
||||
|
||||
# 记录使用次数
|
||||
emoji_manager.record_usage(selected_emoji.hash)
|
||||
results.append((emoji_base64, selected_emoji.description, matched_emotion))
|
||||
|
||||
if not results and count > 0:
|
||||
logger.warning("[EmojiAPI] 随机获取表情包失败,没有一个可以成功处理")
|
||||
return []
|
||||
|
||||
logger.info(f"[EmojiAPI] 成功获取 {len(results)} 个随机表情包")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""根据情感标签获取表情包
|
||||
|
||||
Args:
|
||||
emotion: 情感标签,如"happy"、"sad"、"angry"等
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果情感标签为空字符串
|
||||
TypeError: 如果情感标签不是字符串类型
|
||||
"""
|
||||
if not emotion:
|
||||
raise ValueError("情感标签不能为空")
|
||||
if not isinstance(emotion, str):
|
||||
raise TypeError("情感标签必须是字符串类型")
|
||||
try:
|
||||
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
all_emojis = emoji_manager.emoji_objects
|
||||
|
||||
# 筛选匹配情感的表情包
|
||||
matching_emojis = []
|
||||
matching_emojis.extend(
|
||||
emoji_obj
|
||||
for emoji_obj in all_emojis
|
||||
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
|
||||
)
|
||||
if not matching_emojis:
|
||||
logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包")
|
||||
return None
|
||||
|
||||
# 随机选择匹配的表情包
|
||||
selected_emoji = random.choice(matching_emojis)
|
||||
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
|
||||
return None
|
||||
|
||||
# 记录使用次数
|
||||
emoji_manager.record_usage(selected_emoji.hash)
|
||||
|
||||
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
|
||||
return emoji_base64, selected_emoji.description, emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 根据情感获取表情包失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包信息查询API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_count() -> int:
|
||||
"""获取表情包数量
|
||||
|
||||
Returns:
|
||||
int: 当前可用的表情包数量
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
return emoji_manager.emoji_num
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包数量失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def get_info():
|
||||
"""获取表情包系统信息
|
||||
|
||||
Returns:
|
||||
dict: 包含表情包数量、最大数量、可用数量信息
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
return {
|
||||
"current_count": emoji_manager.emoji_num,
|
||||
"max_count": emoji_manager.emoji_num_max,
|
||||
"available_emojis": len([e for e in emoji_manager.emoji_objects if not e.is_deleted]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包信息失败: {e}")
|
||||
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
||||
|
||||
|
||||
def get_emotions() -> List[str]:
|
||||
"""获取所有可用的情感标签
|
||||
|
||||
Returns:
|
||||
list: 所有表情包的情感标签列表(去重)
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
emotions = set()
|
||||
|
||||
for emoji_obj in emoji_manager.emoji_objects:
|
||||
if not emoji_obj.is_deleted and emoji_obj.emotion:
|
||||
emotions.update(emoji_obj.emotion)
|
||||
|
||||
return sorted(list(emotions))
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取情感标签失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_descriptions() -> List[str]:
|
||||
"""获取所有表情包描述
|
||||
|
||||
Returns:
|
||||
list: 所有可用表情包的描述列表
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
descriptions = []
|
||||
|
||||
descriptions.extend(
|
||||
emoji_obj.description
|
||||
for emoji_obj in emoji_manager.emoji_objects
|
||||
if not emoji_obj.is_deleted and emoji_obj.description
|
||||
)
|
||||
return descriptions
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")
|
||||
return []
|
||||
280
src/plugin_system/apis/generator_api.py
Normal file
280
src/plugin_system/apis/generator_api.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
回复器API模块
|
||||
|
||||
提供回复器相关功能,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import generator_api
|
||||
replyer = generator_api.get_replyer(chat_stream)
|
||||
success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Tuple, Any, Dict, List, Optional
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("generator_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 回复器获取API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_replyer(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer]:
|
||||
"""获取回复器对象
|
||||
|
||||
优先使用chat_stream,如果没有则使用chat_id直接查找。
|
||||
使用 ReplyerManager 来管理实例,避免重复创建。
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
chat_id: 聊天ID(实际上就是stream_id)
|
||||
model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
|
||||
request_type: 请求类型
|
||||
|
||||
Returns:
|
||||
Optional[DefaultReplyer]: 回复器对象,如果获取失败则返回None
|
||||
|
||||
Raises:
|
||||
ValueError: chat_stream 和 chat_id 均为空
|
||||
"""
|
||||
if not chat_id and not chat_stream:
|
||||
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
||||
try:
|
||||
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
||||
return replyer_manager.get_replyer(
|
||||
chat_stream=chat_stream,
|
||||
chat_id=chat_id,
|
||||
model_set_with_weight=model_set_with_weight,
|
||||
request_type=request_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 获取回复器时发生意外错误: {e}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 回复生成API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def generate_reply(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
action_data: Optional[Dict[str, Any]] = None,
|
||||
reply_to: str = "",
|
||||
extra_info: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
enable_tool: bool = False,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
return_prompt: bool = False,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
request_type: str = "generator_api",
|
||||
from_plugin: bool = True,
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
chat_id: 聊天ID(备用)
|
||||
action_data: 动作数据(向下兼容,包含reply_to和extra_info)
|
||||
reply_to: 回复对象,格式为 "发送者:消息内容"
|
||||
extra_info: 额外信息,用于补充上下文
|
||||
available_actions: 可用动作
|
||||
enable_tool: 是否启用工具调用
|
||||
enable_splitter: 是否启用消息分割器
|
||||
enable_chinese_typo: 是否启用错字生成器
|
||||
return_prompt: 是否返回提示词
|
||||
model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
|
||||
request_type: 请求类型(可选,记录LLM使用)
|
||||
from_plugin: 是否来自插件
|
||||
Returns:
|
||||
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(
|
||||
chat_stream, chat_id, model_set_with_weight=model_set_with_weight, request_type=request_type
|
||||
)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, [], None
|
||||
|
||||
logger.debug("[GeneratorAPI] 开始生成回复")
|
||||
|
||||
if not reply_to and action_data:
|
||||
reply_to = action_data.get("reply_to", "")
|
||||
if not extra_info and action_data:
|
||||
extra_info = action_data.get("extra_info", "")
|
||||
|
||||
# 调用回复器生成回复
|
||||
success, llm_response_dict, prompt = await replyer.generate_reply_with_context(
|
||||
reply_to=reply_to,
|
||||
extra_info=extra_info,
|
||||
available_actions=available_actions,
|
||||
enable_tool=enable_tool,
|
||||
from_plugin=from_plugin,
|
||||
stream_id=chat_stream.stream_id if chat_stream else chat_id,
|
||||
)
|
||||
if not success:
|
||||
logger.warning("[GeneratorAPI] 回复生成失败")
|
||||
return False, [], None
|
||||
assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况
|
||||
if content := llm_response_dict.get("content", ""):
|
||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
else:
|
||||
reply_set = []
|
||||
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
|
||||
|
||||
if return_prompt:
|
||||
return success, reply_set, prompt
|
||||
else:
|
||||
return success, reply_set, None
|
||||
|
||||
except ValueError as ve:
|
||||
raise ve
|
||||
|
||||
except UserWarning as uw:
|
||||
logger.warning(f"[GeneratorAPI] 中断了生成: {uw}")
|
||||
return False, [], None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False, [], None
|
||||
|
||||
|
||||
async def rewrite_reply(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
reply_data: Optional[Dict[str, Any]] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
raw_reply: str = "",
|
||||
reason: str = "",
|
||||
reply_to: str = "",
|
||||
return_prompt: bool = False,
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
"""重写回复
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
reply_data: 回复数据字典(向下兼容备用,当其他参数缺失时从此获取)
|
||||
chat_id: 聊天ID(备用)
|
||||
enable_splitter: 是否启用消息分割器
|
||||
enable_chinese_typo: 是否启用错字生成器
|
||||
model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
|
||||
raw_reply: 原始回复内容
|
||||
reason: 回复原因
|
||||
reply_to: 回复对象
|
||||
return_prompt: 是否返回提示词
|
||||
|
||||
Returns:
|
||||
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, [], None
|
||||
|
||||
logger.info("[GeneratorAPI] 开始重写回复")
|
||||
|
||||
# 如果参数缺失,从reply_data中获取
|
||||
if reply_data:
|
||||
raw_reply = raw_reply or reply_data.get("raw_reply", "")
|
||||
reason = reason or reply_data.get("reason", "")
|
||||
reply_to = reply_to or reply_data.get("reply_to", "")
|
||||
|
||||
# 调用回复器重写回复
|
||||
success, content, prompt = await replyer.rewrite_reply_with_context(
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
reply_to=reply_to,
|
||||
return_prompt=return_prompt,
|
||||
)
|
||||
reply_set = []
|
||||
if content:
|
||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
|
||||
if success:
|
||||
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项")
|
||||
else:
|
||||
logger.warning("[GeneratorAPI] 重写回复失败")
|
||||
|
||||
return success, reply_set, prompt if return_prompt else None
|
||||
|
||||
except ValueError as ve:
|
||||
raise ve
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
|
||||
return False, [], None
|
||||
|
||||
|
||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
|
||||
"""将文本处理为更拟人化的文本
|
||||
|
||||
Args:
|
||||
content: 文本内容
|
||||
enable_splitter: 是否启用消息分割器
|
||||
enable_chinese_typo: 是否启用错字生成器
|
||||
"""
|
||||
if not isinstance(content, str):
|
||||
raise ValueError("content 必须是字符串类型")
|
||||
try:
|
||||
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
||||
|
||||
reply_set = []
|
||||
for text in processed_response:
|
||||
reply_seg = ("text", text)
|
||||
reply_set.append(reply_seg)
|
||||
|
||||
return reply_set
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def generate_response_custom(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
prompt: str = "",
|
||||
) -> Optional[str]:
|
||||
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.debug("[GeneratorAPI] 开始生成自定义回复")
|
||||
response, _, _, _ = await replyer.llm_generate_content(prompt)
|
||||
if response:
|
||||
logger.debug("[GeneratorAPI] 自定义回复生成成功")
|
||||
return response
|
||||
else:
|
||||
logger.warning("[GeneratorAPI] 自定义回复生成失败")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 生成自定义回复时出错: {e}")
|
||||
return None
|
||||
122
src/plugin_system/apis/llm_api.py
Normal file
122
src/plugin_system/apis/llm_api.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""LLM API模块
|
||||
|
||||
提供了与LLM模型交互的功能
|
||||
使用方式:
|
||||
from src.plugin_system.apis import llm_api
|
||||
models = llm_api.get_available_models()
|
||||
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
|
||||
"""
|
||||
|
||||
from typing import Tuple, Dict, List, Any, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
|
||||
logger = get_logger("llm_api")
|
||||
|
||||
# =============================================================================
|
||||
# LLM模型API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, TaskConfig]:
|
||||
"""获取所有可用的模型配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置
|
||||
"""
|
||||
try:
|
||||
# 自动获取所有属性并转换为字典形式
|
||||
models = model_config.model_task_config
|
||||
attrs = dir(models)
|
||||
rets: Dict[str, TaskConfig] = {}
|
||||
for attr in attrs:
|
||||
if not attr.startswith("__"):
|
||||
try:
|
||||
value = getattr(models, attr)
|
||||
if not callable(value) and isinstance(value, TaskConfig):
|
||||
rets[attr] = value
|
||||
except Exception as e:
|
||||
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
|
||||
continue
|
||||
return rets
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMAPI] 获取可用模型失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
async def generate_with_model(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str]:
|
||||
"""使用指定模型生成内容
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||
request_type: 请求类型标识
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
||||
"""
|
||||
try:
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
|
||||
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
|
||||
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens)
|
||||
return True, response, reasoning_content, model_name
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", ""
|
||||
|
||||
async def generate_with_model_with_tools(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
tool_options: List[Dict[str, Any]] | None = None,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||
"""使用指定模型和工具生成内容
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||
tool_options: 工具选项列表
|
||||
request_type: 请求类型标识
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
||||
"""
|
||||
try:
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
|
||||
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
|
||||
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
|
||||
prompt,
|
||||
tools=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
return True, response, reasoning_content, model_name, tool_call
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", "", None
|
||||
3
src/plugin_system/apis/logging_api.py
Normal file
3
src/plugin_system/apis/logging_api.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.common.logger import get_logger
|
||||
|
||||
__all__ = ["get_logger"]
|
||||
483
src/plugin_system/apis/message_api.py
Normal file
483
src/plugin_system/apis/message_api.py
Normal file
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
消息API模块
|
||||
|
||||
提供消息查询和构建成字符串的功能,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import message_api
|
||||
messages = message_api.get_messages_by_time_in_chat(chat_id, start_time, end_time)
|
||||
readable_text = message_api.build_readable_messages(messages)
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
get_raw_msg_by_timestamp_with_chat_users,
|
||||
get_raw_msg_by_timestamp_random,
|
||||
get_raw_msg_by_timestamp_with_users,
|
||||
get_raw_msg_before_timestamp,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
get_raw_msg_before_timestamp_with_users,
|
||||
num_new_messages_since,
|
||||
num_new_messages_since_with_users,
|
||||
build_readable_messages,
|
||||
build_readable_messages_with_list,
|
||||
get_person_id_list,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息查询API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_messages_by_time(
|
||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode))
|
||||
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_mai: bool = False,
|
||||
filter_command: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
filter_command: 是否过滤命令消息,默认为False
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command))
|
||||
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat_inclusive(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_mai: bool = False,
|
||||
filter_command: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定时间范围内的消息(包含边界)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳(包含)
|
||||
end_time: 结束时间戳(包含)
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
)
|
||||
return get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat_for_users(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
person_ids: List[str],
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定用户在指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
person_ids: 用户ID列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode)
|
||||
|
||||
|
||||
def get_random_chat_messages(
|
||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
随机选择一个聊天,返回该聊天在指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode))
|
||||
return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)
|
||||
|
||||
|
||||
def get_messages_by_time_for_users(
|
||||
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户在所有聊天中指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
person_ids: 用户ID列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
|
||||
|
||||
|
||||
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定时间戳之前的消息
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(timestamp, (int, float)):
|
||||
raise ValueError("timestamp 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit))
|
||||
return get_raw_msg_before_timestamp(timestamp, limit)
|
||||
|
||||
|
||||
def get_messages_before_time_in_chat(
|
||||
chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定时间戳之前的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
timestamp: 时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(timestamp, (int, float)):
|
||||
raise ValueError("timestamp 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit))
|
||||
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
|
||||
|
||||
|
||||
def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户在指定时间戳之前的消息
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
person_ids: 用户ID列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(timestamp, (int, float)):
|
||||
raise ValueError("timestamp 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit)
|
||||
|
||||
|
||||
def get_recent_messages(
|
||||
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中最近一段时间的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
hours: 最近多少小时,默认24小时
|
||||
limit: 限制返回的消息数量,默认100条
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法s
|
||||
"""
|
||||
if not isinstance(hours, (int, float)) or hours < 0:
|
||||
raise ValueError("hours 不能是负数")
|
||||
if not isinstance(limit, int) or limit < 0:
|
||||
raise ValueError("limit 必须是非负整数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
now = time.time()
|
||||
start_time = now - hours * 3600
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode))
|
||||
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息计数API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
|
||||
"""
|
||||
计算指定聊天中从开始时间到结束时间的新消息数量
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳,如果为None则使用当前时间
|
||||
|
||||
Returns:
|
||||
int: 新消息数量
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)):
|
||||
raise ValueError("start_time 必须是数字类型")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
return num_new_messages_since(chat_id, start_time, end_time)
|
||||
|
||||
|
||||
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
|
||||
"""
|
||||
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
person_ids: 用户ID列表
|
||||
|
||||
Returns:
|
||||
int: 新消息数量
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息格式化API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def build_readable_messages_to_str(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
将消息列表构建成可读的字符串
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
replace_bot_name: 是否将机器人的名称替换为"你"
|
||||
merge_messages: 是否合并连续消息
|
||||
timestamp_mode: 时间戳显示模式,'relative'或'absolute'
|
||||
read_mark: 已读标记时间戳,用于分割已读和未读消息
|
||||
truncate: 是否截断长消息
|
||||
show_actions: 是否显示动作记录
|
||||
|
||||
Returns:
|
||||
格式化后的可读字符串
|
||||
"""
|
||||
return build_readable_messages(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, read_mark, truncate, show_actions
|
||||
)
|
||||
|
||||
|
||||
async def build_readable_messages_with_details(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
"""
|
||||
将消息列表构建成可读的字符串,并返回详细信息
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
replace_bot_name: 是否将机器人的名称替换为"你"
|
||||
merge_messages: 是否合并连续消息
|
||||
timestamp_mode: 时间戳显示模式,'relative'或'absolute'
|
||||
truncate: 是否截断长消息
|
||||
|
||||
Returns:
|
||||
格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
|
||||
"""
|
||||
return await build_readable_messages_with_list(messages, replace_bot_name, merge_messages, timestamp_mode, truncate)
|
||||
|
||||
|
||||
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
"""
|
||||
从消息列表中提取不重复的用户ID列表
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
用户ID列表
|
||||
"""
|
||||
return await get_person_id_list(messages)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息过滤函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从消息列表中移除麦麦的消息
|
||||
Args:
|
||||
messages: 消息列表,每个元素是消息字典
|
||||
Returns:
|
||||
过滤后的消息列表
|
||||
"""
|
||||
return [msg for msg in messages if msg.get("user_id") != str(global_config.bot.qq_account)]
|
||||
154
src/plugin_system/apis/person_api.py
Normal file
154
src/plugin_system/apis/person_api.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""个人信息API模块
|
||||
|
||||
提供个人信息查询功能,用于插件获取用户相关信息
|
||||
使用方式:
|
||||
from src.plugin_system.apis import person_api
|
||||
person_id = person_api.get_person_id("qq", 123456)
|
||||
value = await person_api.get_person_value(person_id, "nickname")
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
|
||||
|
||||
logger = get_logger("person_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 个人信息API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_person_id(platform: str, user_id: int) -> str:
|
||||
"""根据平台和用户ID获取person_id
|
||||
|
||||
Args:
|
||||
platform: 平台名称,如 "qq", "telegram" 等
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
str: 唯一的person_id(MD5哈希值)
|
||||
|
||||
示例:
|
||||
person_id = person_api.get_person_id("qq", 123456)
|
||||
"""
|
||||
try:
|
||||
return PersonInfoManager.get_person_id(platform, user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
|
||||
return ""
|
||||
|
||||
|
||||
async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any:
|
||||
"""根据person_id和字段名获取某个值
|
||||
|
||||
Args:
|
||||
person_id: 用户的唯一标识ID
|
||||
field_name: 要获取的字段名,如 "nickname", "impression" 等
|
||||
default: 当字段不存在或获取失败时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 字段值或默认值
|
||||
|
||||
示例:
|
||||
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
|
||||
impression = await person_api.get_person_value(person_id, "impression")
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
value = await person_info_manager.get_value(person_id, field_name)
|
||||
return value if value is not None else default
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
|
||||
return default
|
||||
|
||||
|
||||
async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict:
|
||||
"""批量获取用户信息字段值
|
||||
|
||||
Args:
|
||||
person_id: 用户的唯一标识ID
|
||||
field_names: 要获取的字段名列表
|
||||
default_dict: 默认值字典,键为字段名,值为默认值
|
||||
|
||||
Returns:
|
||||
dict: 字段名到值的映射字典
|
||||
|
||||
示例:
|
||||
values = await person_api.get_person_values(
|
||||
person_id,
|
||||
["nickname", "impression", "know_times"],
|
||||
{"nickname": "未知用户", "know_times": 0}
|
||||
)
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
values = await person_info_manager.get_values(person_id, field_names)
|
||||
|
||||
# 如果获取成功,返回结果
|
||||
if values:
|
||||
return values
|
||||
|
||||
# 如果获取失败,构建默认值字典
|
||||
result = {}
|
||||
if default_dict:
|
||||
for field in field_names:
|
||||
result[field] = default_dict.get(field, None)
|
||||
else:
|
||||
for field in field_names:
|
||||
result[field] = None
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 批量获取用户信息失败: person_id={person_id}, fields={field_names}, error={e}")
|
||||
# 返回默认值字典
|
||||
result = {}
|
||||
if default_dict:
|
||||
for field in field_names:
|
||||
result[field] = default_dict.get(field, None)
|
||||
else:
|
||||
for field in field_names:
|
||||
result[field] = None
|
||||
return result
|
||||
|
||||
|
||||
async def is_person_known(platform: str, user_id: int) -> bool:
|
||||
"""判断是否认识某个用户
|
||||
|
||||
Args:
|
||||
platform: 平台名称
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
bool: 是否认识该用户
|
||||
|
||||
示例:
|
||||
known = await person_api.is_person_known("qq", 123456)
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
return await person_info_manager.is_person_known(platform, user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 检查用户是否已知失败: platform={platform}, user_id={user_id}, error={e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_person_id_by_name(person_name: str) -> str:
|
||||
"""根据用户名获取person_id
|
||||
|
||||
Args:
|
||||
person_name: 用户名
|
||||
|
||||
Returns:
|
||||
str: person_id,如果未找到返回空字符串
|
||||
|
||||
示例:
|
||||
person_id = person_api.get_person_id_by_name("张三")
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
return person_info_manager.get_person_id_by_person_name(person_name)
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
|
||||
return ""
|
||||
120
src/plugin_system/apis/plugin_manage_api.py
Normal file
120
src/plugin_system/apis/plugin_manage_api.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from typing import Tuple, List
|
||||
|
||||
|
||||
def list_loaded_plugins() -> List[str]:
|
||||
"""
|
||||
列出所有当前加载的插件。
|
||||
|
||||
Returns:
|
||||
List[str]: 当前加载的插件名称列表。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.list_loaded_plugins()
|
||||
|
||||
|
||||
def list_registered_plugins() -> List[str]:
|
||||
"""
|
||||
列出所有已注册的插件。
|
||||
|
||||
Returns:
|
||||
List[str]: 已注册的插件名称列表。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.list_registered_plugins()
|
||||
|
||||
|
||||
def get_plugin_path(plugin_name: str) -> str:
|
||||
"""
|
||||
获取指定插件的路径。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件名称。
|
||||
|
||||
Returns:
|
||||
str: 插件目录的绝对路径。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果插件不存在。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
if plugin_path := plugin_manager.get_plugin_path(plugin_name):
|
||||
return plugin_path
|
||||
else:
|
||||
raise ValueError(f"插件 '{plugin_name}' 不存在。")
|
||||
|
||||
|
||||
async def remove_plugin(plugin_name: str) -> bool:
|
||||
"""
|
||||
卸载指定的插件。
|
||||
|
||||
**此函数是异步的,确保在异步环境中调用。**
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要卸载的插件名称。
|
||||
|
||||
Returns:
|
||||
bool: 卸载是否成功。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return await plugin_manager.remove_registered_plugin(plugin_name)
|
||||
|
||||
|
||||
async def reload_plugin(plugin_name: str) -> bool:
|
||||
"""
|
||||
重新加载指定的插件。
|
||||
|
||||
**此函数是异步的,确保在异步环境中调用。**
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要重新加载的插件名称。
|
||||
|
||||
Returns:
|
||||
bool: 重新加载是否成功。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return await plugin_manager.reload_registered_plugin(plugin_name)
|
||||
|
||||
|
||||
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
|
||||
"""
|
||||
加载指定的插件。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要加载的插件名称。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, int]: 加载是否成功,成功或失败个数。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.load_registered_plugin_classes(plugin_name)
|
||||
|
||||
|
||||
def add_plugin_directory(plugin_directory: str) -> bool:
|
||||
"""
|
||||
添加插件目录。
|
||||
|
||||
Args:
|
||||
plugin_directory (str): 要添加的插件目录路径。
|
||||
Returns:
|
||||
bool: 添加是否成功。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.add_plugin_directory(plugin_directory)
|
||||
|
||||
|
||||
def rescan_plugin_directory() -> Tuple[int, int]:
|
||||
"""
|
||||
重新扫描插件目录,加载新插件。
|
||||
Returns:
|
||||
Tuple[int, int]: 成功加载的插件数量和失败的插件数量。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.rescan_plugin_directory()
|
||||
46
src/plugin_system/apis/plugin_register_api.py
Normal file
46
src/plugin_system/apis/plugin_register_api.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plugin_manager") # 复用plugin_manager名称
|
||||
|
||||
|
||||
def register_plugin(cls):
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
|
||||
"""插件注册装饰器
|
||||
|
||||
用法:
|
||||
@register_plugin
|
||||
class MyPlugin(BasePlugin):
|
||||
plugin_name = "my_plugin"
|
||||
plugin_description = "我的插件"
|
||||
...
|
||||
"""
|
||||
if not issubclass(cls, BasePlugin):
|
||||
logger.error(f"类 {cls.__name__} 不是 BasePlugin 的子类")
|
||||
return cls
|
||||
|
||||
# 只是注册插件类,不立即实例化
|
||||
# 插件管理器会负责实例化和注册
|
||||
plugin_name: str = cls.plugin_name # type: ignore
|
||||
if "." in plugin_name:
|
||||
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
splitted_name = cls.__module__.split(".")
|
||||
root_path = Path(__file__)
|
||||
|
||||
# 查找项目根目录
|
||||
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
|
||||
root_path = root_path.parent
|
||||
|
||||
if not (root_path / "pyproject.toml").exists():
|
||||
logger.error(f"注册 {plugin_name} 无法找到项目根目录")
|
||||
return cls
|
||||
|
||||
plugin_manager.plugin_classes[plugin_name] = cls
|
||||
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
|
||||
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")
|
||||
|
||||
return cls
|
||||
369
src/plugin_system/apis/send_api.py
Normal file
369
src/plugin_system/apis/send_api.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""
|
||||
发送API模块
|
||||
|
||||
专门负责发送各种类型的消息,采用标准Python包设计模式
|
||||
|
||||
使用方式:
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
# 方式1:直接使用stream_id(推荐)
|
||||
await send_api.text_to_stream("hello", stream_id)
|
||||
await send_api.emoji_to_stream(emoji_base64, stream_id)
|
||||
await send_api.custom_to_stream("video", video_data, stream_id)
|
||||
|
||||
# 方式2:使用群聊/私聊指定函数
|
||||
await send_api.text_to_group("hello", "123456")
|
||||
await send_api.text_to_user("hello", "987654")
|
||||
|
||||
# 方式3:使用通用custom_message函数
|
||||
await send_api.custom_message("video", video_data, "123456", True)
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
import difflib
|
||||
from typing import Optional, Union
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入依赖
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, replace_user_references_async
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from maim_message import Seg, UserInfo
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("send_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 内部实现函数(不暴露给外部)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def _send_to_target(
|
||||
message_type: str,
|
||||
content: Union[str, dict],
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
reply_to_platform_id: Optional[str] = None,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
) -> bool:
|
||||
"""向指定目标发送消息的内部实现
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"等
|
||||
content: 消息内容
|
||||
stream_id: 目标流ID
|
||||
display_message: 显示消息
|
||||
typing: 是否模拟打字等待。
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!)
|
||||
storage_message: 是否存储消息到数据库
|
||||
show_log: 发送是否显示日志
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
if show_log:
|
||||
logger.debug(f"[SendAPI] 发送{message_type}消息到 {stream_id}")
|
||||
|
||||
# 查找目标聊天流
|
||||
target_stream = get_chat_manager().get_stream(stream_id)
|
||||
if not target_stream:
|
||||
logger.error(f"[SendAPI] 未找到聊天流: {stream_id}")
|
||||
return False
|
||||
|
||||
# 创建发送器
|
||||
heart_fc_sender = HeartFCSender()
|
||||
|
||||
# 生成消息ID
|
||||
current_time = time.time()
|
||||
message_id = f"send_api_{int(current_time * 1000)}"
|
||||
|
||||
# 构建机器人用户信息
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=target_stream.platform,
|
||||
)
|
||||
|
||||
# 创建消息段
|
||||
message_segment = Seg(type=message_type, data=content) # type: ignore
|
||||
|
||||
# 处理回复消息
|
||||
anchor_message = None
|
||||
if reply_to:
|
||||
anchor_message = await _find_reply_message(target_stream, reply_to)
|
||||
if anchor_message and anchor_message.message_info.user_info and not reply_to_platform_id:
|
||||
reply_to_platform_id = (
|
||||
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
|
||||
)
|
||||
|
||||
# 构建发送消息对象
|
||||
bot_message = MessageSending(
|
||||
message_id=message_id,
|
||||
chat_stream=target_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=target_stream.user_info,
|
||||
message_segment=message_segment,
|
||||
display_message=display_message,
|
||||
reply=anchor_message,
|
||||
is_head=True,
|
||||
is_emoji=(message_type == "emoji"),
|
||||
thinking_start_time=current_time,
|
||||
reply_to=reply_to_platform_id,
|
||||
)
|
||||
|
||||
# 发送消息
|
||||
sent_msg = await heart_fc_sender.send_message(
|
||||
bot_message,
|
||||
typing=typing,
|
||||
set_reply=(anchor_message is not None),
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
|
||||
if sent_msg:
|
||||
logger.debug(f"[SendAPI] 成功发送消息到 {stream_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error("[SendAPI] 发送消息失败")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SendAPI] 发送消息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]:
|
||||
# sourcery skip: inline-variable, use-named-expression
|
||||
"""查找要回复的消息
|
||||
|
||||
Args:
|
||||
target_stream: 目标聊天流
|
||||
reply_to: 回复格式,如"发送者:消息内容"或"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
Optional[MessageRecv]: 找到的消息,如果没找到则返回None
|
||||
"""
|
||||
try:
|
||||
# 解析reply_to参数
|
||||
if ":" in reply_to:
|
||||
parts = reply_to.split(":", 1)
|
||||
elif ":" in reply_to:
|
||||
parts = reply_to.split(":", 1)
|
||||
else:
|
||||
logger.warning(f"[SendAPI] reply_to格式不正确: {reply_to}")
|
||||
return None
|
||||
|
||||
if len(parts) != 2:
|
||||
logger.warning(f"[SendAPI] reply_to格式不正确: {reply_to}")
|
||||
return None
|
||||
|
||||
sender = parts[0].strip()
|
||||
text = parts[1].strip()
|
||||
|
||||
# 获取聊天流的最新20条消息
|
||||
reverse_talking_message = get_raw_msg_before_timestamp_with_chat(
|
||||
target_stream.stream_id,
|
||||
time.time(), # 当前时间之前的消息
|
||||
20, # 最新的20条消息
|
||||
)
|
||||
|
||||
# 反转列表,使最新的消息在前面
|
||||
reverse_talking_message = list(reversed(reverse_talking_message))
|
||||
|
||||
find_msg = None
|
||||
for message in reverse_talking_message:
|
||||
user_id = message["user_id"]
|
||||
platform = message["chat_info_platform"]
|
||||
person_id = get_person_info_manager().get_person_id(platform, user_id)
|
||||
person_name = await get_person_info_manager().get_value(person_id, "person_name")
|
||||
if person_name == sender:
|
||||
translate_text = message["processed_plain_text"]
|
||||
|
||||
# 使用独立函数处理用户引用格式
|
||||
translate_text = await replace_user_references_async(translate_text, platform)
|
||||
|
||||
similarity = difflib.SequenceMatcher(None, text, translate_text).ratio()
|
||||
if similarity >= 0.9:
|
||||
find_msg = message
|
||||
break
|
||||
|
||||
if not find_msg:
|
||||
logger.info("[SendAPI] 未找到匹配的回复消息")
|
||||
return None
|
||||
|
||||
# 构建MessageRecv对象
|
||||
user_info = {
|
||||
"platform": find_msg.get("user_platform", ""),
|
||||
"user_id": find_msg.get("user_id", ""),
|
||||
"user_nickname": find_msg.get("user_nickname", ""),
|
||||
"user_cardname": find_msg.get("user_cardname", ""),
|
||||
}
|
||||
|
||||
group_info = {}
|
||||
if find_msg.get("chat_info_group_id"):
|
||||
group_info = {
|
||||
"platform": find_msg.get("chat_info_group_platform", ""),
|
||||
"group_id": find_msg.get("chat_info_group_id", ""),
|
||||
"group_name": find_msg.get("chat_info_group_name", ""),
|
||||
}
|
||||
|
||||
format_info = {"content_format": "", "accept_format": ""}
|
||||
template_info = {"template_items": {}}
|
||||
|
||||
message_info = {
|
||||
"platform": target_stream.platform,
|
||||
"message_id": find_msg.get("message_id"),
|
||||
"time": find_msg.get("time"),
|
||||
"group_info": group_info,
|
||||
"user_info": user_info,
|
||||
"additional_config": find_msg.get("additional_config"),
|
||||
"format_info": format_info,
|
||||
"template_info": template_info,
|
||||
}
|
||||
|
||||
message_dict = {
|
||||
"message_info": message_info,
|
||||
"raw_message": find_msg.get("processed_plain_text"),
|
||||
"processed_plain_text": find_msg.get("processed_plain_text"),
|
||||
}
|
||||
|
||||
find_rec_msg = MessageRecv(message_dict)
|
||||
find_rec_msg.update_chat_stream(target_stream)
|
||||
|
||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {sender}")
|
||||
return find_rec_msg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SendAPI] 查找回复消息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 公共API函数 - 预定义类型的发送函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def text_to_stream(
|
||||
text: str,
|
||||
stream_id: str,
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
reply_to_platform_id: str = "",
|
||||
storage_message: bool = True,
|
||||
) -> bool:
|
||||
"""向指定流发送文本消息
|
||||
|
||||
Args:
|
||||
text: 要发送的文本内容
|
||||
stream_id: 聊天流ID
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!)
|
||||
storage_message: 是否存储消息到数据库
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target(
|
||||
"text",
|
||||
text,
|
||||
stream_id,
|
||||
"",
|
||||
typing,
|
||||
reply_to,
|
||||
reply_to_platform_id=reply_to_platform_id,
|
||||
storage_message=storage_message,
|
||||
)
|
||||
|
||||
|
||||
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool:
|
||||
"""向指定流发送表情包
|
||||
|
||||
Args:
|
||||
emoji_base64: 表情包的base64编码
|
||||
stream_id: 聊天流ID
|
||||
storage_message: 是否存储消息到数据库
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
|
||||
async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True) -> bool:
|
||||
"""向指定流发送图片
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
stream_id: 聊天流ID
|
||||
storage_message: 是否存储消息到数据库
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
|
||||
async def command_to_stream(
|
||||
command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = ""
|
||||
) -> bool:
|
||||
"""向指定流发送命令
|
||||
|
||||
Args:
|
||||
command: 命令
|
||||
stream_id: 聊天流ID
|
||||
storage_message: 是否存储消息到数据库
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target(
|
||||
"command", command, stream_id, display_message, typing=False, storage_message=storage_message
|
||||
)
|
||||
|
||||
|
||||
async def custom_to_stream(
|
||||
message_type: str,
|
||||
content: str | dict,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
) -> bool:
|
||||
"""向指定流发送自定义类型消息
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"等
|
||||
content: 消息内容(通常是base64编码或文本)
|
||||
stream_id: 聊天流ID
|
||||
display_message: 显示消息
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
storage_message: 是否存储消息到数据库
|
||||
show_log: 是否显示日志
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target(
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
reply_to=reply_to,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
34
src/plugin_system/apis/tool_api.py
Normal file
34
src/plugin_system/apis/tool_api.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Optional, Type
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_api")
|
||||
|
||||
|
||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
"""获取公开工具实例"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
# 获取插件配置
|
||||
tool_info = component_registry.get_component_info(tool_name, ComponentType.TOOL)
|
||||
if tool_info:
|
||||
plugin_config = component_registry.get_plugin_config(tool_info.plugin_name)
|
||||
else:
|
||||
plugin_config = None
|
||||
|
||||
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
|
||||
return tool_class(plugin_config) if tool_class else None
|
||||
|
||||
|
||||
def get_llm_available_tool_definitions():
|
||||
"""获取LLM可用的工具定义列表
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)]
|
||||
"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
llm_available_tools = component_registry.get_llm_available_tools()
|
||||
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
|
||||
Reference in New Issue
Block a user